diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..35569a24b3f51fde1512c567d18f5bf0d1f901dc 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/performance_radar.png filter=lfs diff=lfs merge=lfs -text
+assets/soul_wechat01.jpg filter=lfs diff=lfs merge=lfs -text
+assets/soulx-logo.png filter=lfs diff=lfs merge=lfs -text
+assets/technical-report.pdf filter=lfs diff=lfs merge=lfs -text
+example/audio/music.mp3 filter=lfs diff=lfs merge=lfs -text
+example/audio/yue_target.mp3 filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..5f37ac3b15376b1f82f2edb9e7d95e8bf9c7b5b6
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,38 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+
+dev/
+results/
+wandb/
+.ipynb_checkpoints/
+.vscode/
+.cache
+local/
+outputs/
+
+*.pt
+*.ckpt
+
+# Logs
+logs/
+*.log
+results/
+runs/
+dev*
+local/
+generated/
+
+.DS_Store
+pretrained_models/
+
+*.err
+*.out
+
+# Dev
+dev/
+
+# Data
+data/
+outputs/
+deploy/
+.gradio/
\ No newline at end of file
diff --git a/DEPLOY.md b/DEPLOY.md
new file mode 100644
index 0000000000000000000000000000000000000000..de49e6374a79ea7edc2189f8b208f1c8d9abf234
--- /dev/null
+++ b/DEPLOY.md
@@ -0,0 +1,201 @@
+# 🚀 部署到 Hugging Face Space 指南
+
+本指南将帮助您将 SoulX-Singer 部署到 Hugging Face Space。
+
+## 📋 前置要求
+
+1. **Hugging Face 账号**:如果没有,请先注册 [huggingface.co](https://huggingface.co/join)
+2. **Git**:确保已安装 Git
+3. **Hugging Face CLI**(可选但推荐):`pip install huggingface_hub`
+
+## 🎯 部署步骤
+
+### 方法一:通过 Web 界面创建(推荐)
+
+#### 步骤 1:准备代码仓库
+
+确保您的代码已准备好:
+- ✅ `app.py` - Space 入口文件
+- ✅ `webui.py` - Gradio 界面代码
+- ✅ `requirements.txt` - Python 依赖
+- ✅ `README.md` - 包含 Space 配置的 YAML 头部
+
+#### 步骤 2:创建 Space
+
+1. 访问 [huggingface.co/spaces](https://huggingface.co/spaces)
+2. 点击 **"Create new Space"** 按钮
+3. 填写 Space 信息:
+ - **Space name**: 例如 `SoulX-Singer` 或 `soulx-singer-demo`
+ - **SDK**: 选择 **Gradio**
+ - **Hardware**: 推荐选择 **GPU T4 small**(推理更快,首次下载模型后缓存)
+ - **Visibility**: 选择 Public(公开)或 Private(私有)
+4. 点击 **"Create Space"**
+
+#### 步骤 3:上传代码
+
+**选项 A:使用 Git 推送(推荐)**
+
+```bash
+# 1. 在本地代码目录初始化 Git(如果还没有)
+git init
+git add .
+git commit -m "Initial commit for HF Space"
+
+# 2. 添加 Hugging Face 远程仓库
+# 替换 YOUR_USERNAME 和 YOUR_SPACE_NAME
+git remote add origin https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
+
+# 3. 推送代码
+git push -u origin main
+```
+
+**选项 B:使用 Web 界面上传**
+
+1. 在 Space 页面点击 **"Files and versions"** 标签
+2. 点击 **"Add file"** → **"Upload files"**
+3. 拖拽或选择以下必需文件:
+ - `app.py`
+ - `webui.py`
+ - `requirements.txt`
+ - `README.md`
+ - `soulxsinger/` 目录(整个文件夹)
+ - `preprocess/` 目录(整个文件夹)
+ - `cli/` 目录(整个文件夹)
+ - `example/` 目录(整个文件夹)
+ - `assets/` 目录(整个文件夹)
+ - 其他配置文件(如 `LICENSE`, `.gitignore` 等)
+
+#### 步骤 4:等待构建和首次运行
+
+1. Space 会自动检测到代码并开始构建
+2. 查看 **"Logs"** 标签页监控构建进度
+3. 首次运行会:
+ - 安装 `requirements.txt` 中的依赖
+ - 执行 `app.py`
+ - **自动下载** `Soul-AILab/SoulX-Singer` 和 `Soul-AILab/SoulX-Singer-Preprocess` 模型(可能需要 5-15 分钟,取决于网络速度)
+4. 构建完成后,Space 会自动启动,您可以在 **"App"** 标签页看到界面
+
+### 方法二:使用 Hugging Face CLI
+
+```bash
+# 1. 安装 Hugging Face Hub CLI
+pip install huggingface_hub
+
+# 2. 登录(会打开浏览器)
+huggingface-cli login
+
+# 3. 创建 Space(替换 YOUR_USERNAME 和 YOUR_SPACE_NAME)
+huggingface-cli repo create YOUR_SPACE_NAME --type space --sdk gradio
+
+# 4. 克隆 Space 仓库
+git clone https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
+cd YOUR_SPACE_NAME
+
+# 5. 复制代码文件到 Space 目录
+# (将当前代码目录的所有文件复制过来)
+
+# 6. 提交并推送
+git add .
+git commit -m "Deploy SoulX-Singer to HF Space"
+git push
+```
+
+## ⚙️ Space 配置说明
+
+Space 配置在 `README.md` 的 YAML 头部:
+
+```yaml
+---
+title: SoulX-Singer
+emoji: 🎤
+sdk: gradio
+sdk_version: "6.3.0"
+app_file: app.py
+python_version: "3.10"
+suggested_hardware: t4-small # 取消注释以启用 GPU
+---
+```
+
+### 硬件选择建议
+
+- **CPU Basic**: 免费,但推理速度较慢,适合测试
+- **GPU T4 Small**: 推荐,推理速度快,首次下载模型后缓存
+- **GPU T4 Medium/Large**: 适合高并发或更复杂的推理
+
+### 修改硬件配置
+
+1. 进入 Space 页面
+2. 点击 **"Settings"** 标签
+3. 在 **"Hardware"** 部分选择所需硬件
+4. 保存后 Space 会重启
+
+## 🔍 故障排查
+
+### 问题 1:构建失败
+
+**检查点:**
+- ✅ `requirements.txt` 中所有依赖版本是否兼容
+- ✅ `app.py` 文件是否存在且可执行
+- ✅ `README.md` 的 YAML 配置是否正确
+
+**查看日志:**
+- 在 Space 页面的 **"Logs"** 标签查看详细错误信息
+
+### 问题 2:模型下载失败
+
+**可能原因:**
+- 网络连接问题
+- Hugging Face Hub 认证问题
+
+**解决方案:**
+- 确保 Space 有网络访问权限(默认有)
+- 如果使用私有模型,需要在 Space Settings 中添加 HF Token
+
+### 问题 3:应用启动后无法访问
+
+**检查点:**
+- ✅ `app.py` 中 `server_name="0.0.0.0"` 已设置
+- ✅ 端口使用环境变量 `PORT`(Space 会自动注入)
+- ✅ 查看 **"Logs"** 确认应用是否成功启动
+
+### 问题 4:内存不足
+
+**解决方案:**
+- 升级到更大的硬件(T4 Medium/Large)
+- 或优化代码,减少内存占用
+
+## 📝 重要提示
+
+1. **首次运行时间**:首次部署时,模型下载可能需要 5-15 分钟,请耐心等待
+2. **模型缓存**:下载的模型会缓存在 Space 的存储中,重启后无需重新下载
+3. **存储限制**:免费 Space 有存储限制,确保模型文件不会超过限制
+4. **自动重启**:Space 会在代码更新后自动重启
+5. **日志查看**:遇到问题时,首先查看 **"Logs"** 标签页的详细日志
+
+## 🔗 相关链接
+
+- [Hugging Face Spaces 文档](https://huggingface.co/docs/hub/spaces)
+- [Gradio 文档](https://gradio.app/docs/)
+- [SoulX-Singer 模型页面](https://huggingface.co/Soul-AILab/SoulX-Singer)
+- [SoulX-Singer-Preprocess 模型页面](https://huggingface.co/Soul-AILab/SoulX-Singer-Preprocess)
+
+## ✅ 部署检查清单
+
+部署前确认:
+- [ ] `app.py` 文件存在且正确
+- [ ] `requirements.txt` 包含所有依赖(包括 `huggingface_hub`)
+- [ ] `README.md` 包含正确的 YAML 配置
+- [ ] 所有必需的代码文件都已上传
+- [ ] `.gitignore` 正确配置(排除 `pretrained_models/` 和 `outputs/`)
+- [ ] Space 硬件配置合适(推荐 GPU T4 Small)
+
+部署后验证:
+- [ ] Space 构建成功(无错误日志)
+- [ ] 模型自动下载完成
+- [ ] Web 界面可以正常访问
+- [ ] 可以上传音频文件进行测试
+- [ ] 推理功能正常工作
+
+---
+
+**祝部署顺利!** 🎉
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index e3e6815bdea6b76c83d7191cc7df21e8712fd7be..e53c20d67757a85224198ce14a7f4d6451f480d6 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,231 @@
---
-title: SoulX Singer
-emoji: 👁
-colorFrom: purple
-colorTo: yellow
+title: SoulX-Singer
+emoji: 🎤
sdk: gradio
-sdk_version: 6.5.1
+sdk_version: "6.3.0"
app_file: app.py
-pinned: false
-license: apache-2.0
-short_description: Zero-shot Singing Voice Synthesis
+python_version: "3.10"
+# GPU recommended for inference speed (optional: use CPU for light usage)
+# suggested_hardware: t4-small
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+
🎤 SoulX-Singer
+
+ Official inference code for
+ SoulX-Singer: Towards High-Quality Zero-Shot Singing Voice Synthesis
+
+
+
+
+
+
+
+
+
+
+
+
+---
+
+## 🎵 Overview
+
+**SoulX-Singer** is a high-fidelity, zero-shot singing voice synthesis model that enables users to generate realistic singing voices for unseen singers.
+It supports **melody-conditioned (F0 contour)** and **score-conditioned (MIDI notes)** control for precise pitch, rhythm, and expression.
+
+---
+
+## ✨ Key Features
+
+- **🎤 Zero-Shot Singing** – Generate high-fidelity voices for unseen singers, no fine-tuning needed.
+- **🎵 Flexible Control Modes** – Melody (F0) and Score (MIDI) conditioning.
+- **📚 Large-Scale Dataset** – 42,000+ hours of aligned vocals, lyrics, notes across Mandarin, English, Cantonese.
+- **🧑🎤 Timbre Cloning** – Preserve singer identity across languages, styles, and edited lyrics.
+- **✏️ Singing Voice Editing** – Modify lyrics while keeping natural prosody.
+- **🌐 Cross-Lingual Synthesis** – High-fidelity synthesis by disentangling timbre from content.
+
+---
+
+
+
+
+
+---
+
+## 🎬 Demo Examples
+
+
+
+
+
+
+
+
+
+
+
+
+
+---
+
+## 📰 News
+
+- **[2026-02-06]** SoulX-Singer inference code and models released.
+
+---
+
+## 🚀 Quick Start
+
+**Note:** This repo does not ship pretrained weights. SVS and preprocessing models must be downloaded from Hugging Face (see step 3).
+
+### 1. Clone Repository
+
+```bash
+git clone https://github.com/Soul-AILab/SoulX-Singer.git
+cd SoulX-Singer
+```
+
+### 2. Set Up Environment
+
+**1. Install Conda** (if not already installed): https://docs.conda.io/en/latest/miniconda.html
+
+**2. Create and activate a Conda environment:**
+```
+conda create -n soulxsinger -y python=3.10
+conda activate soulxsinger
+```
+**3. Install dependencies:**
+```
+pip install -r requirements.txt
+```
+⚠️ If you are in mainland China, use a PyPI mirror:
+```
+pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
+```
+
+
+---
+
+### 3. Download Pretrained Models
+
+**This repository does not include pretrained models.** You must download them from Hugging Face:
+
+- [Soul-AILab/SoulX-Singer](https://huggingface.co/Soul-AILab/SoulX-Singer) (SVS model)
+- [Soul-AILab/SoulX-Singer-Preprocess](https://huggingface.co/Soul-AILab/SoulX-Singer-Preprocess) (preprocessing models)
+
+Install Hugging Face Hub and download:
+
+```sh
+pip install -U huggingface_hub
+
+# SoulX-Singer SVS model
+huggingface-cli download Soul-AILab/SoulX-Singer --local-dir pretrained_models/SoulX-Singer
+
+# Preprocessing models (vocal separation, F0, ASR, etc.)
+huggingface-cli download Soul-AILab/SoulX-Singer-Preprocess --local-dir pretrained_models/SoulX-Singer-Preprocess
+```
+
+
+### 4. Run the Demo
+
+Run the inference demo:
+``` sh
+bash example/infer.sh
+```
+
+This script relies on metadata generated from the preprocessing pipeline, including vocal separation and transcription. Users should follow the steps in [preprocess](preprocess/README.md) to prepare the necessary metadata before running the demo with their own data.
+
+**⚠️ Important Note**
+The metadata produced by the automatic preprocessing pipeline may not perfectly align the singing audio with the corresponding lyrics and musical notes. For best synthesis quality, we strongly recommend manually correcting the alignment using the 🎼 [Midi-Editor](https://huggingface.co/spaces/Soul-AILab/SoulX-Singer-Midi-Editor).
+
+How to use the Midi-Editor:
+- [Eiditing Metadata with Midi-Editor](preprocess/README.md#L104-L105)
+
+
+### 🌐 WebUI
+
+You can launch the interactive interface with:
+```
+python webui.py
+```
+
+### 🚀 Deploy as Hugging Face Space
+
+This repo is ready to deploy as a [Hugging Face Space](https://huggingface.co/spaces). **Pretrained models are not included;** `app.py` downloads them from the Hub on first run.
+
+**📖 详细部署指南请查看:[DEPLOY.md](DEPLOY.md)**
+
+**快速步骤:**
+
+1. **创建 Space**:访问 [huggingface.co/spaces](https://huggingface.co/spaces),点击 "Create new Space",选择 **Gradio** SDK
+2. **上传代码**:使用 Git 推送或 Web 界面上传代码文件
+3. **配置硬件**:在 Space Settings 中选择 **GPU T4 Small**(推荐)以加快推理速度
+4. **等待启动**:Space 会自动安装依赖、下载模型并启动应用(首次运行可能需要 5-15 分钟)
+
+模型会自动从以下仓库下载:
+- [Soul-AILab/SoulX-Singer](https://huggingface.co/Soul-AILab/SoulX-Singer) (SVS model)
+- [Soul-AILab/SoulX-Singer-Preprocess](https://huggingface.co/Soul-AILab/SoulX-Singer-Preprocess) (preprocessing models)
+
+
+
+## 🚧 Roadmap
+
+- [ ] 🖥️ Web-based UI for easy and interactive inference
+- [ ] 🌐 Online demo deployment on Hugging Face Spaces
+- [ ] 📊 Release the SoulX-Singer-Eval benchmark
+- [ ] 📚 Comprehensive tutorials and usage documentation
+
+
+## 🙏 Acknowledgements
+
+Special thanks to the following open-source projects:
+
+- [F5-TTS](https://github.com/SWivid/F5-TTS)
+- [Amphion](https://github.com/open-mmlab/Amphion/tree/main)
+- [Music Source Separation Training](https://github.com/ZFTurbo/Music-Source-Separation-Training)
+- [Lead Vocal Separation](https://huggingface.co/becruily/mel-band-roformer-karaoke)
+- [Vocal Dereverberation](https://huggingface.co/anvuew/dereverb_mel_band_roformer)
+- [RMVPE](https://github.com/Dream-High/RMVPE)
+[Paraformer](https://modelscope.cn/models/iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch)
+- [Parakeet-tdt-0.6b-v2](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2)
+- [ROSVOT](https://github.com/RickyL-2000/ROSVOT)
+
+
+
+## 📄 License
+
+We use the Apache 2.0 license. Researchers and developers are free to use the codes and model weights of our SoulX-Singer. Check the license at [LICENSE](LICENSE) for more details.
+
+
+## ⚠️ Usage Disclaimer
+
+SoulX-Singer is intended for academic research, educational purposes, and legitimate applications such as personalized singing synthesis and assistive technologies.
+
+Please note:
+
+- 🎤 Respect intellectual property, privacy, and personal consent when generating singing content.
+- 🚫 Do not use the model to impersonate individuals without authorization or to create deceptive audio.
+- ⚠️ The developers assume no liability for any misuse of this model.
+
+We advocate for the responsible development and use of AI and encourage the community to uphold safety and ethical principles. For ethics or misuse concerns, please contact us.
+
+
+## 📬 Contact Us
+
+We welcome your feedback, questions, and collaboration:
+
+- **Email**: qianjiale@soulapp.cn | menghao@soulapp.cn | wangxinsheng@soulapp.cn
+
+- **Join discussions**: WeChat or Soul APP groups for technical discussions and updates:
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..c678df5dd66dc1ad76e511d21c38b7bc941d8502
--- /dev/null
+++ b/app.py
@@ -0,0 +1,63 @@
+"""
+Hugging Face Space entry point for SoulX-Singer.
+Downloads pretrained models from the Hub if needed, then launches the Gradio app.
+"""
+import os
+import sys
+from pathlib import Path
+
+ROOT = Path(__file__).resolve().parent
+PRETRAINED_DIR = ROOT / "pretrained_models"
+MODEL_DIR_SVS = PRETRAINED_DIR / "SoulX-Singer"
+MODEL_DIR_PREPROCESS = PRETRAINED_DIR / "SoulX-Singer-Preprocess"
+
+
+def ensure_pretrained_models():
+ """Download SoulX-Singer and Preprocess models from Hugging Face Hub if not present."""
+ if (MODEL_DIR_SVS / "model.pt").exists() and MODEL_DIR_PREPROCESS.exists():
+ print("Pretrained models already present, skipping download.", flush=True)
+ return
+
+ try:
+ from huggingface_hub import snapshot_download
+ except ImportError:
+ print(
+ "huggingface_hub not installed. Install with: pip install huggingface_hub",
+ file=sys.stderr,
+ flush=True,
+ )
+ raise
+
+ PRETRAINED_DIR.mkdir(parents=True, exist_ok=True)
+
+ if not (MODEL_DIR_SVS / "model.pt").exists():
+ print("Downloading SoulX-Singer model...", flush=True)
+ snapshot_download(
+ repo_id="Soul-AILab/SoulX-Singer",
+ local_dir=str(MODEL_DIR_SVS),
+ local_dir_use_symlinks=False,
+ )
+ print("SoulX-Singer model ready.", flush=True)
+
+ if not MODEL_DIR_PREPROCESS.exists():
+ print("Downloading SoulX-Singer-Preprocess models...", flush=True)
+ snapshot_download(
+ repo_id="Soul-AILab/SoulX-Singer-Preprocess",
+ local_dir=str(MODEL_DIR_PREPROCESS),
+ local_dir_use_symlinks=False,
+ )
+ print("SoulX-Singer-Preprocess models ready.", flush=True)
+
+
+if __name__ == "__main__":
+ os.chdir(ROOT)
+ ensure_pretrained_models()
+
+ from webui import render_interface
+
+ page = render_interface()
+ page.queue()
+ page.launch(
+ server_name="0.0.0.0",
+ server_port=int(os.environ.get("PORT", "7860")),
+ )
diff --git a/assets/performance_radar.png b/assets/performance_radar.png
new file mode 100644
index 0000000000000000000000000000000000000000..8a7666e0a24734331da1d8cd607234f4518d8bcf
--- /dev/null
+++ b/assets/performance_radar.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8a5fe64523e65072d7c8014e4584b9f20b5e4f43bbd54edee9f2a068ef174162
+size 137183
diff --git a/assets/soul_wechat01.jpg b/assets/soul_wechat01.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..13e2085bff3906e6e0efdd365254d075709dc1fe
--- /dev/null
+++ b/assets/soul_wechat01.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b452c23c33f4d0771f922aed4ceb92c0d6e893e74061f78b69a222f94bbd3c4a
+size 834816
diff --git a/assets/soulx-logo.png b/assets/soulx-logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..c3e9b86c6f703bb1aca005d86f3c938ab0b83800
--- /dev/null
+++ b/assets/soulx-logo.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4fe6c191a71be0323d52b236d8ed57f346821ee66c4a9bd8b6232cbca9bf3daf
+size 636241
diff --git a/assets/technical-report.pdf b/assets/technical-report.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..cb1c3dcf0f58b8d8c01ac6203a9079c3520fe251
--- /dev/null
+++ b/assets/technical-report.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ab2876f8850ce09e2b8ce7e929f8b9adf7de10f13900cb013f548f9707b80061
+size 7927691
diff --git a/cli/inference.py b/cli/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..d91926be197e8e10f6b4393e4e0c4bc43bd2ca9a
--- /dev/null
+++ b/cli/inference.py
@@ -0,0 +1,147 @@
+import os
+import torch
+import json
+import argparse
+from tqdm import tqdm
+import numpy as np
+import soundfile as sf
+from collections import OrderedDict
+from omegaconf import DictConfig
+
+from soulxsinger.utils.file_utils import load_config
+from soulxsinger.models.soulxsinger import SoulXSinger
+from soulxsinger.utils.data_processor import DataProcessor
+
+
+def build_model(
+ model_path: str,
+ config: DictConfig,
+ device: str = "cuda",
+):
+ """
+ Build the model from the pre-trained model path and model configuration.
+
+ Args:
+ model_path (str): Path to the checkpoint file.
+ config (DictConfig): Model configuration.
+ device (str, optional): Device to use. Defaults to "cuda".
+
+ Returns:
+ Tuple[torch.nn.Module, torch.nn.Module]: The initialized model and vocoder.
+ """
+
+ if not os.path.isfile(model_path):
+ raise FileNotFoundError(
+ f"Model checkpoint not found: {model_path}. "
+ "Please download the pretrained model and place it at the path, or set --model_path."
+ )
+ model = SoulXSinger(config).to(device)
+ print("Model initialized.")
+ print("Model parameters:", sum(p.numel() for p in model.parameters()) / 1e6, "M")
+
+ checkpoint = torch.load(model_path, weights_only=False, map_location=device)
+ if "state_dict" not in checkpoint:
+ raise KeyError(
+ f"Checkpoint at {model_path} has no 'state_dict' key. "
+ "Expected a checkpoint saved with model.state_dict()."
+ )
+ model.load_state_dict(checkpoint["state_dict"], strict=True)
+
+ model.eval()
+ model.to(device)
+ print("Model checkpoint loaded.")
+
+ return model
+
+
+def process(args, config, model: torch.nn.Module):
+ """Run the full inference pipeline given a data_processor and model.
+ """
+ if args.control not in ("melody", "score"):
+ raise ValueError(f"control must be 'melody' or 'score', got: {args.control}")
+
+ print(f"prompt_metadata_path: {args.prompt_metadata_path}")
+ print(f"target_metadata_path: {args.target_metadata_path}")
+
+ os.makedirs(args.save_dir, exist_ok=True)
+ data_processor = DataProcessor(
+ hop_size=config.audio.hop_size,
+ sample_rate=config.audio.sample_rate,
+ phoneset_path=args.phoneset_path,
+ device=args.device,
+ )
+
+ with open(args.prompt_metadata_path, "r", encoding="utf-8") as f:
+ prompt_meta_list = json.load(f)
+ if not prompt_meta_list:
+ raise ValueError("Prompt metadata is empty. Please run preprocess on prompt audio first.")
+ prompt_meta = prompt_meta_list[0] # load the first segment as the prompt
+ with open(args.target_metadata_path, "r", encoding="utf-8") as f:
+ target_meta_list = json.load(f)
+ infer_prompt_data = data_processor.process(prompt_meta, args.prompt_wav_path)
+
+ assert len(target_meta_list) > 0, "No target segments found in the target metadata."
+ generated_len = int(target_meta_list[-1]["time"][1] / 1000 * config.audio.sample_rate)
+ generated_merged = np.zeros(generated_len, dtype=np.float32)
+
+ for idx, target_meta in enumerate(
+ tqdm(target_meta_list, total=len(target_meta_list), desc="Inferring segments"),
+ ):
+ start_sample_idx = int(target_meta["time"][0] / 1000 * config.audio.sample_rate)
+ end_sample_idx = int(target_meta["time"][1] / 1000 * config.audio.sample_rate)
+ infer_target_data = data_processor.process(target_meta, None)
+
+ infer_data = {
+ "prompt": infer_prompt_data,
+ "target": infer_target_data,
+ }
+
+ with torch.no_grad():
+ generated_audio = model.infer(
+ infer_data,
+ auto_shift=args.auto_shift,
+ pitch_shift=args.pitch_shift,
+ n_steps=config.infer.n_steps,
+ cfg=config.infer.cfg,
+ control=args.control,
+ )
+
+ generated_audio = generated_audio.squeeze().cpu().numpy()
+ generated_merged[start_sample_idx : start_sample_idx + generated_audio.shape[0]] = generated_audio
+
+ merged_path = os.path.join(args.save_dir, "generated.wav")
+ sf.write(merged_path, generated_merged, 24000)
+ print(f"Generated audio saved to {merged_path}")
+
+
+def main(args, config):
+ model = build_model(
+ model_path=args.model_path,
+ config=config,
+ device=args.device,
+ )
+ process(args, config, model)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--model_path", type=str, default='pretrained_models/soulx-singer/model.pt')
+ parser.add_argument("--config", type=str, default='soulxsinger/config/soulxsinger.yaml')
+ parser.add_argument("--prompt_wav_path", type=str, default='example/audio/zh_prompt.wav')
+ parser.add_argument("--prompt_metadata_path", type=str, default='example/metadata/zh_prompt.json')
+ parser.add_argument("--target_metadata_path", type=str, default='example/metadata/zh_target.json')
+ parser.add_argument("--phoneset_path", type=str, default='soulxsinger/utils/phoneme/phone_set.json')
+ parser.add_argument("--save_dir", type=str, default='outputs')
+ parser.add_argument("--auto_shift", action="store_true")
+ parser.add_argument("--pitch_shift", type=int, default=0)
+ parser.add_argument(
+ "--control",
+ type=str,
+ default="melody",
+ choices=["melody", "score"],
+ help="Control mode: melody or score only",
+ )
+ args = parser.parse_args()
+
+ config = load_config(args.config)
+ main(args, config)
diff --git a/deploy_to_hf.sh b/deploy_to_hf.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6175737832d5003f8f1651eced0765f56406b3ff
--- /dev/null
+++ b/deploy_to_hf.sh
@@ -0,0 +1,70 @@
+#!/bin/bash
+# 快速部署脚本:将 SoulX-Singer 部署到 Hugging Face Space
+# 使用方法: ./deploy_to_hf.sh YOUR_USERNAME YOUR_SPACE_NAME
+
+set -e
+
+if [ $# -lt 2 ]; then
+ echo "用法: $0 "
+ echo "示例: $0 myusername soulx-singer-demo"
+ exit 1
+fi
+
+USERNAME=$1
+SPACE_NAME=$2
+SPACE_REPO="https://huggingface.co/spaces/${USERNAME}/${SPACE_NAME}"
+
+echo "🚀 开始部署到 Hugging Face Space..."
+echo "Space: ${USERNAME}/${SPACE_NAME}"
+echo ""
+
+# 检查是否已安装 huggingface_hub
+if ! command -v huggingface-cli &> /dev/null; then
+ echo "⚠️ 未检测到 huggingface-cli,正在安装..."
+ pip install -U huggingface_hub
+fi
+
+# 检查是否已登录
+if ! huggingface-cli whoami &> /dev/null; then
+ echo "🔐 请先登录 Hugging Face..."
+ huggingface-cli login
+fi
+
+# 创建 Space(如果不存在)
+echo "📦 检查 Space 是否存在..."
+if ! huggingface-cli repo info "${USERNAME}/${SPACE_NAME}" --repo-type space &> /dev/null; then
+ echo "✨ 创建新的 Space..."
+ huggingface-cli repo create "${SPACE_NAME}" --type space --sdk gradio
+else
+ echo "✅ Space 已存在"
+fi
+
+# 检查是否已初始化 Git
+if [ ! -d ".git" ]; then
+ echo "📝 初始化 Git 仓库..."
+ git init
+ git add .
+ git commit -m "Initial commit for HF Space deployment" || echo "⚠️ 没有新文件需要提交"
+fi
+
+# 检查远程仓库
+if git remote | grep -q "^origin$"; then
+ echo "🔄 更新远程仓库地址..."
+ git remote set-url origin "${SPACE_REPO}"
+else
+ echo "➕ 添加远程仓库..."
+ git remote add origin "${SPACE_REPO}"
+fi
+
+# 推送代码
+echo "📤 推送代码到 Hugging Face..."
+git push -u origin main || git push -u origin master
+
+echo ""
+echo "✅ 部署完成!"
+echo "🌐 Space 地址: ${SPACE_REPO}"
+echo ""
+echo "💡 提示:"
+echo " - Space 会自动开始构建,请查看 Logs 标签页"
+echo " - 首次运行会下载模型,可能需要 5-15 分钟"
+echo " - 建议在 Space Settings 中选择 GPU T4 Small 硬件"
diff --git a/example/audio/en_prompt.json b/example/audio/en_prompt.json
new file mode 100644
index 0000000000000000000000000000000000000000..e43f9fa60b7306d9ddee413a795dcfc5760469e2
--- /dev/null
+++ b/example/audio/en_prompt.json
@@ -0,0 +1,16 @@
+[
+ {
+ "index": "vocal_5220_10280",
+ "language": "English",
+ "time": [
+ 5220,
+ 10280
+ ],
+ "duration": "0.24 0.36 0.30 0.78 0.24 0.56 0.19 0.53 0.36 0.20 0.32 0.57 0.19 0.22",
+ "text": " Ooh Ooh I wish nothing nothing more more the best best ",
+ "phoneme": " en_UW1 en_UW1 en_AY1 en_W-IH1-SH en_N-AH1-TH-IH0-NG en_N-AH1-TH-IH0-NG en_M-AO1-R en_M-AO1-R en_DH-AH0 en_B-EH1-S-T en_B-EH1-S-T ",
+ "note_pitch": "0 63 65 0 65 67 68 62 62 64 67 67 65 0",
+ "note_type": "1 2 3 1 2 2 2 3 2 3 2 2 3 1",
+ "f0": "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 345.2 343.1 341.6 339.8 337.8 331.9 319.5 312.1 310.8 312.6 315.1 316.1 315.3 314.6 315.3 317.9 322.0 329.6 337.5 344.7 347.5 347.2 344.3 339.5 338.2 341.7 342.8 342.2 340.7 343.0 342.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 347.0 345.3 348.7 350.2 350.9 350.3 344.7 340.3 338.3 338.0 342.8 347.4 348.3 346.7 343.4 339.5 340.5 345.2 350.4 357.7 367.3 376.6 385.9 392.6 393.6 389.9 384.7 381.8 382.0 383.0 380.6 373.5 367.9 377.0 385.4 391.4 393.8 395.6 396.1 397.2 399.8 406.0 413.5 416.1 416.0 414.4 413.5 412.9 415.5 418.9 417.5 408.8 389.2 373.9 0.0 0.0 0.0 288.5 286.0 284.2 285.6 288.9 291.3 293.5 294.5 295.2 297.8 299.5 301.0 303.0 305.9 306.8 306.0 304.4 301.8 301.0 300.8 301.8 310.2 309.8 308.2 305.9 303.6 301.5 299.3 298.5 300.0 302.1 303.5 303.6 302.2 299.7 297.5 296.3 296.4 296.8 298.6 302.6 311.8 322.0 333.8 349.0 368.8 393.3 407.1 410.7 407.0 402.3 401.2 401.7 403.9 405.7 403.5 396.8 387.4 378.6 377.8 381.4 384.0 384.7 383.5 382.5 380.8 377.3 378.4 383.5 390.0 392.7 390.5 387.6 385.3 382.7 381.0 382.8 383.9 382.2 379.6 379.3 380.2 383.1 386.0 386.5 385.4 384.3 383.7 384.4 386.2 388.2 388.5 385.0 378.6 360.4 333.7 328.2 332.4 340.2 348.9 339.6 334.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0"
+ }
+]
\ No newline at end of file
diff --git a/example/audio/en_prompt.mp3 b/example/audio/en_prompt.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..644a4ddbffda903003c66e939a9608892a269688
Binary files /dev/null and b/example/audio/en_prompt.mp3 differ
diff --git a/example/audio/en_target.json b/example/audio/en_target.json
new file mode 100644
index 0000000000000000000000000000000000000000..13d075226c25096a28d4381d76b717d14c03324f
--- /dev/null
+++ b/example/audio/en_target.json
@@ -0,0 +1,16 @@
+[
+ {
+ "index": "vocal_0_6900",
+ "language": "English",
+ "time": [
+ 0,
+ 6900
+ ],
+ "duration": "0.16 0.24 0.32 0.15 0.17 0.24 0.15 0.44 0.29 0.32 0.24 0.32 0.22 0.18 0.24 0.25 1.01 0.26 0.48 0.29 0.79 0.14",
+ "text": " Who says you're you're not pretty pretty Who says you're you're not beautiful beautiful Who says says ",
+ "phoneme": " en_HH-UW1 en_S-EH1-Z en_Y-UH1-R en_Y-UH1-R en_N-AA1-T en_P-R-IH1-T-IY0 en_P-R-IH1-T-IY0 en_HH-UW1 en_S-EH1-Z en_Y-UH1-R en_Y-UH1-R en_N-AA1-T en_B-Y-UW1-T-AH0-F-AH0-L en_B-Y-UW1-T-AH0-F-AH0-L en_HH-UW1 en_S-EH1-Z en_S-EH1-Z ",
+ "note_pitch": "0 68 67 65 63 63 66 67 70 66 68 67 65 63 63 67 65 63 65 61 58 0",
+ "note_type": "1 2 2 2 3 2 2 1 3 1 2 2 2 3 2 2 3 1 2 2 3 1",
+ "f0": "0.0 0.0 382.7 387.7 385.9 379.8 376.0 380.9 390.1 403.2 415.3 423.6 421.6 402.6 385.2 381.1 0.0 0.0 425.8 419.0 409.6 397.8 392.2 389.0 388.5 391.4 389.1 381.4 375.9 0.0 0.0 0.0 0.0 359.0 354.7 353.8 353.7 354.7 353.1 351.1 350.4 349.0 348.9 346.3 337.4 328.0 312.8 303.1 298.4 296.0 298.9 302.0 306.3 307.9 307.3 307.5 307.3 302.9 301.8 0.0 0.0 0.0 0.0 0.0 343.7 364.3 375.9 368.5 358.1 359.1 365.9 378.4 393.1 406.0 412.5 410.9 407.0 404.1 403.5 403.4 401.5 399.4 397.7 395.4 394.4 394.8 395.5 396.5 397.5 400.8 407.9 415.1 417.8 453.1 472.2 481.0 482.3 481.9 480.8 478.7 477.4 476.8 474.8 467.5 446.0 390.4 382.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 374.3 375.5 370.9 370.7 373.1 378.3 392.2 407.6 418.5 423.9 423.2 415.9 395.4 0.0 0.0 0.0 421.5 416.2 405.3 391.0 383.1 380.8 383.0 388.3 388.8 378.3 371.7 0.0 0.0 0.0 371.4 365.1 362.7 358.5 353.0 352.0 353.5 356.1 356.4 353.6 348.3 341.1 330.6 317.7 303.8 293.3 296.5 297.7 301.4 305.3 308.8 308.8 308.2 308.2 306.3 305.6 285.0 269.8 265.6 280.0 304.4 331.2 351.0 357.9 364.2 370.6 381.1 392.9 399.0 399.1 395.0 389.5 379.9 363.0 338.9 318.5 305.6 300.3 299.6 296.3 292.2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 309.6 322.1 329.8 331.2 332.1 332.6 332.4 335.4 340.7 345.0 347.2 346.2 342.6 339.6 337.4 338.3 340.9 342.6 344.0 344.6 344.0 344.2 343.6 341.9 338.8 336.7 337.6 341.1 347.0 350.4 343.0 326.6 330.8 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 310.0 315.7 317.7 317.4 316.5 314.4 314.1 322.5 336.3 350.0 354.2 352.9 350.5 348.4 347.1 347.3 348.4 349.8 349.8 350.4 350.3 323.9 324.7 0.0 0.0 0.0 0.0 0.0 0.0 0.0 297.3 289.8 279.5 275.1 276.1 276.1 274.9 275.4 274.6 271.8 268.6 264.0 258.3 251.7 244.3 239.9 236.1 233.7 234.0 236.0 237.3 236.9 235.2 233.5 231.7 231.0 232.1 233.6 235.4 236.2 236.7 235.8 234.1 232.2 231.3 232.6 233.5 235.2 236.0 232.3 228.8 229.6 233.8 241.3 239.4 226.3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0"
+ }
+]
\ No newline at end of file
diff --git a/example/audio/en_target.mp3 b/example/audio/en_target.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..b6eb34fdfcadbbbdb1032a3d5eb950ad14de9f3c
Binary files /dev/null and b/example/audio/en_target.mp3 differ
diff --git a/example/audio/music.json b/example/audio/music.json
new file mode 100644
index 0000000000000000000000000000000000000000..bbb6eb7ada45c62ad235cac1f4483744f31f20dc
--- /dev/null
+++ b/example/audio/music.json
@@ -0,0 +1,16 @@
+[
+ {
+ "index": "vocal_240_51240",
+ "language": "Mandarin",
+ "time": [
+ 240,
+ 51240
+ ],
+ "duration": "0.21 0.18 0.34 0.38 0.26 0.22 0.50 0.20 0.33 0.13 0.60 0.22 0.38 0.52 0.30 0.12 1.82 0.98 0.20 0.26 0.38 0.36 0.58 0.54 0.26 0.50 1.46 3.03 0.24 0.24 0.29 0.25 0.20 0.14 0.80 0.22 0.54 0.30 0.24 0.16 0.58 0.28 0.38 1.73 0.79 0.24 0.30 0.34 0.30 0.30 0.34 0.21 0.21 0.36 0.34 0.23 1.85 1.90 0.23 0.39 0.68 0.50 0.31 0.43 0.76 0.38 2.00 1.87 0.68 0.72 0.56 0.62 0.80 0.40 0.42 1.68 1.79 0.70 0.66 0.54 0.24 0.48 0.68 0.40 2.34 0.14",
+ "text": " 只 是 因 为 为 在 人 群 群 中 多 看 了 你 一 眼 再 也 没 能 忘 掉 你 容 颜 梦 想 着 着 偶 偶 然 然 有 一 一 天 再 相 见 从 此 我 开 始 始 孤 孤 单 思 念 念 想 想 你 时 你 你 在 天 边 想 你 时 你 在 眼 前 前 想 你 时 你 你 在 脑 海 ",
+ "phoneme": " zh_zhi3 zh_shi4 zh_yin1 zh_wei4 zh_wei2 zh_zai4 zh_ren2 zh_qun2 zh_qun2 zh_zhong1 zh_duo1 zh_kan4 zh_le5 zh_ni3 zh_yi1 zh_yan3 zh_zai4 zh_ye3 zh_mei2 zh_neng2 zh_wang4 zh_diao4 zh_ni3 zh_rong2 zh_yan2 zh_meng4 zh_xiang3 zh_zhe5 zh_zhe5 zh_ou3 zh_ou3 zh_ran2 zh_ran2 zh_you3 zh_yi1 zh_yi1 zh_tian1 zh_zai4 zh_xiang1 zh_jian4 zh_cong2 zh_ci3 zh_wo3 zh_kai1 zh_shi3 zh_shi3 zh_gu1 zh_gu1 zh_dan1 zh_si1 zh_nian4 zh_nian4 zh_xiang3 zh_xiang3 zh_ni3 zh_shi2 zh_ni3 zh_ni3 zh_zai4 zh_tian1 zh_bian1 zh_xiang3 zh_ni3 zh_shi2 zh_ni3 zh_zai4 zh_yan3 zh_qian2 zh_qian2 zh_xiang3 zh_ni3 zh_shi2 zh_ni3 zh_ni3 zh_zai4 zh_nao3 zh_hai3 ",
+ "note_pitch": "0 64 64 64 66 68 66 66 66 64 64 64 64 66 64 60 61 0 63 63 63 64 66 63 61 59 56 0 68 66 66 68 68 66 66 64 64 0 64 66 61 61 61 64 0 62 63 63 64 64 66 65 66 61 58 59 56 0 69 71 66 68 68 71 66 64 61 0 66 61 68 66 64 63 61 59 0 71 66 68 68 71 66 64 61 0",
+ "note_type": "1 2 2 2 2 3 2 2 2 3 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 1 2 2 2 3 2 3 2 3 2 1 2 3 2 2 2 2 1 2 2 2 2 2 3 2 3 2 2 2 3 1 2 3 2 2 2 3 2 2 2 1 2 2 2 2 2 2 2 3 1 2 2 2 2 3 2 2 2 1",
+ "f0": "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 318.9 332.2 331.2 323.9 312.3 0.0 0.0 0.0 0.0 0.0 0.0 340.9 342.3 340.2 337.4 333.5 329.3 327.8 328.3 329.1 330.5 331.9 331.5 329.2 327.6 327.4 327.7 329.6 330.5 329.5 328.1 328.5 330.1 330.4 329.5 328.9 331.0 333.6 332.8 331.7 330.2 330.6 330.9 329.5 327.6 332.4 343.5 363.9 368.7 369.3 368.2 366.7 371.3 382.0 392.9 402.3 409.8 413.8 416.3 416.4 415.4 414.1 413.4 415.3 417.4 418.3 413.6 416.0 0.0 0.0 0.0 0.0 368.2 362.5 361.3 362.0 362.9 364.3 366.4 367.5 368.4 368.9 368.1 367.6 369.7 371.1 371.1 369.5 368.0 367.3 366.6 367.1 369.5 372.0 371.3 368.2 368.1 369.6 369.9 373.4 376.5 360.2 285.6 0.0 0.0 0.0 0.0 0.0 0.0 0.0 384.0 376.7 372.0 371.7 373.9 376.3 375.5 361.2 333.0 317.8 319.0 329.1 333.5 333.2 333.2 333.3 328.0 314.5 0.0 0.0 0.0 320.6 326.5 332.9 335.7 334.1 329.9 327.4 326.3 326.6 329.1 330.5 330.0 328.0 326.7 328.6 331.1 332.4 332.1 332.8 333.2 333.1 331.8 328.6 318.9 293.2 323.6 327.5 330.0 332.8 331.1 319.4 272.2 0.0 0.0 0.0 0.0 278.5 294.5 310.0 318.5 322.7 327.4 334.1 337.2 336.3 328.5 321.0 324.2 341.5 362.2 374.7 375.8 370.5 366.0 365.2 368.7 371.7 374.4 373.0 368.7 370.4 374.8 375.1 372.5 368.7 363.0 357.1 359.0 366.8 377.1 379.5 371.5 359.9 351.0 358.4 371.7 377.9 375.8 367.8 359.7 363.6 375.3 379.9 377.5 372.9 359.8 345.2 337.1 334.5 333.4 332.8 329.6 319.2 292.8 261.7 253.4 262.8 273.0 278.1 279.3 278.6 277.6 277.9 278.3 278.0 277.3 276.7 275.6 275.3 276.1 277.8 278.2 277.7 277.6 277.2 276.4 275.5 273.9 271.8 270.7 274.7 282.7 285.5 281.4 273.1 264.4 262.2 268.9 278.4 284.3 283.6 277.3 268.5 262.2 263.2 270.7 278.8 285.0 287.3 282.0 272.1 265.2 266.0 272.1 278.2 283.8 285.1 281.4 272.6 267.2 269.1 276.1 282.7 289.3 290.5 285.6 273.9 267.5 267.4 271.8 277.7 282.6 283.1 277.8 269.8 265.6 268.6 273.8 281.1 286.2 287.2 282.3 270.4 264.5 263.3 263.0 268.5 268.3 268.7 270.9 277.2 277.4 284.3 285.7 282.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 311.5 306.4 306.0 306.8 305.8 305.0 307.2 310.8 312.4 313.5 314.4 314.6 313.7 312.2 311.1 311.1 312.8 314.6 312.7 306.5 299.7 302.7 308.7 315.2 314.2 310.9 309.4 309.8 310.6 310.7 310.8 311.4 312.2 313.0 312.5 312.2 312.2 311.9 312.0 312.2 311.4 306.1 303.3 309.3 315.0 325.3 324.4 323.9 325.0 326.4 327.7 327.7 328.1 328.6 327.8 328.3 328.3 327.6 328.6 336.1 347.2 359.4 371.0 374.9 371.3 366.5 369.5 375.5 376.2 370.2 365.7 367.3 372.9 376.9 375.3 368.4 360.1 358.0 365.4 380.8 383.5 380.4 374.4 367.6 365.9 372.0 376.3 377.9 375.7 372.6 368.6 361.2 353.7 0.0 309.4 302.3 299.2 300.6 305.7 308.5 308.6 309.0 310.9 310.9 309.5 308.6 309.0 311.3 313.0 314.4 314.5 313.0 312.1 311.2 308.0 299.5 295.8 295.0 285.7 278.0 280.1 281.2 279.9 278.3 278.3 279.5 279.7 272.2 259.1 0.0 0.0 0.0 240.7 236.0 233.8 233.5 234.6 235.7 235.6 236.2 239.8 245.4 248.8 250.1 250.4 249.9 249.1 247.2 241.7 231.7 216.0 197.6 192.1 197.2 204.3 206.5 205.4 201.7 200.8 203.5 208.2 210.9 210.3 207.1 203.5 202.3 202.1 202.8 205.1 207.3 208.1 207.3 205.4 202.1 199.3 200.5 206.1 212.9 212.9 206.5 197.1 191.1 191.2 196.8 203.6 207.8 208.7 206.4 201.7 197.3 197.4 200.7 206.5 209.9 209.9 206.9 203.3 201.2 203.4 207.7 210.5 208.9 207.6 206.5 204.2 202.0 203.6 205.7 210.8 213.1 214.3 210.6 204.1 199.7 202.3 211.9 217.7 215.1 215.0 215.3 213.3 0.0 0.0 216.3 0.0 0.0 195.2 209.0 205.3 201.3 196.8 195.7 195.8 195.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 431.9 425.8 421.0 417.7 419.3 422.7 423.4 421.1 417.9 411.9 408.6 411.5 414.9 419.6 426.7 424.2 0.0 0.0 0.0 280.8 279.4 369.7 371.8 370.6 369.8 370.5 370.5 371.4 374.6 374.0 361.0 332.7 328.0 0.0 359.0 365.0 371.1 373.8 371.9 369.7 369.6 371.0 378.3 390.4 403.4 414.1 417.9 417.5 417.0 416.6 415.7 414.9 414.5 413.8 413.5 413.6 413.8 415.2 416.9 417.3 415.4 415.0 414.1 411.4 409.2 403.5 392.7 375.6 367.5 365.8 368.8 370.7 370.8 369.0 367.0 366.0 366.2 367.8 369.2 368.8 367.2 366.7 367.3 367.3 368.8 368.5 366.8 365.3 364.0 363.0 365.0 367.3 367.7 365.5 364.0 364.9 367.4 370.3 371.3 369.9 366.5 364.1 368.3 385.8 406.7 409.4 399.7 376.0 357.2 359.9 372.4 377.5 366.9 345.1 320.6 315.8 321.8 329.0 329.9 326.4 324.5 324.8 324.9 325.3 325.1 324.7 325.9 326.8 326.6 326.6 326.4 325.4 323.2 322.5 323.9 326.2 328.6 329.9 329.5 329.0 328.2 327.7 327.9 329.6 331.2 331.4 334.2 340.1 335.5 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 309.0 320.1 329.6 334.4 334.4 334.5 338.2 346.0 355.1 361.4 366.1 369.5 372.7 373.3 372.9 375.2 375.8 374.7 367.6 360.1 0.0 0.0 0.0 0.0 0.0 0.0 283.6 280.4 277.2 274.3 271.8 271.0 273.3 277.2 278.8 278.1 277.2 277.4 278.8 279.3 280.2 281.4 279.2 274.9 268.3 254.1 239.9 0.0 0.0 267.7 267.7 270.1 271.8 273.1 276.3 278.3 269.1 256.1 0.0 0.0 0.0 0.0 286.0 281.1 275.7 272.7 271.7 271.7 272.5 274.0 275.5 277.9 281.1 287.1 314.0 342.2 363.0 376.8 375.4 365.8 351.8 323.3 0.0 0.0 0.0 0.0 299.6 322.7 336.5 336.9 335.4 332.9 330.0 326.6 323.4 322.3 324.4 326.7 328.5 328.5 326.4 324.0 322.5 323.2 325.5 327.7 328.6 328.1 325.4 320.9 316.5 316.1 320.8 327.7 333.7 333.2 325.6 315.5 306.9 305.5 315.1 331.0 343.3 341.6 331.4 319.5 308.3 307.6 318.1 329.5 338.2 342.0 337.4 329.4 321.0 316.6 319.8 330.8 340.6 344.5 342.3 331.4 319.3 315.7 322.9 329.3 336.5 344.0 341.1 328.3 318.7 320.0 328.2 333.5 337.9 339.7 338.5 327.3 323.3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 285.3 284.6 284.3 288.2 295.1 305.6 312.2 315.5 287.8 0.0 0.0 0.0 0.0 0.0 309.8 312.4 313.8 312.8 310.8 310.7 312.5 313.2 310.7 305.7 305.9 308.1 309.9 311.3 311.0 309.8 309.7 310.6 311.1 311.1 311.0 311.9 313.7 316.4 319.5 311.6 281.3 283.4 306.6 0.0 0.0 0.0 334.4 326.7 322.6 321.6 321.7 324.5 332.0 336.7 333.1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 332.5 331.7 332.7 334.8 334.7 335.9 339.7 345.6 353.9 362.3 370.2 374.0 373.8 371.4 370.0 368.2 367.0 368.0 370.4 371.2 371.3 370.0 360.3 0.0 0.0 0.0 319.1 326.8 331.5 332.0 331.6 327.2 328.2 334.1 343.1 350.7 358.0 363.5 368.4 370.8 371.8 371.1 370.2 368.5 368.0 371.1 375.5 376.3 369.5 0.0 0.0 0.0 280.4 267.2 263.6 265.8 271.2 275.0 276.4 276.9 277.1 280.9 287.5 284.7 280.7 0.0 0.0 0.0 0.0 0.0 0.0 231.6 226.0 222.9 224.1 226.2 228.5 234.3 240.9 246.6 249.8 249.6 246.3 245.2 245.8 247.2 248.3 249.3 249.0 245.9 242.7 235.8 226.3 217.8 213.1 209.3 208.0 207.0 205.9 205.8 203.4 201.5 205.3 207.5 209.1 210.4 208.9 203.2 199.2 199.3 201.2 205.6 206.3 204.0 202.7 202.2 203.3 206.1 205.6 201.9 198.8 195.5 195.6 198.7 204.2 214.5 218.3 214.3 207.0 199.1 192.4 189.9 193.6 201.0 210.6 212.4 209.2 202.0 195.2 190.0 190.4 196.0 204.2 209.2 208.9 203.0 195.8 190.7 189.8 194.6 200.8 208.7 213.5 210.5 202.0 193.9 187.7 187.8 193.1 199.8 202.8 204.1 203.8 200.4 197.0 193.9 191.6 192.5 198.4 204.8 203.9 203.5 201.3 200.2 198.4 198.5 201.0 204.0 204.6 205.4 202.0 199.2 194.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 426.6 432.4 436.6 437.2 434.6 433.0 434.0 453.1 473.8 492.8 503.2 501.8 491.3 479.6 478.1 489.7 503.6 508.3 503.2 489.5 482.3 484.2 495.7 505.4 506.6 501.6 495.4 492.7 497.1 500.4 496.1 490.1 480.4 453.8 412.6 373.6 357.7 354.5 356.8 361.0 365.0 367.8 369.5 370.1 371.1 371.3 370.8 370.1 369.7 369.0 369.5 370.4 371.1 370.9 370.2 370.8 371.3 372.7 380.2 386.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 435.0 427.8 417.0 411.6 412.6 411.2 410.7 412.5 413.7 414.9 414.8 415.7 415.4 414.0 413.8 414.6 415.2 416.3 415.2 413.6 413.4 414.9 416.8 416.4 414.7 413.8 413.2 414.7 417.5 420.7 428.2 443.1 459.8 481.7 502.3 510.8 509.1 497.8 490.6 493.5 500.8 508.3 511.8 504.6 493.5 487.3 492.3 507.2 518.8 514.8 496.4 489.6 490.4 491.8 493.3 0.0 0.0 0.0 0.0 0.0 391.9 368.8 360.2 358.3 360.0 362.0 361.8 360.9 360.3 361.8 365.7 368.7 369.3 367.7 364.7 365.6 368.4 370.4 371.7 369.0 366.8 367.5 367.9 369.4 374.6 384.6 396.8 388.2 382.8 0.0 0.0 0.0 0.0 0.0 352.6 339.5 328.2 326.0 326.5 326.9 328.2 329.4 329.4 329.3 337.3 349.1 350.9 343.6 333.4 326.5 327.1 331.5 333.4 326.8 318.4 0.0 0.0 0.0 282.0 283.6 281.6 279.5 273.2 267.9 268.8 273.2 277.5 279.8 278.7 276.8 273.3 268.7 265.0 266.1 272.4 282.3 284.9 279.4 268.5 255.3 253.0 259.9 270.2 281.5 284.3 279.7 270.1 260.2 256.4 260.1 266.6 273.0 279.6 283.6 282.3 272.3 262.1 255.8 259.8 269.6 274.9 278.7 278.4 271.5 261.7 255.8 256.8 263.9 272.1 279.7 278.4 269.9 259.0 255.9 262.1 269.2 274.5 279.6 282.1 280.7 274.3 269.9 270.9 270.7 273.8 276.9 274.9 271.1 266.5 266.7 269.4 276.9 282.9 282.0 279.5 274.9 271.2 267.9 263.1 270.2 281.3 285.5 283.9 280.6 271.1 262.8 263.9 263.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 362.9 364.0 366.8 366.6 363.6 359.8 357.7 359.9 367.6 371.6 365.4 359.3 358.6 359.5 365.5 375.7 378.1 371.5 359.5 350.8 358.6 374.7 381.7 376.8 363.6 353.3 356.9 367.8 378.3 380.8 375.7 372.3 372.6 374.7 372.9 365.2 346.4 314.8 285.6 277.3 275.4 275.3 277.5 279.2 279.2 278.8 277.6 276.4 276.5 277.6 278.8 279.9 281.8 282.7 281.0 279.2 278.6 279.4 279.9 279.2 278.7 281.7 285.8 289.9 288.7 284.4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 428.0 417.9 410.0 409.5 411.2 414.0 417.5 417.9 418.1 419.3 419.0 417.8 417.1 417.2 416.6 415.8 415.0 415.1 414.5 410.9 401.5 388.2 371.5 359.2 363.0 371.7 375.7 370.6 359.6 353.4 358.7 367.0 374.5 378.2 373.7 362.8 358.0 361.6 368.5 373.7 377.7 375.8 368.8 361.7 360.5 366.7 372.5 375.6 380.1 384.9 383.5 376.6 0.0 0.0 0.0 0.0 0.0 331.7 320.0 316.8 320.7 325.1 327.0 327.4 326.6 326.4 325.5 325.1 326.2 327.6 328.5 328.4 327.7 328.7 329.9 329.8 329.7 327.7 326.5 327.8 328.9 329.3 329.8 329.2 328.8 329.4 330.4 331.0 331.0 330.5 329.2 328.2 328.3 328.7 329.4 327.5 322.9 314.8 299.3 290.8 292.3 299.0 305.4 309.2 312.1 313.9 316.8 320.4 327.6 333.1 338.0 341.2 342.1 338.4 331.0 0.0 0.0 0.0 0.0 0.0 0.0 311.7 292.5 281.8 276.1 275.1 276.6 277.6 277.4 274.0 264.8 249.6 239.1 238.8 243.0 245.7 245.0 241.9 237.9 234.6 237.1 242.7 250.4 252.5 248.3 240.7 231.6 228.3 234.0 240.6 247.0 250.6 248.5 242.1 233.6 226.8 227.2 234.0 240.7 246.9 248.9 244.4 237.6 230.8 228.8 236.4 246.9 250.2 249.2 245.4 239.9 232.5 224.7 226.5 238.5 252.0 257.4 255.9 248.4 237.4 230.5 231.6 238.8 247.6 252.6 253.9 253.0 248.9 241.8 236.4 234.1 236.5 246.2 257.7 255.4 248.2 238.5 234.8 237.8 244.4 250.5 253.8 251.7 246.7 240.0 238.3 244.8 251.6 256.4 256.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 484.0 484.5 487.5 490.6 487.2 479.7 478.9 483.8 494.1 505.7 510.3 501.1 485.4 472.4 472.3 495.6 515.8 517.1 503.8 486.5 475.1 479.9 491.4 507.9 509.9 505.0 496.2 485.7 484.5 490.8 495.4 497.7 496.6 489.9 469.1 424.3 387.2 373.0 367.3 364.4 364.6 366.2 366.6 368.8 371.9 373.1 371.4 368.9 367.1 367.0 367.7 368.8 370.1 369.8 369.2 369.8 370.0 369.0 369.0 371.0 375.2 385.3 391.6 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 425.2 418.5 409.4 409.5 408.1 406.3 408.4 410.3 412.8 413.8 412.6 411.2 410.8 410.4 410.2 410.5 412.2 413.6 413.4 412.6 412.6 413.4 415.6 417.9 418.5 417.0 415.0 415.0 416.8 421.5 437.9 455.0 472.6 493.0 502.2 496.7 484.6 480.6 487.2 498.6 506.3 506.2 499.4 489.7 489.5 496.2 503.8 507.2 501.8 491.1 489.9 501.1 520.5 533.7 525.8 0.0 0.0 0.0 0.0 0.0 421.8 365.2 351.0 352.0 358.0 362.6 365.6 365.4 363.2 362.5 362.2 362.4 364.1 365.2 364.9 363.5 363.4 366.3 368.5 370.0 369.7 368.5 366.9 365.7 366.3 368.6 370.8 372.2 370.2 373.1 377.0 376.5 371.0 346.8 326.1 325.2 328.4 331.4 334.0 332.2 327.2 324.4 331.3 349.7 363.2 362.0 346.1 327.4 321.8 329.8 340.7 348.4 349.9 346.1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 269.5 275.7 282.6 282.9 278.8 275.2 274.3 273.5 273.9 275.3 276.4 277.3 276.1 272.3 268.0 268.3 272.5 279.1 281.7 280.1 271.3 261.0 257.3 262.7 273.2 280.8 283.3 279.3 269.1 258.3 258.0 266.3 278.2 286.7 287.8 283.4 274.9 264.7 257.7 260.0 272.0 286.3 294.7 289.8 274.5 263.8 263.4 268.8 277.3 284.1 286.5 285.2 281.7 272.7 264.3 260.3 267.7 281.7 289.2 289.0 281.2 266.2 256.0 254.4 261.5 276.4 286.8 288.8 286.7 273.8 260.9 260.8 270.2 283.4 291.3 292.8 286.1 273.8 265.2 264.0 271.3 283.8 290.3 289.7 277.7 266.9 260.5 263.0 267.8 281.5 286.7 286.1 285.5 280.1 279.0 283.0 284.1 284.5 285.5 282.2 0.0 0.0 280.0 276.4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0"
+ }
+]
\ No newline at end of file
diff --git a/example/audio/music.mp3 b/example/audio/music.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..5004590440b19314d3dee6fb94bbee44df283aa6
--- /dev/null
+++ b/example/audio/music.mp3
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:04b35a7b9d03adc494c304af5c4413aa33a02a54a7110016d6e3b559843d90de
+size 1243961
diff --git a/example/audio/yue_target.json b/example/audio/yue_target.json
new file mode 100644
index 0000000000000000000000000000000000000000..2637c5b1908375013e94217f56c7a4a378b77f12
--- /dev/null
+++ b/example/audio/yue_target.json
@@ -0,0 +1,16 @@
+[
+ {
+ "index": "vocal_420_14370",
+ "language": "Cantonese",
+ "time": [
+ 420,
+ 14370
+ ],
+ "duration": "0.31 0.26 0.28 0.26 0.40 0.20 0.42 0.24 0.36 0.24 0.32 0.26 0.94 0.32 0.24 0.30 0.34 0.22 0.34 0.90 0.22 0.36 0.32 0.30 0.22 0.36 0.22 0.32 0.34 0.20 0.40 0.24 0.30 0.38 0.22 0.32 0.28 0.36 0.24 0.34 0.26 0.60",
+ "text": " 我 的 心 情 又 像 真 该 等 被 揭 开 嘴 巴 却 再 仰 千 台 人 潮 内 越 文 静 越 变 得 不 受 理 睬 睬 自 己 己 要 交 出 意 外",
+ "phoneme": " yue_ngo5 yue_dik1 yue_sam1 yue_cing4 yue_jau6 yue_zoeng6 yue_zan1 yue_goi1 yue_dang2 yue_bei6 yue_kit3 yue_hoi1 yue_zeoi2 yue_baa1 yue_koek3 yue_zoi3 yue_joeng5 yue_cin1 yue_toi4 yue_jan4 yue_ciu4 yue_noi6 yue_jyut6 yue_man4 yue_zing6 yue_jyut6 yue_bin3 yue_dak1 yue_bat1 yue_sau6 yue_lei5 yue_coi2 yue_coi2 yue_zi6 yue_gei2 yue_gei2 yue_jiu3 yue_gaau1 yue_ceot1 yue_ji3 yue_ngoi6",
+ "note_pitch": "0 52 57 59 55 57 59 62 60 58 54 57 59 59 57 55 54 53 57 51 50 54 57 58 54 57 59 61 64 59 54 54 57 59 51 56 58 57 56 56 55 52",
+ "note_type": "1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 2 2 3 2 2 2 2 2",
+ "f0": "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 74.8 74.8 82.4 85.7 81.0 76.2 78.8 111.9 129.0 146.8 160.6 175.0 182.1 172.9 163.9 190.3 214.7 218.4 221.5 223.5 220.2 209.1 173.8 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 254.5 248.7 245.1 243.3 239.8 238.8 241.2 244.0 245.7 245.6 239.8 216.8 191.2 0.0 0.0 170.9 179.4 189.0 192.6 193.3 192.8 193.5 194.1 194.1 194.4 195.4 197.4 199.7 202.2 205.8 210.6 212.3 214.4 216.5 219.4 221.9 222.4 222.4 222.8 222.9 217.1 189.6 175.4 0.0 0.0 255.5 251.2 246.5 247.3 248.9 249.1 249.4 251.9 253.6 250.2 247.1 246.6 239.1 193.8 191.0 0.0 295.9 301.1 302.8 301.5 296.6 287.7 286.4 290.4 294.2 297.1 297.3 294.7 287.9 273.4 221.0 262.3 265.3 259.7 255.9 254.7 255.1 256.2 257.4 259.3 260.8 261.3 249.3 209.1 194.0 236.7 224.9 210.2 202.6 197.9 201.4 210.6 220.6 230.1 237.9 242.4 242.7 241.0 234.7 220.3 190.9 179.0 185.6 182.3 178.3 177.1 179.0 181.3 182.4 184.5 187.6 185.9 172.3 161.7 167.7 0.0 0.0 206.9 210.4 213.0 214.2 215.6 217.3 217.1 194.7 181.9 184.5 182.3 171.8 155.9 161.4 167.7 195.8 235.3 245.0 245.8 241.9 237.0 232.3 231.3 234.9 241.8 250.9 253.2 248.9 238.0 226.0 220.5 224.9 236.6 250.9 259.6 262.0 259.1 252.5 246.5 241.8 236.0 228.8 216.3 210.5 211.1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 209.1 197.4 191.2 191.2 200.4 222.5 235.6 248.0 252.9 249.4 243.0 235.6 216.1 182.1 175.8 213.4 214.1 214.7 215.1 215.2 215.9 215.7 214.8 211.6 202.2 187.0 183.6 0.0 0.0 0.0 191.3 192.1 191.4 192.4 193.3 192.6 188.9 171.2 0.0 0.0 0.0 0.0 0.0 0.0 192.1 186.5 181.8 178.2 176.1 176.8 178.2 179.4 179.7 180.4 183.3 184.4 182.7 179.0 176.1 174.6 168.7 166.5 167.2 168.8 171.8 175.6 185.4 194.7 197.2 192.6 181.9 0.0 0.0 0.0 192.6 204.8 217.5 220.0 220.4 218.7 216.0 214.2 216.0 216.8 213.0 200.5 183.9 0.0 0.0 158.5 156.8 159.6 161.4 160.7 159.4 159.5 159.3 157.1 152.4 152.1 156.4 162.9 167.3 166.4 160.3 149.6 146.1 149.5 155.4 161.3 164.1 163.3 161.3 154.8 148.5 144.6 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 99.4 97.6 102.0 115.3 131.4 142.5 150.1 155.1 159.7 162.5 161.1 152.3 0.0 0.0 0.0 0.0 174.2 183.7 189.8 191.6 190.4 189.1 189.5 190.4 191.9 192.1 192.5 196.2 203.3 212.2 214.1 216.7 217.4 217.2 216.8 216.4 214.9 214.7 216.3 218.9 220.4 221.5 224.6 231.2 237.9 242.2 242.6 240.3 239.0 240.4 241.7 242.3 238.4 184.9 178.6 194.1 203.6 194.7 185.4 180.5 181.9 185.4 188.7 191.6 193.9 194.6 192.2 191.0 189.5 186.7 181.0 0.0 0.0 219.2 217.8 216.8 215.3 213.3 212.0 213.6 215.2 215.0 214.7 215.8 218.2 221.0 224.4 229.9 237.5 244.9 246.6 246.6 248.0 250.2 249.8 193.6 186.3 193.0 0.0 0.0 0.0 288.3 289.9 287.8 287.2 287.7 285.0 283.4 281.9 280.1 283.4 287.8 290.9 291.6 287.6 240.0 238.5 0.0 333.3 328.9 325.3 323.1 321.6 317.4 298.0 274.9 0.0 0.0 0.0 0.0 0.0 0.0 243.7 248.7 245.5 243.3 243.8 246.2 244.5 226.2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 197.1 189.0 184.3 182.5 184.8 188.1 188.8 190.3 190.4 189.5 188.7 188.0 189.8 188.2 185.0 182.2 179.7 178.1 178.7 188.8 208.5 217.2 220.0 218.6 212.8 191.9 167.3 0.0 0.0 0.0 183.9 201.1 210.5 210.3 209.5 206.7 206.5 212.0 221.1 235.9 250.8 253.0 247.7 238.4 229.7 229.1 235.1 244.7 251.7 253.5 251.4 246.7 242.4 234.6 209.5 180.1 173.1 0.0 0.0 0.0 147.9 147.5 155.4 159.1 160.2 161.4 162.0 161.7 159.0 152.4 139.6 121.5 126.4 0.0 0.0 197.2 200.2 196.8 195.8 200.3 203.2 203.8 202.5 203.7 212.8 220.0 225.8 231.6 234.1 231.7 228.3 225.4 226.0 229.7 234.3 237.3 238.2 236.9 232.9 227.1 220.3 215.1 211.3 205.9 210.1 216.0 217.6 218.4 218.9 219.3 218.5 217.1 216.9 217.7 216.4 212.1 196.3 171.7 171.3 210.9 203.0 194.9 194.8 199.1 205.4 210.5 214.9 219.6 224.8 225.1 221.4 212.4 198.5 0.0 0.0 204.9 204.1 208.9 212.7 212.8 213.2 214.9 214.4 208.4 189.0 159.7 160.4 193.1 198.9 196.0 192.8 193.0 195.3 195.2 194.5 193.6 193.9 192.5 192.3 192.2 183.0 164.8 150.0 147.3 150.7 155.9 160.8 163.6 164.8 161.8 156.6 155.0 160.5 166.2 167.3 165.1 162.0 154.3 142.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0"
+ }
+]
\ No newline at end of file
diff --git a/example/audio/yue_target.mp3 b/example/audio/yue_target.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..f224e2e330a81b2a006299e0f6f9f617bbeb69be
--- /dev/null
+++ b/example/audio/yue_target.mp3
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a699c2649eec48ed1e9a6caae2af918bf7d49e5e4ad39cf3cca0916942bc7db2
+size 353361
diff --git a/example/audio/zh_prompt.json b/example/audio/zh_prompt.json
new file mode 100644
index 0000000000000000000000000000000000000000..288d4baca913ff85a83b77fd692f7cc55be9770a
--- /dev/null
+++ b/example/audio/zh_prompt.json
@@ -0,0 +1,16 @@
+[
+ {
+ "index": "vocal_320_10687",
+ "language": "Mandarin",
+ "time": [
+ 320,
+ 10687
+ ],
+ "duration": "0.23 0.34 0.26 0.70 0.52 0.46 0.36 0.44 0.14 0.24 0.64 0.47 0.51 1.10 0.28 0.38 0.32 0.32 0.38 0.32 0.31 0.19 1.45",
+ "text": " 除 了 想 你 你 除 了 了 爱 你 你 我 什 么 什 么 都 愿 愿 意",
+ "phoneme": " zh_chu2 zh_le5 zh_xiang3 zh_ni3 zh_ni3 zh_chu2 zh_le5 zh_le5 zh_ai4 zh_ni3 zh_ni3 zh_wo3 zh_shen2 zh_me5 zh_shen2 zh_me5 zh_dou1 zh_yuan4 zh_yuan4 zh_yi4",
+ "note_pitch": "0 62 65 67 67 69 0 67 69 67 65 67 69 67 67 66 64 64 60 60 65 67 0",
+ "note_type": "1 2 2 2 2 3 1 2 2 3 2 2 3 1 2 2 2 2 2 2 2 3 2",
+ "f0": "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 294.1 288.2 290.6 294.6 295.7 292.8 291.5 294.4 295.4 294.7 293.5 292.2 294.4 295.8 293.2 297.7 320.1 338.1 348.3 348.7 344.4 342.6 346.2 354.8 356.9 353.1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 401.1 403.8 405.0 402.1 398.4 395.6 393.3 392.0 391.4 390.3 390.2 390.3 390.3 391.3 392.8 393.6 391.7 390.5 391.4 391.6 391.5 393.5 393.8 390.8 387.4 387.8 389.3 390.8 392.2 391.6 390.2 389.8 389.1 388.4 390.0 395.5 397.2 396.7 395.5 395.1 394.6 394.9 395.6 395.2 394.6 395.4 395.9 394.0 391.7 390.7 391.7 392.6 391.6 395.7 405.8 441.7 462.5 463.8 450.2 430.9 414.5 415.2 426.7 439.8 454.4 462.9 447.8 422.5 400.6 403.2 423.5 451.3 482.3 492.8 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 435.9 414.3 406.2 402.8 398.2 396.7 396.7 395.2 398.3 416.9 441.6 446.9 441.5 437.3 435.3 432.9 431.6 432.3 434.8 436.7 433.8 422.8 406.7 390.9 382.0 381.7 384.7 382.7 368.1 357.4 355.6 355.1 352.4 348.1 346.4 348.7 351.6 355.0 354.4 351.7 349.9 349.2 348.1 346.0 345.4 344.2 344.4 345.5 346.6 349.0 349.7 349.1 349.5 349.6 349.7 349.4 349.5 352.2 354.5 355.7 355.6 356.9 359.4 361.5 363.8 360.4 354.3 357.2 363.9 372.4 382.6 399.1 402.9 400.6 395.5 390.5 388.9 390.2 391.1 391.9 391.5 390.4 390.2 391.3 391.6 391.0 388.6 386.2 389.1 403.8 430.2 441.8 449.5 448.1 443.2 438.3 432.9 430.7 434.6 442.0 447.6 446.0 440.3 434.7 431.1 435.7 442.3 445.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 423.8 402.1 398.5 397.6 393.4 390.5 391.4 402.5 427.6 442.5 442.5 435.1 430.1 430.8 439.9 447.4 442.2 426.1 412.3 399.8 391.8 389.5 388.5 387.8 386.4 384.7 384.5 387.9 391.2 391.8 392.9 393.8 392.0 392.0 395.4 398.1 398.2 396.3 393.1 391.1 388.9 386.6 383.1 381.5 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 357.8 367.6 370.9 365.4 359.3 355.7 358.1 372.7 396.8 404.0 398.4 392.6 389.2 388.8 383.6 362.3 341.1 325.1 326.3 327.9 331.3 333.3 326.0 319.4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 356.4 353.5 349.0 346.5 338.4 328.4 323.7 326.2 334.2 338.4 331.6 309.8 280.6 256.4 252.8 256.1 258.3 256.9 257.6 261.1 259.7 258.7 258.6 260.2 262.2 262.4 262.9 264.2 263.4 260.2 256.4 253.2 238.7 223.3 231.6 257.7 258.1 258.4 258.8 258.0 256.8 255.8 254.0 255.1 258.1 261.9 263.7 262.6 256.7 253.1 250.2 246.7 258.7 294.4 327.1 342.6 346.0 344.1 341.2 342.3 345.2 350.1 364.7 382.5 396.1 396.1 389.2 381.8 381.1 387.2 397.0 399.3 390.7 374.3 360.1 350.6 346.6 347.4 350.7 354.4 354.3 351.7 349.8 348.4 346.9 347.0 348.4 349.5 351.0 352.3 353.6 353.3 350.5 348.3 345.5 344.3 344.4 347.0 350.6 352.0 351.0 350.8 350.1 347.6 345.7 347.0 350.3 351.7 350.7 348.5 346.9 347.7 349.0 349.1 348.5 346.8 346.3 348.4 349.0 349.2 351.1 349.6 348.3 350.5 351.1 348.0 347.6 349.1 351.3 356.0 361.3 360.6 354.0 341.0 316.2 302.9 302.1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0"
+ }
+]
\ No newline at end of file
diff --git a/example/audio/zh_prompt.mp3 b/example/audio/zh_prompt.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..bc306410a44cdd6f80b0519847424a7a8e0ac954
Binary files /dev/null and b/example/audio/zh_prompt.mp3 differ
diff --git a/example/audio/zh_target.json b/example/audio/zh_target.json
new file mode 100644
index 0000000000000000000000000000000000000000..49e9d1c722e22d3aafd37166bc74b7fb8d76ef0c
--- /dev/null
+++ b/example/audio/zh_target.json
@@ -0,0 +1,16 @@
+[
+ {
+ "index": "vocal_0_6710",
+ "language": "Mandarin",
+ "time": [
+ 0,
+ 6710
+ ],
+ "duration": "0.13 0.26 0.24 0.22 0.24 0.33 0.13 0.24 0.22 0.46 0.69 0.84 0.26 0.30 0.16 0.26 0.26 0.20 0.32 0.94",
+ "text": " 像 我 这 样 懦 懦 弱 的 人 人 凡 事 都 要 留 留 几 分",
+ "phoneme": " zh_xiang4 zh_wo3 zh_zhe4 zh_yang4 zh_nuo4 zh_nuo4 zh_ruo4 zh_de5 zh_ren2 zh_ren2 zh_fan2 zh_shi4 zh_dou1 zh_yao4 zh_liu2 zh_liu2 zh_ji3 zh_fen1",
+ "note_pitch": "0 50 53 55 53 56 54 53 50 51 53 0 51 53 55 53 54 56 51 53",
+ "note_type": "1 2 2 2 2 2 3 2 2 2 3 1 2 2 2 2 2 3 2 2",
+ "f0": "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 132.7 137.2 144.9 147.0 148.0 148.6 148.9 148.5 147.9 147.4 148.1 149.1 154.3 166.4 173.1 175.0 176.3 175.9 172.7 173.6 175.9 172.9 159.1 165.8 0.0 214.2 213.4 210.2 201.5 198.1 197.1 197.2 197.8 200.8 206.4 206.1 200.5 189.9 180.6 172.0 170.8 171.4 176.2 180.9 182.3 182.2 181.0 180.5 183.4 192.6 211.8 220.6 223.7 219.3 211.9 207.0 203.6 202.6 204.2 204.4 204.1 202.1 198.0 192.9 185.9 177.9 174.1 174.6 174.5 173.8 173.6 172.4 168.3 168.3 172.7 173.2 171.9 170.7 170.2 169.9 170.6 173.1 172.4 164.2 148.2 147.8 152.3 148.5 143.8 145.6 149.2 149.9 150.1 152.5 153.6 154.7 156.1 155.0 152.4 152.1 153.7 155.3 156.4 156.8 157.3 157.7 157.1 156.8 157.8 158.9 157.9 157.5 157.1 157.0 159.1 162.0 167.7 172.1 174.9 176.2 174.5 172.0 170.9 171.3 172.5 173.1 173.5 173.1 174.1 174.6 175.2 176.7 177.2 177.3 176.9 175.9 174.1 172.4 174.1 174.8 171.8 172.1 176.5 177.3 176.0 179.4 179.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 137.5 138.5 145.8 151.0 153.3 154.3 156.2 158.5 161.0 158.1 148.6 151.3 157.0 163.5 172.7 176.4 176.1 175.8 175.6 174.7 171.4 169.3 161.7 194.2 199.3 199.7 201.0 202.7 201.4 199.4 198.5 196.6 194.0 190.6 186.1 181.9 179.4 177.4 177.1 176.3 175.8 175.5 174.8 173.6 172.4 170.8 165.0 160.6 175.9 179.3 179.8 180.4 180.3 178.8 178.6 181.5 185.0 190.7 198.3 206.3 210.6 210.9 207.4 203.5 203.4 204.6 203.5 195.6 182.5 0.0 0.0 0.0 0.0 144.7 144.1 146.6 150.2 151.9 153.4 155.0 156.0 155.9 155.4 155.2 153.5 147.0 144.8 0.0 0.0 0.0 0.0 0.0 0.0 181.1 178.9 178.0 177.5 176.0 173.2 172.9 172.9 174.1 176.2 177.3 178.7 178.8 176.0 175.2 175.1 176.3 178.1 177.6 177.4 177.9 177.6 177.3 177.3 177.5 176.6 175.7 176.6 177.5 177.2 175.9 174.8 173.5 174.0 175.7 177.4 177.8 174.7"
+ }
+]
\ No newline at end of file
diff --git a/example/audio/zh_target.mp3 b/example/audio/zh_target.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..ff3fe3082b7bba02e0b2297644f2c15b7f2beb6c
Binary files /dev/null and b/example/audio/zh_target.mp3 differ
diff --git a/example/infer.sh b/example/infer.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a39661e7cf9d6897ea504650a3617bf79c1d0716
--- /dev/null
+++ b/example/infer.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+
+script_dir=$(dirname "$(realpath "$0")")
+root_dir=$(dirname "$script_dir")
+
+cd $root_dir || exit
+export PYTHONPATH=$root_dir:$PYTHONPATH
+
+model_path=pretrained_models/SoulX-Singer/model.pt
+config=soulxsinger/config/soulxsinger.yaml
+prompt_wav_path=example/audio/zh_prompt.mp3
+prompt_metadata_path=example/audio/zh_prompt.json
+target_metadata_path=example/audio/music.json
+phoneset_path=soulxsinger/utils/phoneme/phone_set.json
+save_dir=example/generated/music
+control=score # melody or score
+
+python -m cli.inference \
+ --device cuda \
+ --model_path $model_path \
+ --config $config \
+ --prompt_wav_path $prompt_wav_path \
+ --prompt_metadata_path $prompt_metadata_path \
+ --target_metadata_path $target_metadata_path \
+ --phoneset_path $phoneset_path \
+ --save_dir $save_dir \
+ --auto_shift \
+ --pitch_shift 0
\ No newline at end of file
diff --git a/example/preprocess.sh b/example/preprocess.sh
new file mode 100644
index 0000000000000000000000000000000000000000..bd6b4cea5bb5f4e20cdf5d5725135881a68e4e49
--- /dev/null
+++ b/example/preprocess.sh
@@ -0,0 +1,41 @@
+#!/bin/bash
+
+script_dir=$(dirname "$(realpath "$0")")
+root_dir=$(dirname "$script_dir")
+
+cd $root_dir || exit
+export PYTHONPATH=$root_dir:$PYTHONPATH
+
+device=cuda
+
+
+####### Run Prompt Annotation #######
+audio_path=example/audio/zh_prompt.mp3
+save_dir=example/transcriptions/zh_prompt
+language=Mandarin
+vocal_sep=False
+max_merge_duration=30000
+
+python -m preprocess.pipeline \
+ --audio_path $audio_path \
+ --save_dir $save_dir \
+ --language $language \
+ --device $device \
+ --vocal_sep $vocal_sep \
+ --max_merge_duration $max_merge_duration
+
+
+####### Run Target Annotation #######
+audio_path=example/audio/music.mp3
+save_dir=example/transcriptions/music
+language=Mandarin
+vocal_sep=True
+max_merge_duration=60000
+
+python -m preprocess.pipeline \
+ --audio_path $audio_path \
+ --save_dir $save_dir \
+ --language $language \
+ --device $device \
+ --vocal_sep $vocal_sep \
+ --max_merge_duration $max_merge_duration
\ No newline at end of file
diff --git a/preprocess/README.md b/preprocess/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..17e9bc4457f1f8cf4b6f1791c0975424bb445a1a
--- /dev/null
+++ b/preprocess/README.md
@@ -0,0 +1,155 @@
+# 🎵 SoulX-Singer-Preprocess
+
+This part offers a comprehensive **singing transcription and editing toolkit** for real-world music audio. It provides the pipeline from vocal extraction to high-level annotation optimized for SVS dataset construction. By integrating state-of-the-art models, it transforms raw audio into structured singing data and supports the **customizable creation and editing of lyric-aligned MIDI scores**.
+
+
+## ✨ Features
+
+The toolkit includes the following core modules:
+
+- 🎤 **Clean Dry Vocal Extraction**
+ Extracts the lead vocal track from polyphonic music audio and dereverberation.
+
+- 📝 **Lyrics Transcription**
+ Automatically transcribes lyrics from clean vocal.
+
+- 🎶 **Note Transcription**
+ Converts singing voice into note-level representations for SVS.
+
+- 🎼 **MIDI Editor**
+ Supports customizable creation and editing of MIDI scores integrated with lyrics.
+
+
+## 🔧 Python Environment
+
+Before running the pipeline, set up the Python environment as follows:
+
+1. **Install Conda** (if not already installed): https://docs.conda.io/en/latest/miniconda.html
+
+2. **Activate or create a conda environment** (recommended Python 3.10):
+
+ - If you already have the `soulxsinger` environment:
+
+ ```bash
+ conda activate soulxsinger
+ ```
+
+ - Otherwise, create it first:
+
+ ```bash
+ conda create -n soulxsinger -y python=3.10
+ conda activate soulxsinger
+ ```
+
+3. **Install dependencies** from the `preprocess` directory:
+
+ ```bash
+ cd preprocess
+ pip install -r requirements.txt
+ ```
+
+## 📁 Data Preparation
+
+Before running the pipeline, prepare the following inputs:
+
+- **Prompt audio**
+ Reference audio that provides timbre and style
+
+- **Target audio**
+ Original vocal or music audio to be processed and transcribed.
+
+Configure the corresponding parameters in:
+
+```
+example/preprocess.sh
+```
+
+Typical configuration includes:
+- Input / output paths
+- Module enable switches
+
+## 🚀 Usage
+
+After configuring `preprocess.sh`, run the transcription pipeline with:
+
+```bash
+bash example/preprocess.sh
+```
+
+The script will automatically execute the following steps:
+
+1. **Vocal separation and dereverberation**
+2. **F0 extraction and voice activity detection (VAD)**
+3. **Lyrics transcription**
+4. **Note transcription**
+
+---
+
+After the pipeline completes, you will obtain **SoulX-Singer–style metadata** that can be directly used for Singing Voice Synthesis (SVS).
+
+**Output paths:**
+- The final metadata (**JSON file**) is written **in the same directory as your input audio**, with the **same filename** (e.g. `audio.mp3` → `audio.json`)
+- All **intermediate results** (separated vocal and accompaniment, F0, VAD outputs, etc.) are also saved under the configured **`save_dir`**.
+
+⚠️ **Important Note**
+
+Transcription errors—especially in **lyrics** and **note annotations**—can significantly affect the final SVS quality. We **strongly recommend manually reviewing and correcting** the generated metadata before inference.
+
+To support this, we provide a **MIDI Editor** for editing lyrics, phoneme alignment, note pitches, and durations. The workflow is:
+
+**Export metadata to MIDI** → edit in the MIDI Editor → **Import edited MIDI back to metadata** for SVS.
+
+---
+
+#### Step 1: Metadata → MIDI (for editing)
+
+Convert SoulX-Singer metadata to a MIDI file so you can open it in the MIDI Editor:
+
+```bash
+preprocess_root=example/transcriptions/music
+
+python -m preprocess.tools.midi_parser \
+ --meta2midi \
+ --meta "${preprocess_root}/metadata.json" \
+ --midi "${preprocess_root}/vocal.mid"
+```
+
+#### Step 2: Edit in the MIDI Editor
+
+Open the MIDI Editor (see [MIDI Editor Tutorial](tools/midi_editor/README.md)), load `vocal.mid`, and correct lyrics, pitches, or durations as needed. Save the result as e.g. `vocal_edited.mid`.
+
+#### Step 3: MIDI → Metadata (for SoulX-Singer inference)
+
+Convert the edited MIDI back into SoulX-Singer-style metadata (and cut wavs) for SVS:
+
+```bash
+python -m preprocess.tools.midi_parser \
+ --midi2meta \
+ --midi "${preprocess_root}/vocal_edited.mid" \
+ --meta "${preprocess_root}/edit_metadata.json" \
+ --vocal "${preprocess_root}/vocal.wav" \
+```
+
+Use `edit_metadata.json` (and the wavs under `edit_cut_wavs`) as the target metadata in your inference pipeline.
+
+
+## 🔗 References & Dependencies
+
+This project builds upon the following excellent open-source works:
+
+### 🎧 Vocal Separation & Dereverberation
+- [Music Source Separation Training](https://github.com/ZFTurbo/Music-Source-Separation-Training)
+- [Lead Vocal Separation](https://huggingface.co/becruily/mel-band-roformer-karaoke)
+- [Vocal Dereverberation](https://huggingface.co/anvuew/dereverb_mel_band_roformer)
+
+### 🎼 F0 Extraction
+- [RMVPE](https://github.com/Dream-High/RMVPE)
+
+### 📝 Lyrics Transcription (ASR)
+- [Paraformer](https://modelscope.cn/models/iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch)
+- [Parakeet-tdt-0.6b-v2](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2)
+
+### 🎶 Note Transcription
+- [ROSVOT](https://github.com/RickyL-2000/ROSVOT)
+
+We sincerely thank the authors of these repositories for their exceptional open-source contributions, which have been fundamental to the development of this toolkit.
diff --git a/preprocess/pipeline.py b/preprocess/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3a69c76a039f1c860882b2fe16e1327cfcb9a2b
--- /dev/null
+++ b/preprocess/pipeline.py
@@ -0,0 +1,146 @@
+import json
+import shutil
+import soundfile as sf
+from pathlib import Path
+import librosa
+
+from preprocess.utils import convert_metadata, merge_short_segments
+
+from preprocess.tools import (
+ F0Extractor,
+ VocalDetector,
+ VocalSeparator,
+ NoteTranscriber,
+ LyricTranscriber,
+)
+
+
+class PreprocessPipeline:
+ def __init__(self, device: str, language: str, save_dir: str, vocal_sep: bool = True, max_merge_duration: int = 60000):
+ self.device = device
+ self.language = language
+ self.save_dir = save_dir
+ self.vocal_sep = vocal_sep
+ self.max_merge_duration = max_merge_duration
+
+ if vocal_sep:
+ self.vocal_separator = VocalSeparator(
+ sep_model_path="pretrained_models/SoulX-Singer-Preprocess/mel-band-roformer-karaoke/mel_band_roformer_karaoke_becruily.ckpt",
+ sep_config_path="pretrained_models/SoulX-Singer-Preprocess/mel-band-roformer-karaoke/config_karaoke_becruily.yaml",
+ der_model_path="pretrained_models/SoulX-Singer-Preprocess/dereverb_mel_band_roformer/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt",
+ der_config_path="pretrained_models/SoulX-Singer-Preprocess/dereverb_mel_band_roformer/dereverb_mel_band_roformer_anvuew.yaml",
+ device=device
+ )
+ else:
+ self.vocal_separator = None
+ self.f0_extractor = F0Extractor(
+ model_path="pretrained_models/SoulX-Singer-Preprocess/rmvpe/rmvpe.pt",
+ device=device,
+ )
+ self.vocal_detector = VocalDetector(
+ cut_wavs_output_dir= f"{save_dir}/cut_wavs",
+ )
+ self.lyric_transcriber = LyricTranscriber(
+ zh_model_path="pretrained_models/SoulX-Singer-Preprocess/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+ en_model_path="pretrained_models/SoulX-Singer-Preprocess/parakeet-tdt-0.6b-v2/parakeet-tdt-0.6b-v2.nemo",
+ device=device
+ )
+ self.note_transcriber = NoteTranscriber(
+ rosvot_model_path="pretrained_models/SoulX-Singer-Preprocess/rosvot/rosvot/model.pt",
+ rwbd_model_path="pretrained_models/SoulX-Singer-Preprocess/rosvot/rwbd/model.pt",
+ device=device
+ )
+
+ def run(
+ self,
+ audio_path: str,
+ vocal_sep: bool = True,
+ max_merge_duration: int = 60000,
+ language: str = "Mandarin"
+ ) -> None:
+ vocal_sep = self.vocal_sep if vocal_sep is None else vocal_sep
+ max_merge_duration = self.max_merge_duration if max_merge_duration is None else max_merge_duration
+ language = self.language if language is None else language
+ output_dir = Path(self.save_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ if vocal_sep:
+ # Perform vocal/accompaniment separation
+ sep = self.vocal_separator.process(audio_path)
+ vocal = sep.vocals_dereverbed.T
+ acc = sep.accompaniment.T
+ sample_rate = sep.sample_rate
+
+ vocal_path = output_dir / "vocal.wav"
+ acc_path = output_dir / "acc.wav"
+ sf.write(vocal_path, vocal, sample_rate)
+ sf.write(acc_path, acc, sample_rate)
+ else:
+ # Use the original audio as vocal source (no separation)
+ vocal, sample_rate = librosa.load(audio_path, sr=None, mono=True)
+ vocal_path = output_dir / "vocal.wav"
+ sf.write(vocal_path, vocal, sample_rate)
+
+ vocal_f0 = self.f0_extractor.process(str(vocal_path))
+ segments = self.vocal_detector.process(str(vocal_path), f0=vocal_f0)
+
+ metadata = []
+ for seg in segments:
+ self.f0_extractor.process(seg["wav_fn"], f0_path=seg["wav_fn"].replace(".wav", "_f0.npy"))
+ words, durs = self.lyric_transcriber.process(
+ seg["wav_fn"], language
+ )
+ seg["words"] = words
+ seg["word_durs"] = durs
+ seg["language"] = language
+ metadata.append(
+ self.note_transcriber.process(seg, segment_info=seg)
+ )
+
+ merged = merge_short_segments(
+ vocal,
+ sample_rate,
+ metadata,
+ output_dir / "long_cut_wavs",
+ max_duration_ms=max_merge_duration,
+ )
+
+ final_metadata = []
+
+ for item in merged:
+ self.f0_extractor.process(item.wav_fn, f0_path=item.wav_fn.replace(".wav", "_f0.npy"))
+ final_metadata.append(convert_metadata(item))
+
+ with open(output_dir / "metadata.json", "w", encoding="utf-8") as f:
+ json.dump(final_metadata, f, ensure_ascii=False, indent=2)
+
+ shutil.copy(output_dir / "metadata.json", audio_path.replace(".wav", ".json").replace(".mp3", ".json").replace(".flac", ".json"))
+
+
+def main(args):
+ pipeline = PreprocessPipeline(
+ device=args.device,
+ language=args.language,
+ save_dir=args.save_dir,
+ vocal_sep=args.vocal_sep,
+ max_merge_duration=args.max_merge_duration,
+ )
+ pipeline.run(
+ audio_path=args.audio_path,
+ language=args.language
+ )
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--audio_path", type=str, required=True, help="Path to the input audio file")
+ parser.add_argument("--save_dir", type=str, required=True, help="Directory to save the output files")
+ parser.add_argument("--language", type=str, default="Mandarin", help="Language of the audio")
+ parser.add_argument("--device", type=str, default="cuda:0", help="Device to run the models on")
+ parser.add_argument("--vocal_sep", type=bool, default=True, help="Whether to perform vocal separation")
+ parser.add_argument("--max_merge_duration", type=int, default=60000, help="Maximum merged segment duration in milliseconds")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/preprocess/requirements.txt b/preprocess/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..09f8081d927da98ab99c44ef1fbea06cd611fe54
--- /dev/null
+++ b/preprocess/requirements.txt
@@ -0,0 +1,33 @@
+beartype==0.22.9
+einops==0.8.2
+funasr==1.3.0
+g2p_en==2.1.0
+g2pM==0.1.2.5
+librosa==0.11.0
+loralib==0.1.2
+matplotlib==3.10.8
+mido==1.3.3
+ml_collections==1.1.0
+nemo_toolkit==2.6.1
+nltk==3.9.2
+numba==0.63.1
+numpy==2.2.6
+omegaconf==2.3.0
+packaging==24.2
+praat-parselmouth==0.4.7
+pretty_midi==0.2.11
+pyloudnorm==0.2.0
+pyworld==0.3.5
+rotary_embedding_torch==0.8.9
+sageattention==1.0.6
+scikit_learn==1.7.2
+scipy==1.15.3
+six==1.17.0
+scikit_image==0.25.2
+soundfile==0.13.1
+ToJyutping==3.2.0
+torch==2.10.0
+torchaudio==2.10.0
+tqdm==4.67.1
+wandb==0.24.2
+webrtcvad==2.0.10
diff --git a/preprocess/tools/__init__.py b/preprocess/tools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bacfc0747151660f7e92e40075128d8eb675296
--- /dev/null
+++ b/preprocess/tools/__init__.py
@@ -0,0 +1,53 @@
+"""Preprocess tools.
+
+This package provides a thin, stable import surface for common preprocess components.
+
+Examples:
+ from preprocess.tools import (
+ F0Extractor,
+ PitchExtractor,
+ VocalDetectionModel,
+ VocalSeparationModel,
+ VocalExtractionModel,
+ NoteTranscriptionModel,
+ LyricTranscriptionModel,
+ )
+
+Note:
+ Keep these imports lightweight. If a tool pulls heavy dependencies at import time,
+ consider switching to lazy imports.
+"""
+
+from __future__ import annotations
+
+# Core tools
+from .f0_extraction import F0Extractor
+from .vocal_detection import VocalDetector
+
+# Some tools may live outside this package in different layouts across branches.
+# Keep the public surface stable while avoiding hard import failures.
+try:
+ from .vocal_separation.model import VocalSeparator # type: ignore
+except Exception: # pragma: no cover
+ VocalSeparator = None # type: ignore
+
+try:
+ from .note_transcription.model import NoteTranscriber # type: ignore
+except Exception: # pragma: no cover
+ NoteTranscriber = None # type: ignore
+try:
+ from .lyric_transcription import LyricTranscriber
+except Exception: # pragma: no cover
+ LyricTranscriber = None # type: ignore
+
+__all__ = [
+ "F0Extractor",
+ "VocalDetector",
+]
+
+if VocalSeparator is not None:
+ __all__.append("VocalSeparator")
+if LyricTranscriber is not None:
+ __all__.append("LyricTranscriber")
+if NoteTranscriber is not None:
+ __all__.append("NoteTranscriber")
diff --git a/preprocess/tools/f0_extraction.py b/preprocess/tools/f0_extraction.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ec2cf7a991c2780b60f7f3817749c1d2a7af22e
--- /dev/null
+++ b/preprocess/tools/f0_extraction.py
@@ -0,0 +1,527 @@
+# https://github.com/Dream-High/RMVPE
+import math
+import time
+import librosa
+import numpy as np
+from librosa.filters import mel
+from scipy.interpolate import interp1d
+
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BiGRU(nn.Module):
+ def __init__(self, input_features, hidden_features, num_layers):
+ super(BiGRU, self).__init__()
+ self.gru = nn.GRU(
+ input_features,
+ hidden_features,
+ num_layers=num_layers,
+ batch_first=True,
+ bidirectional=True,
+ )
+
+ def forward(self, x):
+ return self.gru(x)[0]
+
+
+class ConvBlockRes(nn.Module):
+ def __init__(self, in_channels, out_channels, momentum=0.01):
+ super(ConvBlockRes, self).__init__()
+ self.conv = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(3, 3),
+ stride=(1, 1),
+ padding=(1, 1),
+ bias=False,
+ ),
+ nn.BatchNorm2d(out_channels, momentum=momentum),
+ nn.ReLU(),
+ nn.Conv2d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=(3, 3),
+ stride=(1, 1),
+ padding=(1, 1),
+ bias=False,
+ ),
+ nn.BatchNorm2d(out_channels, momentum=momentum),
+ nn.ReLU(),
+ )
+ if in_channels != out_channels:
+ self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
+
+ def forward(self, x):
+ if not hasattr(self, "shortcut"):
+ return self.conv(x) + x
+ else:
+ return self.conv(x) + self.shortcut(x)
+
+
+class ResEncoderBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
+ super(ResEncoderBlock, self).__init__()
+ self.n_blocks = n_blocks
+ self.conv = nn.ModuleList()
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
+ for i in range(n_blocks - 1):
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
+ self.kernel_size = kernel_size
+ if self.kernel_size is not None:
+ self.pool = nn.AvgPool2d(kernel_size=kernel_size)
+
+ def forward(self, x):
+ for conv in self.conv:
+ x = conv(x)
+ if self.kernel_size is not None:
+ return x, self.pool(x)
+ else:
+ return x
+
+
+class Encoder(nn.Module):
+ def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
+ super(Encoder, self).__init__()
+ self.n_encoders = n_encoders
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
+ self.layers = nn.ModuleList()
+ self.latent_channels = []
+ for i in range(self.n_encoders):
+ self.layers.append(
+ ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum)
+ )
+ self.latent_channels.append([out_channels, in_size])
+ in_channels = out_channels
+ out_channels *= 2
+ in_size //= 2
+ self.out_size = in_size
+ self.out_channel = out_channels
+
+ def forward(self, x):
+ concat_tensors = []
+ x = self.bn(x)
+ for layer in self.layers:
+ t, x = layer(x)
+ concat_tensors.append(t)
+ return x, concat_tensors
+
+
+class Intermediate(nn.Module):
+ def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
+ super(Intermediate, self).__init__()
+ self.n_inters = n_inters
+ self.layers = nn.ModuleList()
+ self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
+ for i in range(self.n_inters - 1):
+ self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
+
+ def forward(self, x):
+ for layer in self.layers:
+ x = layer(x)
+ return x
+
+
+class ResDecoderBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
+ super(ResDecoderBlock, self).__init__()
+ out_padding = (0, 1) if stride == (1, 2) else (1, 1)
+ self.n_blocks = n_blocks
+ self.conv1 = nn.Sequential(
+ nn.ConvTranspose2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(3, 3),
+ stride=stride,
+ padding=(1, 1),
+ output_padding=out_padding,
+ bias=False,
+ ),
+ nn.BatchNorm2d(out_channels, momentum=momentum),
+ nn.ReLU(),
+ )
+ self.conv2 = nn.ModuleList()
+ self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
+ for i in range(n_blocks - 1):
+ self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
+
+ def forward(self, x, concat_tensor):
+ x = self.conv1(x)
+ x = torch.cat((x, concat_tensor), dim=1)
+ for conv2 in self.conv2:
+ x = conv2(x)
+ return x
+
+
+class Decoder(nn.Module):
+ def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
+ super(Decoder, self).__init__()
+ self.layers = nn.ModuleList()
+ self.n_decoders = n_decoders
+ for i in range(self.n_decoders):
+ out_channels = in_channels // 2
+ self.layers.append(
+ ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
+ )
+ in_channels = out_channels
+
+ def forward(self, x, concat_tensors):
+ for i, layer in enumerate(self.layers):
+ x = layer(x, concat_tensors[-1 - i])
+ return x
+
+
+class DeepUnet(nn.Module):
+ def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
+ super(DeepUnet, self).__init__()
+ self.encoder = Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels)
+ self.intermediate = Intermediate(
+ self.encoder.out_channel // 2,
+ self.encoder.out_channel,
+ inter_layers,
+ n_blocks,
+ )
+ self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
+
+ def forward(self, x):
+ x, concat_tensors = self.encoder(x)
+ x = self.intermediate(x)
+ x = self.decoder(x, concat_tensors)
+ return x
+
+
+class E2E(nn.Module):
+ def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
+ super(E2E, self).__init__()
+ self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
+ if n_gru:
+ self.fc = nn.Sequential(
+ BiGRU(3 * 128, 256, n_gru),
+ nn.Linear(512, 360),
+ nn.Dropout(0.25),
+ nn.Sigmoid(),
+ )
+ else:
+ self.fc = nn.Sequential(
+ nn.Linear(3 * 128, 360),
+ nn.Dropout(0.25),
+ nn.Sigmoid()
+ )
+
+ def forward(self, mel):
+ mel = mel.transpose(-1, -2).unsqueeze(1)
+ x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
+ x = self.fc(x)
+ return x
+
+
+
+class MelSpectrogram(torch.nn.Module):
+ def __init__(self, is_half, n_mel_channels, sampling_rate, win_length, hop_length,
+ n_fft=None, mel_fmin=0, mel_fmax=None, clamp=1e-5):
+ super().__init__()
+ n_fft = win_length if n_fft is None else n_fft
+ self.hann_window = {}
+ mel_basis = mel(
+ sr=sampling_rate,
+ n_fft=n_fft,
+ n_mels=n_mel_channels,
+ fmin=mel_fmin,
+ fmax=mel_fmax,
+ htk=True,
+ )
+ mel_basis = torch.from_numpy(mel_basis).float()
+ self.register_buffer("mel_basis", mel_basis)
+ self.n_fft = win_length if n_fft is None else n_fft
+ self.hop_length = hop_length
+ self.win_length = win_length
+ self.sampling_rate = sampling_rate
+ self.n_mel_channels = n_mel_channels
+ self.clamp = clamp
+ self.is_half = is_half
+
+ def forward(self, audio, keyshift=0, speed=1, center=True):
+ factor = 2 ** (keyshift / 12)
+ n_fft_new = int(np.round(self.n_fft * factor))
+ win_length_new = int(np.round(self.win_length * factor))
+ hop_length_new = int(np.round(self.hop_length * speed))
+
+ keyshift_key = str(keyshift) + "_" + str(audio.device)
+ if keyshift_key not in self.hann_window:
+ self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
+
+ fft = torch.stft(
+ audio,
+ n_fft=n_fft_new,
+ hop_length=hop_length_new,
+ win_length=win_length_new,
+ window=self.hann_window[keyshift_key],
+ center=center,
+ return_complex=True,
+ )
+ magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
+
+ if keyshift != 0:
+ size = self.n_fft // 2 + 1
+ resize = magnitude.size(1)
+ if resize < size:
+ magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
+
+ mel_output = torch.matmul(self.mel_basis, magnitude)
+ if self.is_half:
+ mel_output = mel_output.half()
+ log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
+ return log_mel_spec
+
+
+
+class RMVPE:
+ def __init__(self, model_path: str, is_half, device=None):
+ self.is_half = is_half
+ if device is None:
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ self.device = torch.device(device) if isinstance(device, str) else device
+
+ self.mel_extractor = MelSpectrogram(
+ is_half=is_half,
+ n_mel_channels=128,
+ sampling_rate=16000,
+ win_length=1024,
+ hop_length=160,
+ n_fft=None,
+ mel_fmin=30,
+ mel_fmax=8000
+ ).to(self.device)
+
+ model = E2E(n_blocks=4, n_gru=1, kernel_size=(2, 2))
+ ckpt = torch.load(model_path, map_location=self.device)
+ model.load_state_dict(ckpt)
+ model.eval()
+
+ if is_half:
+ model = model.half()
+ else:
+ model = model.float()
+
+ self.model = model.to(self.device)
+
+ cents_mapping = 20 * np.arange(360) + 1997.3794084376191
+ self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
+
+ def mel2hidden(self, mel):
+ with torch.no_grad():
+ n_frames = mel.shape[-1]
+ n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
+ if n_pad > 0:
+ mel = F.pad(mel, (0, n_pad), mode="constant")
+ mel = mel.half() if self.is_half else mel.float()
+ hidden = self.model(mel)
+ return hidden[:, :n_frames]
+
+ def decode(self, hidden, thred=0.03):
+ cents_pred = self.to_local_average_cents(hidden, thred=thred)
+ f0 = 10 * (2 ** (cents_pred / 1200))
+ f0[f0 == 10] = 0
+ return f0
+
+ def infer_from_audio(self, audio, thred=0.03):
+ if not torch.is_tensor(audio):
+ audio = torch.from_numpy(audio)
+
+ mel = self.mel_extractor(audio.float().to(self.device).unsqueeze(0), center=True)
+ hidden = self.mel2hidden(mel)
+ hidden = hidden.squeeze(0).cpu().numpy()
+
+ if self.is_half:
+ hidden = hidden.astype("float32")
+
+ f0 = self.decode(hidden, thred=thred)
+ return f0
+
+ def to_local_average_cents(self, salience, thred=0.05):
+ center = np.argmax(salience, axis=1)
+ salience = np.pad(salience, ((0, 0), (4, 4)))
+ center += 4
+
+ todo_salience = []
+ todo_cents_mapping = []
+ starts = center - 4
+ ends = center + 5
+
+ for idx in range(salience.shape[0]):
+ todo_salience.append(salience[:, starts[idx]:ends[idx]][idx])
+ todo_cents_mapping.append(self.cents_mapping[starts[idx]:ends[idx]])
+
+ todo_salience = np.array(todo_salience)
+ todo_cents_mapping = np.array(todo_cents_mapping)
+ product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
+ weight_sum = np.sum(todo_salience, 1)
+ devided = product_sum / weight_sum
+
+ maxx = np.max(salience, axis=1)
+ devided[maxx <= thred] = 0
+
+ return devided
+
+class F0Extractor:
+ """Extract frame-level f0 from singing voice.
+
+ Wrapper around an RMVPE network that:
+ 1) loads the checkpoint once in ``__init__``
+ 2) exposes a simple :py:meth:`process` API and optionally saves ``*_f0.npy``.
+ """
+ def __init__(
+ self,
+ model_path: str,
+ device: str = "cpu",
+ *,
+ is_half: bool = False,
+ input_sr: int = 16000,
+ target_sr: int = 24000,
+ hop_size: int = 480,
+ max_duration: float = 300,
+ thred: float = 0.03,
+ verbose: bool = True,
+ ):
+ """Initialize the f0 extractor.
+
+ Args:
+ model_path: Path to RMVPE checkpoint.
+ device: Torch device string, e.g. ``"cuda:0"`` / ``"cpu"``.
+ is_half: Whether to run the model in fp16.
+ input_sr: Input resample rate used by RMVPE frontend.
+ target_sr: Target sample rate for the output f0 grid.
+ hop_size: Target hop size for the output f0 grid.
+ max_duration: Max duration (seconds) for interpolation grid.
+ thred: Voicing threshold used when decoding salience.
+ verbose: Whether to print verbose logs.
+ """
+ self.model_path = model_path
+ self.input_sr = input_sr
+ self.target_sr = target_sr
+ self.hop_size = hop_size
+ self.max_duration = max_duration
+ self.thred = thred
+
+ self.verbose = verbose
+
+ self.model = RMVPE(model_path, is_half=is_half, device=device)
+
+ if self.verbose:
+ print(
+ "[f0 extraction] init success:",
+ f"device={device}",
+ f"model_path={model_path}",
+ f"is_half={is_half}",
+ f"input_sr={input_sr}",
+ f"target_sr={target_sr}",
+ f"hop_size={hop_size}",
+ f"thred={thred}",
+ )
+
+ @staticmethod
+ def interpolate_f0(
+ f0_16k: np.ndarray,
+ original_length: int,
+ original_sr: int,
+ *,
+ target_sr: int = 48000,
+ hop_size: int = 256,
+ max_duration: float = 20.0,
+ ) -> np.ndarray:
+ """Interpolate f0 from RMVPE's 16k hop grid to target mel hop grid."""
+ mel_target_sr = target_sr
+ mel_hop_size = hop_size
+ mel_max_duration = max_duration
+
+ batch_max_length = int(mel_max_duration * mel_target_sr / mel_hop_size)
+ duration_in_seconds = original_length / original_sr
+ effective_target_length = int(duration_in_seconds * mel_target_sr)
+ original_frames = math.ceil(effective_target_length / mel_hop_size)
+ target_frames = min(original_frames, batch_max_length)
+
+ rmvpe_hop = 160
+ t_16k = np.arange(len(f0_16k)) * (rmvpe_hop / 16000.0)
+ t_target = np.arange(target_frames) * (mel_hop_size / float(mel_target_sr))
+
+ if len(f0_16k) > 0:
+ f_interp = interp1d(
+ t_16k,
+ f0_16k,
+ kind="linear",
+ bounds_error=False,
+ fill_value=0.0,
+ assume_sorted=True,
+ )
+ f0 = f_interp(t_target)
+ else:
+ f0 = np.zeros(target_frames)
+
+ if len(f0) != target_frames:
+ f0 = (
+ f0[:target_frames]
+ if len(f0) > target_frames
+ else np.pad(f0, (0, target_frames - len(f0)), "constant")
+ )
+
+ return f0
+
+ def process(self, audio_path: str, *, f0_path: str | None = None, verbose: Optional[bool] = None) -> np.ndarray:
+ """Run f0 extraction for a single wav.
+
+ Args:
+ audio_path: Path to the input wav file.
+ f0_path: if is not None, save the f0 data to this path.
+ verbose: Override instance-level verbose flag for this call.
+
+ Returns:
+ np.ndarray: shape ``[T]``, f0 in Hz (0 for unvoiced).
+ """
+ verbose = self.verbose if verbose is None else verbose
+ if verbose:
+ print(f"[f0 extraction] process: start: {audio_path}")
+ t0 = time.time()
+
+ audio, _ = librosa.load(audio_path, sr=self.input_sr)
+ f0_16k = self.model.infer_from_audio(audio, thred=self.thred)
+ f0 = self.interpolate_f0(
+ f0_16k,
+ original_length=audio.shape[-1],
+ original_sr=self.input_sr,
+ target_sr=self.target_sr,
+ hop_size=self.hop_size,
+ max_duration=self.max_duration,
+ )
+
+ if verbose:
+ dt = time.time() - t0
+ voiced_ratio = float(np.mean(f0 > 0)) if len(f0) else 0.0
+ print(
+ "[f0 extraction] process: done:",
+ f"frames={len(f0)}",
+ f"voiced_ratio={voiced_ratio:.3f}",
+ f"time={dt:.3f}s",
+ )
+ if f0_path is not None:
+ np.save(f0_path, f0)
+
+ return f0
+
+
+if __name__ == "__main__":
+ model_path = (
+ "pretrained_models/rmvpe/rmvpe.pt"
+ )
+ audio_path = "./outputs/transcription/test.wav"
+
+ pe = F0Extractor(
+ model_path,
+ device="cuda",
+ )
+ f0 = pe.process(audio_path)
diff --git a/preprocess/tools/g2p.py b/preprocess/tools/g2p.py
new file mode 100644
index 0000000000000000000000000000000000000000..5861d06b3cdec638cd6088605d1cd0f6f392f709
--- /dev/null
+++ b/preprocess/tools/g2p.py
@@ -0,0 +1,72 @@
+import re
+
+import ToJyutping
+from g2pM import G2pM
+from g2p_en import G2p as G2pE
+
+_EN_WORD_RE = re.compile(r"^[A-Za-z]+(?:'[A-Za-z]+)*$")
+_ZH_WORD_RE = re.compile(r"[\u4e00-\u9fff]")
+
+EN_FLAG = "en_"
+YUE_FLAG = "yue_"
+ZH_FLAG = "zh_"
+
+g2p_zh = G2pM()
+g2p_en = G2pE()
+
+
+def is_chinese_char(word: str) -> bool:
+ if len(word) != 1:
+ return False
+ return bool(_ZH_WORD_RE.fullmatch(word))
+
+def is_english_word(word: str) -> bool:
+ if not word:
+ return False
+ return bool(_EN_WORD_RE.fullmatch(word))
+
+def g2p_cantonese(sent):
+ return ToJyutping.get_jyutping_list(sent) # with tone
+
+def g2p_mandarin(sent):
+ return g2p_zh(sent, tone=True, char_split=False)
+
+def g2p_english(word):
+ return g2p_en(word)
+
+def g2p_transform(words, lang):
+
+ zh_words = []
+ transformed_words = [0] * len(words)
+
+ for idx, w in enumerate(words):
+ if w == "":
+ transformed_words[idx] = w
+ continue
+
+ w = w.replace("?", "").replace(".", "").replace("!", "").replace(",", "")
+
+ if is_chinese_char(w):
+ zh_words.append([idx, w])
+ else:
+ if is_english_word(w):
+ w = EN_FLAG + "-".join(g2p_english(w.lower()))
+ else:
+ w = ""
+ transformed_words[idx] = w
+
+ sent = "".join([k[1] for k in zh_words])
+
+ # zh (zh and yue) transformer to g2p
+ if len(sent) > 0:
+ if lang == "Cantonese":
+ g2pm_rst = g2p_cantonese(sent) # with tone
+ g2pm_rst = [YUE_FLAG + k[1] for k in g2pm_rst]
+ else:
+ g2pm_rst = g2p_mandarin(sent)
+ g2pm_rst = [ZH_FLAG + k for k in g2pm_rst]
+ for p, w in zip([k[0] for k in zh_words], g2pm_rst):
+ transformed_words[p] = w
+
+ return transformed_words
+
diff --git a/preprocess/tools/lyric_transcription.py b/preprocess/tools/lyric_transcription.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b5d2da69a8c50e0c8d2adbf65df76b3043f6d4
--- /dev/null
+++ b/preprocess/tools/lyric_transcription.py
@@ -0,0 +1,279 @@
+# https://modelscope.cn/models/iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary
+# https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2
+import os
+import re
+import time
+from typing import Any, Dict, List, Tuple
+
+import librosa
+import numpy as np
+from funasr import AutoModel
+
+
+def _build_words_with_gaps(raw_words, raw_timestamps, wav_fn: str):
+ words, word_durs = [], []
+ prev = 0.0
+ for w, t in zip(raw_words, raw_timestamps):
+ s, e = float(t[0]), float(t[1])
+ if s > prev:
+ words.append("")
+ word_durs.append(s - prev)
+ words.append(w)
+ word_durs.append(e - s)
+ prev = e
+
+ wav_len = librosa.get_duration(filename=wav_fn)
+ if wav_len > prev:
+ if len(words) == 0:
+ words.append("")
+ word_durs.append(wav_len)
+ return words, word_durs
+ if words[-1] != "":
+ words.append("")
+ word_durs.append(wav_len - prev)
+ else:
+ word_durs[-1] += wav_len - prev
+
+ return words, word_durs
+
+def _word_dur_post_process(words, word_durs, f0):
+ """Post-process word durations using f0 to better place silences.
+ """
+ # f0 time grid parameters
+ sr = 24000 # f0 sample rate
+ hop_length = 480 # f0 hop length
+
+ # Convert word durations (seconds) to frame boundaries on the f0 grid.
+ boundaries = np.cumsum([
+ 0,
+ *[
+ int(dur * sr / hop_length)
+ for dur in word_durs
+ ],
+ ]).tolist()
+
+ sil_tolerance = 5 # tolerance frames for silence detection
+ ext_tolerance = 5 # tolerance frames for vocal extension
+
+ new_words: list[str] = []
+ new_word_durs: list[float] = []
+ if words:
+ new_words.append(words[0])
+ new_word_durs.append(word_durs[0])
+
+ for i in range(1, len(words)):
+ word = words[i]
+ if word == "":
+ start_frame = boundaries[i]
+ end_frame = boundaries[i + 1]
+
+ num_frames = end_frame - start_frame
+ frame_idx = start_frame
+
+ # Find first region with at least 5 consecutive "unvoiced" frames.
+ unvoiced_count = 0
+ while frame_idx < end_frame:
+ if f0[frame_idx] <= 1: # unvoiced
+ unvoiced_count += 1
+ if unvoiced_count >= sil_tolerance:
+ frame_idx -= sil_tolerance - 1 # back to the last voiced frame
+ break
+ else:
+ unvoiced_count = 0
+ frame_idx += 1
+
+ voice_frames = frame_idx - start_frame
+
+ if voice_frames >= int(num_frames * 0.9): # over 90% voiced
+ # Treat the whole "" as silence and merge into previous word.
+ new_word_durs[-1] += word_durs[i]
+ elif voice_frames >= ext_tolerance: # over 5 frames voiced
+ # Split the "" into two parts: leading silence and tail kept as "".
+ dur = voice_frames * hop_length / sr
+ new_word_durs[-1] += dur
+ new_words.append("")
+ new_word_durs.append(word_durs[i] - dur)
+ else:
+ # Too short to adjust, keep as-is.
+ new_words.append(word)
+ new_word_durs.append(word_durs[i])
+ else:
+ new_words.append(word)
+ new_word_durs.append(word_durs[i])
+
+ return new_words, new_word_durs
+
+
+class _ASRZhModel:
+ """Mandarin/Cantonese ASR wrapper."""
+
+ def __init__(self, model_path: str, device: str):
+ self.model = AutoModel(
+ model=model_path,
+ disable_update=True,
+ device=device,
+ )
+
+ def process(self, wav_fn):
+ out = self.model.generate(wav_fn, output_timestamp=True)[0]
+ raw_words = out["text"].replace("@", "").split(" ")
+ raw_timestamps = [[t[0] / 1000, t[1] / 1000] for t in out["timestamp"]]
+ words, word_durs = _build_words_with_gaps(raw_words, raw_timestamps, wav_fn)
+
+ if os.path.exists(wav_fn.replace(".wav", "_f0.npy")):
+ words, word_durs = _word_dur_post_process(
+ words, word_durs, np.load(wav_fn.replace(".wav", "_f0.npy"))
+ )
+
+ return words, word_durs
+
+
+class _ASREnModel:
+ """English ASR wrapper for NeMo Parakeet-TDT."""
+
+ def __init__(self, model_path: str, device: str):
+ try:
+ import nemo.collections.asr as nemo_asr # type: ignore
+ except Exception as e: # pragma: no cover
+ raise ImportError(
+ "NeMo (nemo_toolkit) is required for ASR English but is not available in this Python env. "
+ "Install it in the active environment, then retry."
+ ) from e
+
+ self.model = nemo_asr.models.ASRModel.restore_from(
+ restore_path=model_path,
+ map_location=device,
+ )
+ self.model.eval()
+
+ @staticmethod
+ def _clean_word(word: str) -> str:
+ return re.sub(r"[\?\.,:]", "", word).strip()
+
+ @staticmethod
+ def _extract_word_segments(output: Any) -> List[Dict[str, Any]]:
+ ts = getattr(output, "timestamp", None)
+ if not ts or not isinstance(ts, dict):
+ return []
+ word_ts = ts.get("word")
+ return word_ts if isinstance(word_ts, list) else []
+
+ def process(self, wav_fn: str) -> Tuple[List[str], List[float]]:
+ outputs = self.model.transcribe(
+ [wav_fn],
+ timestamps=True,
+ batch_size=1,
+ num_workers=0,
+ )
+ output = outputs[0] if outputs else None
+
+ raw_words: List[str] = []
+ raw_timestamps: List[List[float]] = []
+ if output is not None:
+ for w in self._extract_word_segments(output):
+ s, e = float(w.get("start", 0.0)), float(w.get("end", 0.0))
+ word = self._clean_word(str(w.get("word", "")))
+ if word:
+ raw_words.append(word)
+ raw_timestamps.append([s, e])
+
+ words, durs = _build_words_with_gaps(raw_words, raw_timestamps, wav_fn)
+
+ if os.path.exists(wav_fn.replace(".wav", "_f0.npy")):
+ words, durs = _word_dur_post_process(
+ words, durs, np.load(wav_fn.replace(".wav", "_f0.npy"))
+ )
+
+ return words, durs
+
+
+class LyricTranscriber:
+ """Transcribe lyrics from singing voice segment
+ """
+
+ def __init__(
+ self,
+ zh_model_path: str,
+ en_model_path: str,
+ device: str = "cuda",
+ *,
+ verbose: bool = True,
+ ):
+ """Initialize lyric transcriber.
+
+ Args:
+ zh_model_path (str): Path to the Chinese model file.
+ en_model_path (str): Path to the English model file.
+ device (str): Device to use for tensor operations.
+ verbose (bool): Whether to print verbose logs.
+ """
+ self.verbose = verbose
+ self.device = device
+ self.zh_model_path = zh_model_path
+ self.en_model_path = en_model_path
+
+ if self.verbose:
+ print(
+ "[lyric transcription] init: start:",
+ f"device={device}",
+ f"model_path={zh_model_path}",
+ )
+
+ # Always initialize Chinese ASR.
+ self.zh_model = _ASRZhModel(device=device, model_path=zh_model_path)
+
+ # English ASR will be lazily initialized on first English request to avoid long waiting cost when importing NeMo
+ self.en_model = None
+
+ if self.verbose:
+ print("[lyric transcription] init: success")
+
+ def process(self, wav_fn, language: str | None = "Mandarin", *, verbose: bool | None = None):
+ """ Lyric transcriber process
+
+ Args:
+ wav_fn (str): Path to the audio file.
+ language (str | None): Language of the audio. Defaults to "Mandarin". Supports "Mandarin", "Cantonese" and "English".
+ verbose (bool | None): Whether to print verbose logs. Defaults to None.
+ """
+ v = self.verbose if verbose is None else verbose
+ if language not in {"Mandarin", "Cantonese", "English"}:
+ raise ValueError(f"Unsupported language: {language}, should be one of ['Mandarin', 'Cantonese', 'English']")
+ if v:
+ print(f"[lyric transcription] process: start: wav_fn={wav_fn} language={language}")
+ t0 = time.time()
+
+ lang = (language or "auto").lower()
+ if lang in {"english"}:
+ if self.en_model is None:
+ # Lazy-load NeMo model only when English is actually used.
+ if v:
+ print("[lyric transcription] init English ASR, please make sure NeMo is installed")
+ self.en_model = _ASREnModel(model_path=self.en_model_path, device=self.device)
+ out = self.en_model.process(wav_fn)
+ else:
+ out = self.zh_model.process(wav_fn)
+
+ if v:
+ words, durs = out
+ n_words = len(words) if isinstance(words, list) else 0
+ dur_sum = float(sum(durs)) if isinstance(durs, list) else 0.0
+ dt = time.time() - t0
+ print(
+ "[lyric transcription] process: done:",
+ f"n_words={n_words}",
+ f"dur_sum={dur_sum:.3f}s",
+ f"time={dt:.3f}s",
+ )
+
+ return out
+
+
+if __name__ == "__main__":
+ m = LyricTranscriber(
+ zh_model_path="pretrained_models/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+ en_model_path="pretrained_models/parakeet-tdt-0.6b-v2/parakeet-tdt-0.6b-v2.nemo",
+ device="cuda"
+ )
+ print(m.process("example/test/asr_zh.wav", language="Mandarin"))
+ print(m.process("example/test/asr_en.wav", language="English"))
\ No newline at end of file
diff --git a/preprocess/tools/midi_parser.py b/preprocess/tools/midi_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c0ed3b79776855da4bf4c3ccf6766c0bd4a88d8
--- /dev/null
+++ b/preprocess/tools/midi_parser.py
@@ -0,0 +1,669 @@
+"""
+SoulX-Singer MIDI <-> metadata converter.
+
+Converts between SoulX-Singer-style metadata JSON (with note_text, note_dur,
+note_pitch, note_type per segment) and standard MIDI files. Uses an internal
+Note dataclass (start_s, note_dur, note_text, note_pitch, note_type) as the
+intermediate representation.
+"""
+import os
+import json
+import shutil
+from dataclasses import dataclass
+from typing import Any, List, Tuple, Union
+
+import librosa
+import mido
+from soundfile import write
+
+from .f0_extraction import F0Extractor
+from .g2p import g2p_transform
+
+
+# Audio and segmenting constants (used by _edit_data_to_meta)
+SAMPLE_RATE = 44100
+DEFAULT_LANGUAGE = "Mandarin"
+MAX_GAP_SEC = 5.0 # gap (sec) above which we start a new segment
+MAX_SEGMENT_DUR_SUM_SEC = 60.0 # max cumulative note duration per segment (sec)
+MIN_GAP_THRESHOLD_SEC = 0.001 # ignore gaps smaller than this
+LONG_SILENCE_THRESHOLD_SEC = 0.05 # treat as separate if gap larger
+MAX_LEADING_SP_DUR_SEC = 2.0 # cap leading silence in a segment to this (sec)
+DEFAULT_RMVPE_MODEL_PATH = "pretrained_models/SoulX-Singer-Preprocess/rmvpe/rmvpe.pt"
+
+
+@dataclass
+class Note:
+ """Single note: text, duration (seconds), pitch (MIDI), type. start_s is absolute start time in seconds (for ordering / MIDI)."""
+ start_s: float
+ note_dur: float
+ note_text: str
+ note_pitch: int
+ note_type: int
+
+ @property
+ def end_s(self) -> float:
+ return self.start_s + self.note_dur
+
+
+
+def remove_duplicate_segments(meta_data: List[dict]) -> None:
+ """Merge consecutive identical notes (same text, pitch, type) within each segment. Mutates meta_data in place."""
+ for idx, segment in enumerate(meta_data):
+ texts = segment["note_text"]
+ durs = segment["note_dur"]
+ pitches = segment["note_pitch"]
+ types = segment["note_type"]
+ new_texts = []
+ new_durs = []
+ new_pitches = []
+ new_types = []
+ for i in range(len(texts)):
+ if i == 0:
+ new_texts.append(texts[i])
+ new_durs.append(durs[i])
+ new_pitches.append(pitches[i])
+ new_types.append(types[i])
+ continue
+ t, d, p, ty = texts[i], durs[i], pitches[i], types[i]
+ if t == "" and texts[i - 1] == "":
+ new_durs[-1] += d
+ continue
+ if t == texts[i - 1] and p == pitches[i - 1] and ty == types[i - 1]:
+ new_durs[-1] += d
+ else:
+ new_texts.append(t)
+ new_durs.append(d)
+ new_pitches.append(p)
+ new_types.append(ty)
+ meta_data[idx]["note_text"] = new_texts
+ meta_data[idx]["note_dur"] = new_durs
+ meta_data[idx]["note_pitch"] = new_pitches
+ meta_data[idx]["note_type"] = new_types
+
+def meta2notes(meta_path: str) -> List[Note]:
+ """Parse SoulX-Singer metadata JSON into a flat list of Note (absolute start_s)."""
+ with open(meta_path, "r", encoding="utf-8") as f:
+ segments = json.load(f)
+ if not isinstance(segments, list):
+ raise ValueError(f"Metadata must be a list of segments, got {type(segments).__name__}")
+ if not segments:
+ raise ValueError("Metadata has no segments.")
+
+ notes: List[Note] = []
+ for seg in segments:
+ offset_s = seg["time"][0] / 1000
+ words = [str(x).replace("", "") for i, x in enumerate(seg["text"].split())]
+ word_durs = [float(x) for x in seg["duration"].split()]
+ pitches = [int(x) for x in seg["note_pitch"].split()]
+ types = [int(x) if words[i] != "" else 1 for i, x in enumerate(seg["note_type"].split())]
+ if len(words) != len(word_durs) or len(word_durs) != len(pitches) or len(pitches) != len(types):
+ raise ValueError(
+ f"Length mismatch in segment {seg.get('item_name', '?')}: "
+ "note_text, note_dur, note_pitch, note_type must have same length"
+ )
+ current_s = offset_s
+ for text, dur, pitch, type_ in zip(words, word_durs, pitches, types):
+ notes.append(
+ Note(
+ start_s=current_s,
+ note_dur=float(dur),
+ note_text=str(text),
+ note_pitch=int(pitch),
+ note_type=int(type_),
+ )
+ )
+ current_s += float(dur)
+ return notes
+
+def _append_segment_to_meta(
+ meta_path_str: str,
+ cut_wavs_output_dir: str,
+ vocal_file: str,
+ audio_data: Any,
+ meta_data: List[dict],
+ note_start: List[float],
+ note_end: List[float],
+ note_text: List[Any],
+ note_pitch: List[Any],
+ note_type: List[Any],
+ note_dur: List[float],
+ end_time_ms_override: float | None = None,
+) -> None:
+ """Write one segment wav and append one segment dict to meta_data. Caller clears note_* lists after."""
+ base_name = os.path.splitext(os.path.basename(meta_path_str))[0]
+ item_name = f"{base_name}_{len(meta_data)}"
+ wav_fn = os.path.join(cut_wavs_output_dir, f"{item_name}.wav")
+ start_ms = int(note_start[0] * 1000)
+ end_ms = (
+ int(end_time_ms_override)
+ if end_time_ms_override is not None
+ else int(note_end[-1] * 1000)
+ )
+ start_sample = int(note_start[0] * SAMPLE_RATE)
+ end_sample = int(note_end[-1] * SAMPLE_RATE)
+ write(wav_fn, audio_data[start_sample:end_sample], SAMPLE_RATE)
+ meta_data.append({
+ "item_name": item_name,
+ "wav_fn": wav_fn,
+ "origin_wav_fn": vocal_file,
+ "start_time_ms": start_ms,
+ "end_time_ms": end_ms,
+ "language": DEFAULT_LANGUAGE,
+ "note_text": list(note_text),
+ "note_pitch": list(note_pitch),
+ "note_type": list(note_type),
+ "note_dur": list(note_dur),
+ })
+
+
+def convert_meta(meta_data: List[dict], rmvpe_model_path, device="cuda"):
+ pitch_extractor = F0Extractor(rmvpe_model_path, device=device, verbose=False)
+ converted_data = []
+
+ for item in meta_data:
+ wav_fn = item.get("wav_fn")
+ if not wav_fn or not os.path.isfile(wav_fn):
+ raise FileNotFoundError(f"Segment wav file not found: {wav_fn}")
+ f0 = pitch_extractor.process(wav_fn)
+ converted_item = {
+ "index": item.get("item_name"),
+ "language": item.get("language"),
+ "time": [item.get("start_time_ms", 0), item.get("end_time_ms", sum(item["note_dur"]) * 1000)],
+ "duration": " ".join(str(round(x, 2)) for x in item.get("note_dur", [])),
+ "text": " ".join(item.get("note_text", [])),
+ "phoneme": " ".join(g2p_transform(item.get("note_text", []), DEFAULT_LANGUAGE)),
+ "note_pitch": " ".join(str(x) for x in item.get("note_pitch", [])),
+ "note_type": " ".join(str(x) for x in item.get("note_type", [])),
+ "f0": " ".join(str(round(float(x), 1)) for x in f0),
+ }
+ converted_data.append(converted_item)
+
+ return converted_data
+
+
+def _edit_data_to_meta(
+ meta_path_str: str,
+ edit_data: List[dict],
+ vocal_file: str,
+ rmvpe_model_path: str | None = None,
+ device: str = "cuda",
+) -> None:
+ """Write SoulX-Singer metadata JSON from edit_data (list of {start, end, note_text, note_pitch, note_type})."""
+ # Use a fixed temporary directory for cut wavs
+ cut_wavs_output_dir = os.path.join(os.path.dirname(vocal_file), "cut_wavs_tmp")
+ os.makedirs(cut_wavs_output_dir, exist_ok=True)
+
+ note_text: List[Any] = []
+ note_pitch: List[Any] = []
+ note_type: List[Any] = []
+ note_dur: List[float] = []
+ note_start: List[float] = []
+ note_end: List[float] = []
+ prev_end = 0.0
+ meta_data: List[dict] = []
+ audio_data, _ = librosa.load(vocal_file, sr=SAMPLE_RATE, mono=True)
+ dur_sum = 0.0
+
+ for entry in edit_data:
+ start = float(entry["start"])
+ end = float(entry["end"])
+ text = entry["note_text"]
+ pitch = entry["note_pitch"]
+ type_ = entry["note_type"]
+
+ if text == "" or pitch == "" or type_ == "":
+ note_text.append("")
+ note_pitch.append(0)
+ note_type.append(1)
+ note_dur.append(end - start)
+ note_start.append(start)
+ note_end.append(end)
+ prev_end = end
+ dur_sum += end - start
+ continue
+
+ if (
+ len(note_text) > 0
+ and note_text[-1] == ""
+ and note_dur[-1] > MAX_LEADING_SP_DUR_SEC
+ ):
+ cut_time = note_dur[-1] - MAX_LEADING_SP_DUR_SEC
+ note_dur[-1] = MAX_LEADING_SP_DUR_SEC
+ end_ms_override = note_end[-1] * 1000 - cut_time * 1000
+ _append_segment_to_meta(
+ meta_path_str,
+ cut_wavs_output_dir,
+ vocal_file,
+ audio_data,
+ meta_data,
+ note_start,
+ note_end,
+ note_text,
+ note_pitch,
+ note_type,
+ note_dur,
+ end_time_ms_override=end_ms_override,
+ )
+ note_text = []
+ note_pitch = []
+ note_type = []
+ note_dur = []
+ note_start = []
+ note_end = []
+ prev_end = start
+ dur_sum = 0.0
+
+ gap_from_prev = start - prev_end
+ gap_from_last_note = (start - note_end[-1]) if note_end else 0.0
+ if (
+ gap_from_prev >= MAX_GAP_SEC
+ or gap_from_last_note >= MAX_GAP_SEC
+ or dur_sum >= MAX_SEGMENT_DUR_SUM_SEC
+ ):
+ if len(note_text) > 0:
+ _append_segment_to_meta(
+ meta_path_str,
+ cut_wavs_output_dir,
+ vocal_file,
+ audio_data,
+ meta_data,
+ note_start,
+ note_end,
+ note_text,
+ note_pitch,
+ note_type,
+ note_dur,
+ )
+ note_text = []
+ note_pitch = []
+ note_type = []
+ note_dur = []
+ note_start = []
+ note_end = []
+ prev_end = start
+ dur_sum = 0.0
+
+ if start - prev_end > MIN_GAP_THRESHOLD_SEC:
+ if start - prev_end > LONG_SILENCE_THRESHOLD_SEC or len(note_text) == 0:
+ note_text.append("")
+ note_pitch.append(0)
+ note_type.append(1)
+ note_dur.append(start - prev_end)
+ note_start.append(prev_end)
+ note_end.append(start)
+ else:
+ if len(note_dur) > 0:
+ note_dur[-1] += start - prev_end
+ note_end[-1] = start
+
+ prev_end = end
+ note_text.append(text)
+ note_pitch.append(int(pitch))
+ note_type.append(int(type_))
+ note_dur.append(end - start)
+ note_start.append(start)
+ note_end.append(end)
+ dur_sum += end - start
+
+ if len(note_text) > 0:
+ _append_segment_to_meta(
+ meta_path_str,
+ cut_wavs_output_dir,
+ vocal_file,
+ audio_data,
+ meta_data,
+ note_start,
+ note_end,
+ note_text,
+ note_pitch,
+ note_type,
+ note_dur,
+ )
+
+ remove_duplicate_segments(meta_data)
+
+ _rmvpe_path = rmvpe_model_path or DEFAULT_RMVPE_MODEL_PATH
+ converted_data = convert_meta(meta_data, _rmvpe_path, device)
+
+ with open(meta_path_str, "w", encoding="utf-8") as f:
+ json.dump(converted_data, f, ensure_ascii=False, indent=2)
+
+ # Clean up temporary cut wavs directory
+ try:
+ shutil.rmtree(cut_wavs_output_dir, ignore_errors=True)
+ except Exception:
+ pass
+
+
+def notes2meta(
+ notes: List[Note],
+ meta_path: str,
+ vocal_file: str,
+ rmvpe_model_path: str | None = None,
+ device: str = "cuda",
+) -> None:
+ """Write SoulX-Singer metadata JSON from a list of Note (segmenting + wav cuts)."""
+ edit_data = [
+ {
+ "start": n.start_s,
+ "end": n.end_s,
+ "note_text": n.note_text,
+ "note_pitch": str(n.note_pitch),
+ "note_type": str(n.note_type),
+ }
+ for n in notes
+ ]
+ _edit_data_to_meta(
+ str(meta_path),
+ edit_data,
+ vocal_file,
+ rmvpe_model_path=rmvpe_model_path,
+ device=device,
+ )
+
+
+@dataclass(frozen=True)
+class MidiDefaults:
+ ticks_per_beat: int = 500
+ tempo: int = 500000 # microseconds per beat (120 BPM)
+ time_signature: Tuple[int, int] = (4, 4)
+ velocity: int = 64
+
+
+def _seconds_to_ticks(seconds: float, ticks_per_beat: int, tempo: int) -> int:
+ return int(round(seconds * ticks_per_beat * 1_000_000 / tempo))
+
+
+def notes2midi(
+ notes: List[Note],
+ midi_path: str,
+ defaults: MidiDefaults | None = None,
+) -> None:
+ """Write MIDI file from a list of Note."""
+ defaults = defaults or MidiDefaults()
+ if not notes:
+ raise ValueError("Empty note list.")
+
+ events: List[Tuple[int, int, Union[mido.Message, mido.MetaMessage]]] = []
+ for n in notes:
+ start_s = n.start_s
+ end_s = n.end_s
+ if end_s <= start_s:
+ continue
+
+ start_ticks = _seconds_to_ticks(
+ start_s, defaults.ticks_per_beat, defaults.tempo
+ )
+ end_ticks = _seconds_to_ticks(
+ end_s, defaults.ticks_per_beat, defaults.tempo
+ )
+ if end_ticks <= start_ticks:
+ end_ticks = start_ticks + 1
+
+ lyric = n.note_text
+ try:
+ lyric = lyric.encode("utf-8").decode("latin1")
+ except (UnicodeEncodeError, UnicodeDecodeError):
+ pass
+ if n.note_type == 3:
+ lyric = "-"
+
+ events.append(
+ (start_ticks, 1, mido.MetaMessage("lyrics", text=lyric, time=0))
+ )
+ events.append(
+ (
+ start_ticks,
+ 2,
+ mido.Message(
+ "note_on",
+ note=n.note_pitch,
+ velocity=defaults.velocity,
+ time=0,
+ ),
+ )
+ )
+ events.append(
+ (
+ end_ticks,
+ 0,
+ mido.Message("note_off", note=n.note_pitch, velocity=0, time=0),
+ )
+ )
+
+ events.sort(key=lambda x: (x[0], x[1]))
+
+ mid = mido.MidiFile(ticks_per_beat=defaults.ticks_per_beat)
+ track = mido.MidiTrack()
+ mid.tracks.append(track)
+
+ track.append(mido.MetaMessage("set_tempo", tempo=defaults.tempo, time=0))
+ track.append(
+ mido.MetaMessage(
+ "time_signature",
+ numerator=defaults.time_signature[0],
+ denominator=defaults.time_signature[1],
+ time=0,
+ )
+ )
+
+ last_tick = 0
+ for tick, _, msg in events:
+ msg.time = max(0, tick - last_tick)
+ track.append(msg)
+ last_tick = tick
+
+ track.append(mido.MetaMessage("end_of_track", time=0))
+ mid.save(midi_path)
+
+
+def midi2notes(midi_path: str) -> List[Note]:
+ """Parse MIDI file into a list of Note. Merges all tracks; tempo from last set_tempo event."""
+ mid = mido.MidiFile(midi_path)
+ ticks_per_beat = mid.ticks_per_beat
+ tempo = 500000
+
+ raw_notes: List[dict] = []
+ lyrics: List[Tuple[int, str]] = []
+
+ for track in mid.tracks:
+ abs_ticks = 0
+ active = {}
+ for msg in track:
+ abs_ticks += msg.time
+ if msg.type == "set_tempo":
+ tempo = msg.tempo
+ elif msg.type == "lyrics":
+ text = msg.text
+ try:
+ text = text.encode("latin1").decode("utf-8")
+ except Exception:
+ pass
+ lyrics.append((abs_ticks, text))
+ elif msg.type == "note_on":
+ key = (msg.channel, msg.note)
+ if msg.velocity > 0:
+ active[key] = (abs_ticks, msg.velocity)
+ else:
+ if key in active:
+ start_ticks, vel = active.pop(key)
+ raw_notes.append(
+ {
+ "midi": msg.note,
+ "start_ticks": start_ticks,
+ "duration_ticks": abs_ticks - start_ticks,
+ "velocity": vel,
+ "lyric": "",
+ }
+ )
+ elif msg.type == "note_off":
+ key = (msg.channel, msg.note)
+ if key in active:
+ start_ticks, vel = active.pop(key)
+ raw_notes.append(
+ {
+ "midi": msg.note,
+ "start_ticks": start_ticks,
+ "duration_ticks": abs_ticks - start_ticks,
+ "velocity": vel,
+ "lyric": "",
+ }
+ )
+
+ if not raw_notes:
+ raise ValueError("No notes found in MIDI file")
+
+ for n in raw_notes:
+ n["end_ticks"] = n["start_ticks"] + n["duration_ticks"]
+
+ raw_notes.sort(key=lambda n: n["start_ticks"])
+ lyrics.sort(key=lambda x: x[0])
+
+ trimmed = []
+ for note in raw_notes:
+ while trimmed:
+ prev = trimmed[-1]
+ if note["start_ticks"] < prev["end_ticks"]:
+ prev["end_ticks"] = note["start_ticks"]
+ prev["duration_ticks"] = prev["end_ticks"] - prev["start_ticks"]
+ if prev["duration_ticks"] <= 0:
+ trimmed.pop()
+ continue
+ break
+ trimmed.append(note)
+ raw_notes = trimmed
+
+ tolerance = ticks_per_beat // 100
+ lyric_idx = 0
+ for note in raw_notes:
+ while lyric_idx < len(lyrics) and lyrics[lyric_idx][0] < note["start_ticks"] - tolerance:
+ lyric_idx += 1
+ if lyric_idx < len(lyrics):
+ lyric_ticks, lyric_text = lyrics[lyric_idx]
+ if abs(lyric_ticks - note["start_ticks"]) <= tolerance:
+ note["lyric"] = lyric_text
+ lyric_idx += 1
+
+ def ticks_to_seconds(ticks: int) -> float:
+ return (ticks / ticks_per_beat) * (tempo / 1_000_000)
+
+ result: List[Note] = []
+ prev_end_s = 0.0
+ for idx, n in enumerate(raw_notes):
+ start_s = ticks_to_seconds(n["start_ticks"])
+ end_s = ticks_to_seconds(n["end_ticks"])
+ if prev_end_s > start_s:
+ start_s = prev_end_s
+ dur_s = end_s - start_s
+ if dur_s <= 0:
+ continue
+
+ lyric = n.get("lyric", "")
+ if not lyric:
+ tp = 2
+ text = "啦"
+ elif lyric == "":
+ tp = 1
+ text = ""
+ elif lyric == "-":
+ tp = 3
+ text = raw_notes[idx - 1].get("lyric", "-") if idx > 0 else "-"
+ else:
+ tp = 2
+ text = lyric
+
+ result.append(
+ Note(
+ start_s=start_s,
+ note_dur=dur_s,
+ note_text=text,
+ note_pitch=n["midi"],
+ note_type=tp,
+ )
+ )
+ prev_end_s = end_s
+
+ return result
+
+
+def meta2midi(meta_path: str, midi_path: str, defaults: MidiDefaults | None = None) -> None:
+ """Convert SoulX-Singer metadata JSON to MIDI file (meta -> List[Note] -> midi)."""
+ notes = meta2notes(meta_path)
+ notes2midi(notes, midi_path, defaults)
+ print(f"Saved MIDI to {midi_path}")
+
+
+def midi2meta(
+ midi_path: str,
+ meta_path: str,
+ vocal_file: str,
+ rmvpe_model_path: str | None = None,
+ device: str = "cuda",
+) -> None:
+ """Convert MIDI file to SoulX-Singer metadata JSON (midi -> List[Note] -> meta)."""
+ meta_dir = os.path.dirname(meta_path)
+ if meta_dir:
+ os.makedirs(meta_dir, exist_ok=True)
+ # cut_wavs will be written to a fixed temporary directory inside _edit_data_to_meta
+ notes = midi2notes(midi_path)
+ notes2meta(
+ notes,
+ meta_path,
+ vocal_file,
+ rmvpe_model_path=rmvpe_model_path,
+ device=device,
+ )
+ print(f"Saved Meta to {meta_path}")
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ description="Convert SoulX-Singer metadata JSON <-> MIDI."
+ )
+ parser.add_argument("--meta", type=str, help="Path to metadata JSON")
+ parser.add_argument("--midi", type=str, help="Path to MIDI file")
+ parser.add_argument("--vocal", type=str, help="Path to vocal wav (for midi2meta)")
+ parser.add_argument(
+ "--meta2midi",
+ action="store_true",
+ help="Convert meta -> midi (requires --meta and --midi)",
+ )
+ parser.add_argument(
+ "--midi2meta",
+ action="store_true",
+ help="Convert midi -> meta (requires --midi, --meta, --vocal, --cut_wavs_dir)",
+ )
+ parser.add_argument(
+ "--rmvpe_model_path",
+ type=str,
+ help="Path to RMVPE model",
+ default="pretrained_models/SoulX-Singer-Preprocess/rmvpe/rmvpe.pt",
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ help="Device to use for RMVPE",
+ default="cuda",
+ )
+ args = parser.parse_args()
+
+ if args.meta2midi:
+ if not args.meta or not args.midi:
+ parser.error("--meta2midi requires --meta and --midi")
+ meta2midi(args.meta, args.midi)
+ elif args.midi2meta:
+ if not args.midi or not args.meta or not args.vocal:
+ parser.error(
+ "--midi2meta requires --midi, --meta, --vocal"
+ )
+ midi2meta(
+ args.midi,
+ args.meta,
+ args.vocal,
+ rmvpe_model_path=args.rmvpe_model_path,
+ device=args.device,
+ )
+ else:
+ parser.print_help()
\ No newline at end of file
diff --git a/preprocess/tools/note_transcription/__init__.py b/preprocess/tools/note_transcription/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/preprocess/tools/note_transcription/model.py b/preprocess/tools/note_transcription/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a20e2c931db3d23e7cf93a0d56bc35883a246c2
--- /dev/null
+++ b/preprocess/tools/note_transcription/model.py
@@ -0,0 +1,522 @@
+# https://github.com/RickyL-2000/ROSVOT
+import math
+import sys
+import traceback
+import json
+import time
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import librosa
+import numpy as np
+import torch
+import matplotlib.pyplot as plt
+
+from .utils.os_utils import safe_path
+from .utils.commons.hparams import set_hparams
+from .utils.commons.ckpt_utils import load_ckpt
+from .utils.commons.dataset_utils import pad_or_cut_xd
+from .utils.audio.mel import MelNet
+from .utils.audio.pitch_utils import (
+ norm_interp_f0,
+ denorm_f0,
+ f0_to_coarse,
+ boundary2Interval,
+ save_midi,
+ midi_to_hz,
+)
+from .utils.rosvot_utils import (
+ get_mel_len,
+ align_word,
+ regulate_real_note_itv,
+ regulate_ill_slur,
+ bd_to_durs,
+)
+from .modules.pe.rmvpe import RMVPE
+from .modules.rosvot.rosvot import MidiExtractor, WordbdExtractor
+
+
+@torch.no_grad()
+def infer_sample(
+ item: Dict[str, Any],
+ hparams: Dict[str, Any],
+ models: Dict[str, Any],
+ device: torch.device,
+ *,
+ save_dir: Optional[str] = None,
+ apply_rwbd: Optional[bool] = None,
+ # outputs
+ save_plot: bool = False,
+ no_save_midi: bool = True,
+ no_save_npy: bool = True,
+ verbose: bool = False,
+) -> Dict[str, Any]:
+ if "item_name" not in item or "wav_fn" not in item:
+ raise ValueError('item must contain keys: "item_name" and "wav_fn"')
+
+ item_name = item["item_name"]
+ wav_src = item["wav_fn"]
+
+ # Decide RWBD usage
+ if apply_rwbd is None:
+ apply_rwbd_ = ("word_durs" not in item)
+ else:
+ apply_rwbd_ = bool(apply_rwbd)
+
+ # Models
+ model = models["model"]
+ mel_net = models["mel_net"]
+ pe = models.get("pe")
+ wbd_predictor = models.get("wbd_predictor")
+
+ if wbd_predictor is None and apply_rwbd_:
+ raise ValueError("apply_rwbd is True but wbd_predictor model is not provided in models")
+
+ # ---- Prepare Data ----
+ if isinstance(wav_src, str):
+ wav, _ = librosa.core.load(wav_src, sr=hparams["audio_sample_rate"])
+ else:
+ wav = wav_src
+ if not isinstance(wav, np.ndarray):
+ wav = np.asarray(wav)
+ wav = wav.astype(np.float32)
+
+ # Calculate timestamps and alignment lengths
+ wav_len_samples = wav.shape[-1]
+ mel_len = get_mel_len(wav_len_samples, hparams["hop_size"])
+
+ # Word boundary preparation
+ mel2word = None
+ word_durs_filtered = None
+
+ if not apply_rwbd_:
+ if "word_durs" not in item:
+ raise ValueError('apply_rwbd=False but item has no "word_durs"')
+
+ wd_raw = list(item["word_durs"])
+ min_word_dur = hparams.get("min_word_dur", 20) / 1000
+ word_durs_filtered = []
+
+ for i, wd in enumerate(wd_raw):
+ if wd < min_word_dur:
+ if i == 0 and len(wd_raw) > 1:
+ wd_raw[i + 1] += wd
+ elif len(word_durs_filtered) > 0:
+ word_durs_filtered[-1] += wd
+ else:
+ word_durs_filtered.append(wd)
+
+ mel2word, _ = align_word(word_durs_filtered, mel_len, hparams["hop_size"], hparams["audio_sample_rate"])
+ mel2word = np.asarray(mel2word)
+ if mel2word.size > 0 and mel2word[0] == 0:
+ mel2word = mel2word + 1
+
+ mel2word_len = int(np.sum(mel2word > 0))
+ real_len = min(mel_len, mel2word_len)
+ else:
+ real_len = min(mel_len, hparams["max_frames"])
+
+ T = math.ceil(min(real_len, hparams["max_frames"]) / hparams["frames_multiple"]) * hparams["frames_multiple"]
+
+ # ---- Input Tensors & Padding ----
+ target_samples = T * hparams["hop_size"]
+ wav_t = torch.from_numpy(wav).float().to(device).unsqueeze(0) # [1, L]
+ if wav_t.shape[-1] < target_samples:
+ wav_t = pad_or_cut_xd(wav_t, target_samples, 1)
+
+ # ---- Pitch Extraction ----
+ if pe is not None:
+ f0s, uvs = pe.get_pitch_batch(
+ wav_t,
+ sample_rate=hparams["audio_sample_rate"],
+ hop_size=hparams["hop_size"],
+ lengths=[real_len],
+ fmax=hparams["f0_max"],
+ fmin=hparams["f0_min"],
+ )
+ f0_1d, uv_1d = norm_interp_f0(f0s[0][:T])
+ f0_t = pad_or_cut_xd(torch.FloatTensor(f0_1d).to(device), T, 0).unsqueeze(0)
+ uv_t = pad_or_cut_xd(torch.FloatTensor(uv_1d).to(device), T, 0).long().unsqueeze(0)
+ pitch_coarse = f0_to_coarse(denorm_f0(f0_t, uv_t)).to(device)
+ f0_np = denorm_f0(f0_t, uv_t)[0].detach().cpu().numpy()[:real_len]
+ else:
+ f0_t = uv_t = pitch_coarse = None
+ f0_np = None
+
+ # ---- Mel Extraction ----
+ mel = mel_net(wav_t) # [1, T_padded, C]
+ mel = pad_or_cut_xd(mel, T, 1)
+
+ # Construct non-padding mask
+ mel_nonpadding_mask = torch.zeros(1, T, device=device)
+ mel_nonpadding_mask[:, :real_len] = 1.0
+
+ # Apply mask to mel (zero out padding)
+ mel = (mel.transpose(1, 2) * mel_nonpadding_mask.unsqueeze(1)).transpose(1, 2)
+ # Re-calculate non_padding bool mask
+ mel_nonpadding = mel.abs().sum(-1) > 0
+
+ # ---- Word Boundary ----
+ word_durs_used = None
+ if apply_rwbd_:
+ mel_input = mel[:, :, : hparams.get("wbd_use_mel_bins", 80)]
+ wbd_outputs = wbd_predictor(
+ mel=mel_input,
+ pitch=pitch_coarse,
+ uv=uv_t,
+ non_padding=mel_nonpadding,
+ train=False,
+ )
+ word_bd = wbd_outputs["word_bd_pred"] # [1, T]
+ else:
+ # Construct word_bd from provided durs
+ mel2word_t = pad_or_cut_xd(torch.LongTensor(mel2word).to(device), T, 0)
+ word_bd = torch.zeros_like(mel2word_t)
+ # Vectorized check
+ word_bd[1:] = (mel2word_t[1:] != mel2word_t[:-1]).long()
+ word_bd[real_len:] = 0
+ word_bd = word_bd.unsqueeze(0) # [1, T]
+
+ word_durs_used = np.array(word_durs_filtered)
+
+ # ---- Main Inference ----
+ mel_input = mel[:, :, : hparams.get("use_mel_bins", 80)]
+ outputs = model(
+ mel=mel_input,
+ word_bd=word_bd,
+ pitch=pitch_coarse,
+ uv=uv_t,
+ non_padding=mel_nonpadding,
+ train=False,
+ )
+
+ note_lengths = outputs["note_lengths"].detach().cpu().numpy()
+ note_bd_pred = outputs["note_bd_pred"][0].detach().cpu().numpy()[:real_len]
+ note_pred = outputs["note_pred"][0].detach().cpu().numpy()[: note_lengths[0]]
+ note_bd_logits = torch.sigmoid(outputs["note_bd_logits"])[0].detach().cpu().numpy()[:real_len]
+
+ if note_pred.shape == (0,):
+ if verbose:
+ print(f"skip {item_name}: no notes detected")
+ return {
+ "item_name": item_name,
+ "pitches": [],
+ "note_durs": [],
+ "note2words": None,
+ }
+
+ # ---- Post-Processing & Regulation ----
+ note_itv_pred = boundary2Interval(note_bd_pred)
+ note2words = None
+
+ if apply_rwbd_:
+ word_bd_np = outputs['word_bd_pred'][0].detach().cpu().numpy()[:real_len]
+ word_durs_derived = np.array(bd_to_durs(word_bd_np)) * hparams['hop_size'] / hparams['audio_sample_rate']
+ word_durs_for_reg = word_durs_derived
+ word_bd_for_reg = word_bd_np
+ else:
+ word_bd_for_reg = word_bd[0].detach().cpu().numpy()[:real_len]
+ word_durs_for_reg = word_durs_used
+
+ should_regulate = hparams.get("infer_regulate_real_note_itv", True) and (not apply_rwbd_)
+
+ if should_regulate and (word_durs_for_reg is not None):
+ try:
+ note_itv_pred_secs, note2words = regulate_real_note_itv(
+ note_itv_pred,
+ note_bd_pred,
+ word_bd_for_reg,
+ word_durs_for_reg,
+ hparams["hop_size"],
+ hparams["audio_sample_rate"],
+ )
+ note_pred, note_itv_pred_secs, note2words = regulate_ill_slur(note_pred, note_itv_pred_secs, note2words)
+ except Exception as err:
+ if verbose:
+ _, exc_value, exc_tb = sys.exc_info()
+ tb = traceback.extract_tb(exc_tb)[-1]
+ print(f"postprocess failed: {err}: {exc_value} in {tb[0]}:{tb[1]} '{tb[2]}' in {tb[3]}")
+ # Fallback
+ note_itv_pred_secs = note_itv_pred * hparams["hop_size"] / hparams["audio_sample_rate"]
+ note2words = None
+ else:
+ note_itv_pred_secs = note_itv_pred * hparams["hop_size"] / hparams["audio_sample_rate"]
+
+ # ---- Output ----
+ note_durs = [float((itv[1] - itv[0])) for itv in note_itv_pred_secs]
+
+ out = {
+ "item_name": item_name,
+ "pitches": note_pred.tolist(),
+ "note_durs": note_durs,
+ "note2words": note2words.tolist() if note2words is not None else None,
+ }
+
+ # ---- Saving ----
+ if save_dir is not None:
+ save_dir_path = Path(save_dir)
+ save_dir_path.mkdir(parents=True, exist_ok=True)
+ fn = str(item_name)
+
+ if not no_save_midi:
+ save_midi(note_pred, note_itv_pred_secs, safe_path(save_dir_path / "midi" / f"{fn}.mid"))
+
+ if not no_save_npy:
+ np.save(safe_path(save_dir_path / "npy" / f"[note]{fn}.npy"), out, allow_pickle=True)
+
+ if save_plot:
+ fig = plt.figure()
+ if f0_np is not None:
+ plt.plot(f0_np, color="red", label="f0")
+
+ midi_pred = np.zeros(note_bd_pred.shape[0], dtype=np.float32)
+ itvs = np.round(note_itv_pred_secs * hparams["audio_sample_rate"] / hparams["hop_size"]).astype(int)
+ for i, itv in enumerate(itvs):
+ midi_pred[itv[0] : itv[1]] = note_pred[i]
+ plt.plot(midi_to_hz(midi_pred), color="blue", label="pred midi")
+ plt.plot(note_bd_logits * 100, color="green", label="note bd logits x100")
+ plt.legend()
+ plt.tight_layout()
+ plt.savefig(safe_path(save_dir_path / "plot" / f"[MIDI]{fn}.png"), format="png")
+ plt.close(fig)
+
+ return out
+
+
+def load_rosvot_models(ckpt, config="", wbd_ckpt="", wbd_config="", device="cuda:0", verbose=False, thr=0.85):
+ """
+ Load models once to reuse across multiple items.
+ """
+ dev = torch.device(device)
+
+ # 1. Hparams
+ config_path = Path(ckpt).with_name("config.yaml") if config == "" else config
+ pe_ckpt = Path(ckpt).parent.parent / "rmvpe/model.pt"
+ hparams = set_hparams(
+ config=config_path,
+ print_hparams=verbose,
+ hparams_str=f"note_bd_threshold={thr}",
+ )
+
+ # 2. Main Model
+ model = MidiExtractor(hparams)
+ load_ckpt(model, ckpt, verbose=verbose)
+ model.eval().to(dev)
+
+ # 3. MelNet
+ mel_net = MelNet(hparams)
+ mel_net.to(dev)
+
+ # 4. Pitch Extractor
+ pe = None
+ if hparams.get("use_pitch_embed", False):
+ pe = RMVPE(pe_ckpt, device=dev)
+
+ # 5. Word Boundary Predictor (optional but we load if ckpt provided or needed)
+ wbd_predictor = None
+ if wbd_ckpt:
+ wbd_config_path = Path(wbd_ckpt).with_name("config.yaml") if wbd_config == "" else wbd_config
+ wbd_hparams = set_hparams(
+ config=wbd_config_path,
+ print_hparams=False,
+ hparams_str="",
+ )
+ hparams.update({
+ "wbd_use_mel_bins": wbd_hparams["use_mel_bins"],
+ "min_word_dur": wbd_hparams["min_word_dur"],
+ })
+ wbd_predictor = WordbdExtractor(wbd_hparams)
+ load_ckpt(wbd_predictor, wbd_ckpt, verbose=verbose)
+ wbd_predictor.eval().to(dev)
+
+ models = {
+ "model": model,
+ "mel_net": mel_net,
+ "pe": pe,
+ "wbd_predictor": wbd_predictor
+ }
+ return hparams, models
+
+
+class NoteTranscriber:
+ """Note transcription wrapper based on ROSVOT.
+
+ Loads ROSVOT and optional RWBD models once in ``__init__`` and
+ exposes a :py:meth:`process` API that turns an item dict into
+ aligned note metadata for downstream SVS.
+ """
+
+ def __init__(
+ self,
+ rosvot_model_path: str,
+ rwbd_model_path: str,
+ *,
+ rosvot_config_path: str = "",
+ rwbd_config_path: str = "",
+ device: str = "cuda:0",
+ thr: float = 0.85,
+ verbose: bool = True,
+ ):
+ """Initialize the note transcriber.
+
+ Args:
+ ckpt: Path to the main ROSVOT checkpoint.
+ config: Optional config YAML path for ROSVOT.
+ wbd_ckpt: Optional word-boundary checkpoint path.
+ wbd_config: Optional config YAML path for RWBD.
+ device: Torch device string, e.g. ``"cuda:0"`` / ``"cpu"``.
+ thr: Note boundary threshold.
+ verbose: Whether to print verbose logs.
+ """
+ self.verbose = verbose
+ self.device = torch.device(device)
+ self.hparams, self.models = load_rosvot_models(
+ ckpt=rosvot_model_path,
+ config=rosvot_config_path,
+ wbd_ckpt=rwbd_model_path,
+ wbd_config=rwbd_config_path,
+ device=device,
+ verbose=verbose,
+ thr=thr,
+ )
+
+ if self.verbose:
+ print(
+ "[note transcription] init success:",
+ f"device={self.device}",
+ f"rosvot_model_path={rosvot_model_path}",
+ f"rwbd_model_path={rwbd_model_path if rwbd_model_path else 'None'}",
+ f"thr={thr}",
+ )
+
+ def process(
+ self,
+ item: Dict[str, Any],
+ *,
+ segment_info: Optional[Dict[str, Any]] = None,
+ save_dir: Optional[str] = None,
+ apply_rwbd: Optional[bool] = None,
+ save_plot: bool = False,
+ no_save_midi: bool = True,
+ no_save_npy: bool = True,
+ verbose: Optional[bool] = None,
+ ) -> Dict[str, Any]:
+ """Run ROSVOT on a single item and post-process outputs.
+
+ Args:
+ item: Input metadata dict with at least ``item_name`` and ``wav_fn``.
+ segment_info: Optional segment metadata for sliced audio.
+ save_dir: Optional directory for debug artifacts (plots, midis).
+ apply_rwbd: Whether to run RWBD-based word boundary refinement.
+ save_plot: Whether to save diagnostic plots.
+ no_save_midi: If True, skip saving midi.
+ no_save_npy: If True, skip saving numpy intermediates.
+ verbose: Override instance-level verbose flag for this call.
+
+ Returns:
+ Dict with aligned note information for downstream SVS.
+ """
+ v = self.verbose if verbose is None else verbose
+ if v:
+ item_name = item.get("item_name", "")
+ wav_fn = item.get("wav_fn", "")
+ print(f"[note transcription] process: start: item_name={item_name} wav_fn={wav_fn}")
+ t0 = time.time()
+
+ rosvot_out = infer_sample(
+ item,
+ self.hparams,
+ self.models,
+ device=self.device,
+ save_dir=save_dir,
+ apply_rwbd=apply_rwbd,
+ save_plot=save_plot,
+ no_save_midi=no_save_midi,
+ no_save_npy=no_save_npy,
+ verbose=v,
+ )
+
+ out = self.post_process(
+ metadata=item,
+ segment_info=segment_info,
+ rosvot_out=rosvot_out,
+ )
+
+ if v:
+ dt = time.time() - t0
+ print(
+ "[note transcription] process: done:",
+ f"item_name={out.get('item_name','')}",
+ f"n_notes={len(out.get('note_pitch', []) or [])}",
+ f"time={dt:.3f}s",
+ )
+
+ return out
+
+ @staticmethod
+ def _normalize_note2words(note2words: list[int]) -> list[int]:
+ if not note2words:
+ return []
+ normalized = [note2words[0]]
+ for idx in range(1, len(note2words)):
+ if note2words[idx] < normalized[-1]:
+ normalized.append(normalized[-1])
+ else:
+ normalized.append(note2words[idx])
+ return normalized
+
+ @staticmethod
+ def _build_ep_types(note2words: list[int], align_words: list[str]) -> list[int]:
+ ep_types: list[int] = []
+ prev = -1
+ for i, w in zip(note2words, align_words):
+ if w == "":
+ ep_types.append(1)
+ else:
+ ep_types.append(2 if i != prev else 3)
+ prev = i
+ return ep_types
+
+ def post_process(
+ self,
+ *,
+ metadata: Dict[str, Any],
+ segment_info: Dict[str, Any],
+ rosvot_out: Dict[str, Any],
+ ) -> Dict[str, Any]:
+ """Build aligned note metadata using ROSVOT outputs."""
+ note2words_raw = rosvot_out.get("note2words") or []
+ note2words = self._normalize_note2words(note2words_raw)
+ align_words = [
+ metadata["words"][idx - 1]
+ for idx in note2words_raw
+ if 0 < idx <= len(metadata["words"])
+ ]
+ ep_types = self._build_ep_types(note2words, align_words) if align_words else []
+
+ return {
+ "item_name": rosvot_out.get("item_name", "") if not segment_info else segment_info["item_name"],
+ "wav_fn": metadata.get("wav_fn", "") if not segment_info else segment_info["wav_fn"],
+ "origin_wav_fn": metadata.get("origin_wav_fn", "") if not segment_info else segment_info["origin_wav_fn"],
+ "start_time_ms": "" if not segment_info else segment_info["start_time_ms"],
+ "end_time_ms": "" if not segment_info else segment_info["end_time_ms"],
+ "language": metadata.get("language", ""),
+ "note_text": align_words,
+ "note_dur": rosvot_out.get("note_durs", []),
+ "note_type": ep_types,
+ "note_pitch": rosvot_out.get("pitches", []),
+ }
+
+if __name__ == "__main__":
+
+ items = json.load(open("example/test/rosvot_input.json", "r"))
+ item = items[0]
+
+ m = NoteTranscriber(
+ rosvot_model_path="pretrained_models/rosvot/rosvot/model.pt",
+ rwbd_model_path="pretrained_models/rosvot/rwbd/model.pt",
+ device="cuda"
+ )
+ out = m.process(item)
+
+ print(out)
\ No newline at end of file
diff --git a/preprocess/tools/note_transcription/modules/__init__.py b/preprocess/tools/note_transcription/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a55c2276c9c1af3679908bcfb25306eedd074e1
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/__init__.py
@@ -0,0 +1 @@
+"""ROSVOT model submodules."""
diff --git a/preprocess/tools/note_transcription/modules/commons/__init__.py b/preprocess/tools/note_transcription/modules/commons/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef5ce84798fc4c79e6ec30d0dd4c703f01ecc40e
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/commons/__init__.py
@@ -0,0 +1 @@
+"""Common ROSVOT layers and utilities."""
diff --git a/preprocess/tools/note_transcription/modules/commons/conformer/__init__.py b/preprocess/tools/note_transcription/modules/commons/conformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba9229ac0dd0cb6d9ee036b4b0f8a41a173d61ef
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/commons/conformer/__init__.py
@@ -0,0 +1 @@
+"""Conformer layers for ROSVOT."""
diff --git a/preprocess/tools/note_transcription/modules/commons/conformer/conformer.py b/preprocess/tools/note_transcription/modules/commons/conformer/conformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..23d837a478efa8eca32b563499184186d03fc04d
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/commons/conformer/conformer.py
@@ -0,0 +1,96 @@
+from torch import nn
+from .espnet_positional_embedding import RelPositionalEncoding, ScaledPositionalEncoding, PositionalEncoding
+from .espnet_transformer_attn import RelPositionMultiHeadedAttention, MultiHeadedAttention
+from .layers import Swish, ConvolutionModule, EncoderLayer, MultiLayeredConv1d
+from ..layers import Embedding
+
+
+class ConformerLayers(nn.Module):
+ def __init__(self, hidden_size, num_layers, kernel_size=9, dropout=0.0, num_heads=4,
+ use_last_norm=True, save_hidden=False):
+ super().__init__()
+ self.use_last_norm = use_last_norm
+ self.layers = nn.ModuleList()
+ positionwise_layer = MultiLayeredConv1d
+ positionwise_layer_args = (hidden_size, hidden_size * 4, 1, dropout)
+ self.pos_embed = RelPositionalEncoding(hidden_size, dropout)
+ self.encoder_layers = nn.ModuleList([EncoderLayer(
+ hidden_size,
+ RelPositionMultiHeadedAttention(num_heads, hidden_size, 0.0),
+ positionwise_layer(*positionwise_layer_args),
+ positionwise_layer(*positionwise_layer_args),
+ ConvolutionModule(hidden_size, kernel_size, Swish()),
+ dropout,
+ ) for _ in range(num_layers)])
+ if self.use_last_norm:
+ self.layer_norm = nn.LayerNorm(hidden_size)
+ else:
+ self.layer_norm = nn.Linear(hidden_size, hidden_size)
+ self.save_hidden = save_hidden
+ if save_hidden:
+ self.hiddens = []
+
+ def forward(self, x, padding_mask=None):
+ """
+
+ :param x: [B, T, H]
+ :param padding_mask: [B, T]
+ :return: [B, T, H]
+ """
+ self.hiddens = []
+ nonpadding_mask = x.abs().sum(-1) > 0
+ x = self.pos_embed(x)
+ for l in self.encoder_layers:
+ x, mask = l(x, nonpadding_mask[:, None, :])
+ if self.save_hidden:
+ self.hiddens.append(x[0])
+ x = x[0]
+ x = self.layer_norm(x) * nonpadding_mask.float()[:, :, None]
+ return x
+
+class FastConformerLayers(ConformerLayers):
+ def __init__(self, hidden_size, num_layers, kernel_size=9, dropout=0.0, num_heads=4,
+ use_last_norm=True, save_hidden=False):
+ super(ConformerLayers, self).__init__()
+ self.use_last_norm = use_last_norm
+ self.layers = nn.ModuleList()
+ positionwise_layer = MultiLayeredConv1d
+ positionwise_layer_args = (hidden_size, hidden_size * 4, 1, dropout)
+ self.pos_embed = PositionalEncoding(hidden_size, dropout)
+ self.encoder_layers = nn.ModuleList([EncoderLayer(
+ hidden_size,
+ MultiHeadedAttention(num_heads, hidden_size, 0.0, flash=True),
+ positionwise_layer(*positionwise_layer_args),
+ positionwise_layer(*positionwise_layer_args),
+ ConvolutionModule(hidden_size, kernel_size, Swish()),
+ dropout,
+ ) for _ in range(num_layers)])
+ if self.use_last_norm:
+ self.layer_norm = nn.LayerNorm(hidden_size)
+ else:
+ self.layer_norm = nn.Linear(hidden_size, hidden_size)
+ self.save_hidden = save_hidden
+ if save_hidden:
+ self.hiddens = []
+
+class ConformerEncoder(ConformerLayers):
+ def __init__(self, hidden_size, dict_size, num_layers=None):
+ conformer_enc_kernel_size = 9
+ super().__init__(hidden_size, num_layers, conformer_enc_kernel_size)
+ self.embed = Embedding(dict_size, hidden_size, padding_idx=0)
+
+ def forward(self, x):
+ """
+
+ :param src_tokens: [B, T]
+ :return: [B x T x C]
+ """
+ x = self.embed(x) # [B, T, H]
+ x = super(ConformerEncoder, self).forward(x)
+ return x
+
+
+class ConformerDecoder(ConformerLayers):
+ def __init__(self, hidden_size, num_layers):
+ conformer_dec_kernel_size = 9
+ super().__init__(hidden_size, num_layers, conformer_dec_kernel_size)
diff --git a/preprocess/tools/note_transcription/modules/commons/conformer/espnet_positional_embedding.py b/preprocess/tools/note_transcription/modules/commons/conformer/espnet_positional_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..89b9b5549cc779d1ea67f052b1c99cad92365503
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/commons/conformer/espnet_positional_embedding.py
@@ -0,0 +1,113 @@
+import math
+import torch
+
+
+class PositionalEncoding(torch.nn.Module):
+ """Positional encoding.
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+ reverse (bool): Whether to reverse the input position.
+ """
+
+ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
+ """Construct an PositionalEncoding object."""
+ super(PositionalEncoding, self).__init__()
+ self.d_model = d_model
+ self.reverse = reverse
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+ def extend_pe(self, x):
+ """Reset the positional encodings."""
+ if self.pe is not None:
+ if self.pe.size(1) >= x.size(1):
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ pe = torch.zeros(x.size(1), self.d_model)
+ if self.reverse:
+ position = torch.arange(
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
+ ).unsqueeze(1)
+ else:
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
+ * -(math.log(10000.0) / self.d_model)
+ )
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
+
+ def forward(self, x: torch.Tensor):
+ """Add positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ """
+ self.extend_pe(x)
+ x = x * self.xscale + self.pe[:, : x.size(1)]
+ return self.dropout(x)
+
+
+class ScaledPositionalEncoding(PositionalEncoding):
+ """Scaled positional encoding module.
+ See Sec. 3.2 https://arxiv.org/abs/1809.08895
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+ """
+
+ def __init__(self, d_model, dropout_rate, max_len=5000):
+ """Initialize class."""
+ super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
+ self.alpha = torch.nn.Parameter(torch.tensor(1.0))
+
+ def reset_parameters(self):
+ """Reset parameters."""
+ self.alpha.data = torch.tensor(1.0)
+
+ def forward(self, x):
+ """Add positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ """
+ self.extend_pe(x)
+ x = x + self.alpha * self.pe[:, : x.size(1)]
+ return self.dropout(x)
+
+
+class RelPositionalEncoding(PositionalEncoding):
+ """Relative positional encoding module.
+ See : Appendix B in https://arxiv.org/abs/1901.02860
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+ """
+
+ def __init__(self, d_model, dropout_rate, max_len=5000):
+ """Initialize class."""
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
+
+ def forward(self, x):
+ """Compute positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
+ """
+ self.extend_pe(x)
+ x = x * self.xscale
+ pos_emb = self.pe[:, : x.size(1)]
+ return self.dropout(x), self.dropout(pos_emb)
\ No newline at end of file
diff --git a/preprocess/tools/note_transcription/modules/commons/conformer/espnet_transformer_attn.py b/preprocess/tools/note_transcription/modules/commons/conformer/espnet_transformer_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..7355f19b73040b335fc8c73ab1865a129a2da569
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/commons/conformer/espnet_transformer_attn.py
@@ -0,0 +1,198 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Shigeki Karita
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Multi-Head Attention layer definition."""
+
+from packaging import version
+import math
+
+import numpy
+import torch
+from torch import nn
+
+
+class MultiHeadedAttention(nn.Module):
+ """Multi-Head Attention layer.
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ """
+
+ def __init__(self, n_head, n_feat, dropout_rate, flash=False):
+ """Construct an MultiHeadedAttention object."""
+ super(MultiHeadedAttention, self).__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ self.linear_k = nn.Linear(n_feat, n_feat)
+ self.linear_v = nn.Linear(n_feat, n_feat)
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.attn = None
+ self.dropout = nn.Dropout(p=dropout_rate)
+ self.dropout_rate = dropout_rate
+ self.flash = flash
+
+ def forward_qkv(self, query, key, value):
+ """Transform query, key and value.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ Returns:
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+ """
+ n_batch = query.size(0)
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
+
+ return q, k, v
+
+ def forward_attention(self, value, scores, mask):
+ """Compute attention context vector.
+ Args:
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+ """
+ n_batch = value.size(0)
+ if mask is not None:
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+ )
+ scores = scores.masked_fill(mask, min_value)
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0
+ ) # (batch, head, time1, time2)
+ else:
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(self.attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(self, query, key, value, mask):
+ """Compute scaled dot product attention.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ if version.parse(torch.__version__) >= version.parse("2.0") and self.flash:
+ n_batch = value.size(0)
+ x = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, attn_mask=mask.unsqueeze(1) if mask is not None else None, dropout_p=self.dropout_rate)
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ ) # (batch, time1, d_model)
+ return self.linear_out(x)
+ else:
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+ return self.forward_attention(v, scores, mask)
+
+
+class RelPositionMultiHeadedAttention(MultiHeadedAttention):
+ """Multi-Head Attention layer with relative position encoding.
+ Paper: https://arxiv.org/abs/1901.02860
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ """
+
+ def __init__(self, n_head, n_feat, dropout_rate):
+ """Construct an RelPositionMultiHeadedAttention object."""
+ super().__init__(n_head, n_feat, dropout_rate)
+ # linear transformation for positional ecoding
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+ def rel_shift(self, x, zero_triu=False):
+ """Compute relative positinal encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, size).
+ zero_triu (bool): If true, return the lower triangular part of the matrix.
+ Returns:
+ torch.Tensor: Output tensor.
+ """
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
+ x_padded = torch.cat([zero_pad, x], dim=-1)
+
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
+ x = x_padded[:, :, 1:].view_as(x)
+
+ if zero_triu:
+ ones = torch.ones((x.size(2), x.size(3)))
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
+
+ return x
+
+ def forward(self, query, key, value, pos_emb, mask):
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ pos_emb (torch.Tensor): Positional embedding tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
+
+ n_batch_pos = pos_emb.size(0)
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
+
+ # (batch, head, time1, d_k)
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ # (batch, head, time1, d_k)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # (batch, head, time1, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+ # compute matrix b and matrix d
+ # (batch, head, time1, time2)
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+ matrix_bd = self.rel_shift(matrix_bd)
+
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
+ self.d_k
+ ) # (batch, head, time1, time2)
+
+ return self.forward_attention(v, scores, mask)
diff --git a/preprocess/tools/note_transcription/modules/commons/conformer/layers.py b/preprocess/tools/note_transcription/modules/commons/conformer/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..6226c43a8d1f337c5e714fd3c840ce6e4509c7a4
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/commons/conformer/layers.py
@@ -0,0 +1,260 @@
+from torch import nn
+import torch
+
+from ..layers import LayerNorm
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Conformer model.
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernerl size of conv layers.
+ """
+
+ def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
+ """Construct an ConvolutionModule object."""
+ super(ConvolutionModule, self).__init__()
+ # kernerl_size should be a odd number for 'SAME' padding
+ assert (kernel_size - 1) % 2 == 0
+
+ self.pointwise_conv1 = nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.depthwise_conv = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ groups=channels,
+ bias=bias,
+ )
+ self.norm = nn.BatchNorm1d(channels)
+ self.pointwise_conv2 = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.activation = activation
+
+ def forward(self, x):
+ """Compute convolution module.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, channels).
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, channels).
+ """
+ # exchange the temporal dimension and the feature dimension
+ x = x.transpose(1, 2)
+
+ # GLU mechanism
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
+
+ # 1D Depthwise Conv
+ x = self.depthwise_conv(x)
+ x = self.activation(self.norm(x))
+
+ x = self.pointwise_conv2(x)
+
+ return x.transpose(1, 2)
+
+
+class MultiLayeredConv1d(torch.nn.Module):
+ """Multi-layered conv1d for Transformer block.
+ This is a module of multi-leyered conv1d designed
+ to replace positionwise feed-forward network
+ in Transforner block, which is introduced in
+ `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
+ https://arxiv.org/pdf/1905.09263.pdf
+ """
+
+ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
+ """Initialize MultiLayeredConv1d module.
+ Args:
+ in_chans (int): Number of input channels.
+ hidden_chans (int): Number of hidden channels.
+ kernel_size (int): Kernel size of conv1d.
+ dropout_rate (float): Dropout rate.
+ """
+ super(MultiLayeredConv1d, self).__init__()
+ self.w_1 = torch.nn.Conv1d(
+ in_chans,
+ hidden_chans,
+ kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ )
+ self.w_2 = torch.nn.Conv1d(
+ hidden_chans,
+ in_chans,
+ kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ )
+ self.dropout = torch.nn.Dropout(dropout_rate)
+
+ def forward(self, x):
+ """Calculate forward propagation.
+ Args:
+ x (torch.Tensor): Batch of input tensors (B, T, in_chans).
+ Returns:
+ torch.Tensor: Batch of output tensors (B, T, hidden_chans).
+ """
+ x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
+ return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
+
+
+class Swish(torch.nn.Module):
+ """Construct an Swish object."""
+
+ def forward(self, x):
+ """Return Swich activation function."""
+ return x * torch.sigmoid(x)
+
+
+class EncoderLayer(nn.Module):
+ """Encoder layer module.
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
+ can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ conv_module (torch.nn.Module): Convolution module instance.
+ `ConvlutionModule` instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool): Whether to use layer_norm before the first block.
+ concat_after (bool): Whether to concat attention layer's input and output.
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
+ """
+
+ def __init__(
+ self,
+ size,
+ self_attn,
+ feed_forward,
+ feed_forward_macaron,
+ conv_module,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ ):
+ """Construct an EncoderLayer object."""
+ super(EncoderLayer, self).__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.feed_forward_macaron = feed_forward_macaron
+ self.conv_module = conv_module
+ self.norm_ff = LayerNorm(size) # for the FNN module
+ self.norm_mha = LayerNorm(size) # for the MHA module
+ if feed_forward_macaron is not None:
+ self.norm_ff_macaron = LayerNorm(size)
+ self.ff_scale = 0.5
+ else:
+ self.ff_scale = 1.0
+ if self.conv_module is not None:
+ self.norm_conv = LayerNorm(size) # for the CNN module
+ self.norm_final = LayerNorm(size) # for the final output of the block
+ self.dropout = nn.Dropout(dropout_rate)
+ self.size = size
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear = nn.Linear(size + size, size)
+
+ def forward(self, x_input, mask, cache=None):
+ """Compute encoded features.
+ Args:
+ x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
+ - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
+ - w/o pos emb: Tensor (#batch, time, size).
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time).
+ """
+ if isinstance(x_input, tuple):
+ x, pos_emb = x_input[0], x_input[1]
+ else:
+ x, pos_emb = x_input, None
+
+ # whether to use macaron style
+ if self.feed_forward_macaron is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff_macaron(x)
+ x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
+ if not self.normalize_before:
+ x = self.norm_ff_macaron(x)
+
+ # multi-headed self-attention module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_mha(x)
+
+ if cache is None:
+ x_q = x
+ else:
+ assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
+ x_q = x[:, -1:, :]
+ residual = residual[:, -1:, :]
+ mask = None if mask is None else mask[:, -1:, :]
+
+ if pos_emb is not None:
+ x_att = self.self_attn(x_q, x, x, pos_emb, mask)
+ else:
+ x_att = self.self_attn(x_q, x, x, mask)
+
+ if self.concat_after:
+ x_concat = torch.cat((x, x_att), dim=-1)
+ x = residual + self.concat_linear(x_concat)
+ else:
+ x = residual + self.dropout(x_att)
+ if not self.normalize_before:
+ x = self.norm_mha(x)
+
+ # convolution module
+ if self.conv_module is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_conv(x)
+ x = residual + self.dropout(self.conv_module(x))
+ if not self.normalize_before:
+ x = self.norm_conv(x)
+
+ # feed forward module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff(x)
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm_ff(x)
+
+ if self.conv_module is not None:
+ x = self.norm_final(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ if pos_emb is not None:
+ return (x, pos_emb), mask
+
+ return x, mask
diff --git a/preprocess/tools/note_transcription/modules/commons/conv.py b/preprocess/tools/note_transcription/modules/commons/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..04e16c62cd60b7d765de13165050919e5e8fa40f
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/commons/conv.py
@@ -0,0 +1,175 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .layers import LayerNorm, Embedding
+
+class LambdaLayer(nn.Module):
+ def __init__(self, lambd):
+ super(LambdaLayer, self).__init__()
+ self.lambd = lambd
+
+ def forward(self, x):
+ return self.lambd(x)
+
+def init_weights_func(m):
+ classname = m.__class__.__name__
+ if classname.find("Conv1d") != -1:
+ torch.nn.init.xavier_uniform_(m.weight)
+
+def get_norm_builder(norm_type, channels, ln_eps=1e-6):
+ if norm_type == 'bn':
+ norm_builder = lambda: nn.BatchNorm1d(channels)
+ elif norm_type == 'in':
+ norm_builder = lambda: nn.InstanceNorm1d(channels, affine=True)
+ elif norm_type == 'gn':
+ norm_builder = lambda: nn.GroupNorm(8, channels)
+ elif norm_type == 'ln':
+ norm_builder = lambda: LayerNorm(channels, dim=1, eps=ln_eps)
+ else:
+ norm_builder = lambda: nn.Identity()
+ return norm_builder
+
+def get_act_builder(act_type):
+ if act_type == 'gelu':
+ act_builder = lambda: nn.GELU()
+ elif act_type == 'relu':
+ act_builder = lambda: nn.ReLU(inplace=True)
+ elif act_type == 'leakyrelu':
+ act_builder = lambda: nn.LeakyReLU(negative_slope=0.01, inplace=True)
+ elif act_type == 'swish':
+ act_builder = lambda: nn.SiLU(inplace=True)
+ else:
+ act_builder = lambda: nn.Identity()
+ return act_builder
+
+class ResidualBlock(nn.Module):
+ """Implements conv->PReLU->norm n-times"""
+
+ def __init__(self, channels, kernel_size, dilation, n=2, norm_type='bn', dropout=0.0,
+ c_multiple=2, ln_eps=1e-12, act_type='gelu'):
+ super(ResidualBlock, self).__init__()
+
+ norm_builder = get_norm_builder(norm_type, channels, ln_eps)
+ act_builder = get_act_builder(act_type)
+
+ self.blocks = [
+ nn.Sequential(
+ norm_builder(),
+ nn.Conv1d(channels, c_multiple * channels, kernel_size, dilation=dilation,
+ padding=(dilation * (kernel_size - 1)) // 2),
+ LambdaLayer(lambda x: x * kernel_size ** -0.5),
+ act_builder(),
+ nn.Conv1d(c_multiple * channels, channels, 1, dilation=dilation),
+ )
+ for i in range(n)
+ ]
+
+ self.blocks = nn.ModuleList(self.blocks)
+ self.dropout = dropout
+
+ def forward(self, x):
+ nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
+ for b in self.blocks:
+ x_ = b(x)
+ if self.dropout > 0 and self.training:
+ x_ = F.dropout(x_, self.dropout, training=self.training)
+ x = x + x_
+ x = x * nonpadding
+ return x
+
+
+class ConvBlocks(nn.Module):
+ """Decodes the expanded phoneme encoding into spectrograms"""
+
+ def __init__(self, hidden_size, out_dims, dilations, kernel_size,
+ norm_type='ln', layers_in_block=2, c_multiple=2,
+ dropout=0.0, ln_eps=1e-5,
+ init_weights=True, is_BTC=True, num_layers=None, post_net_kernel=3, act_type='gelu'):
+ super(ConvBlocks, self).__init__()
+ self.is_BTC = is_BTC
+ if num_layers is not None:
+ dilations = [1] * num_layers
+ self.res_blocks = nn.Sequential(
+ *[ResidualBlock(hidden_size, kernel_size, d,
+ n=layers_in_block, norm_type=norm_type, c_multiple=c_multiple,
+ dropout=dropout, ln_eps=ln_eps, act_type=act_type)
+ for d in dilations],
+ )
+ norm = get_norm_builder(norm_type, hidden_size, ln_eps)()
+ self.last_norm = norm
+ self.post_net1 = nn.Conv1d(hidden_size, out_dims, kernel_size=post_net_kernel,
+ padding=post_net_kernel // 2)
+ if init_weights:
+ self.apply(init_weights_func)
+
+ def forward(self, x, nonpadding=None):
+ """
+
+ :param x: [B, T, H]
+ :return: [B, T, H]
+ """
+ if self.is_BTC:
+ x = x.transpose(1, 2)
+ if nonpadding is None:
+ nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
+ elif self.is_BTC:
+ nonpadding = nonpadding.transpose(1, 2)
+ x = self.res_blocks(x) * nonpadding
+ x = self.last_norm(x) * nonpadding
+ x = self.post_net1(x) * nonpadding
+ if self.is_BTC:
+ x = x.transpose(1, 2)
+ return x
+
+
+class TextConvEncoder(ConvBlocks):
+ def __init__(self, dict_size, hidden_size, out_dims, dilations, kernel_size,
+ norm_type='ln', layers_in_block=2, c_multiple=2,
+ dropout=0.0, ln_eps=1e-5, init_weights=True, num_layers=None, post_net_kernel=3):
+ super().__init__(hidden_size, out_dims, dilations, kernel_size,
+ norm_type, layers_in_block, c_multiple,
+ dropout, ln_eps, init_weights, num_layers=num_layers,
+ post_net_kernel=post_net_kernel)
+ self.embed_tokens = Embedding(dict_size, hidden_size, 0)
+ self.embed_scale = math.sqrt(hidden_size)
+
+ def forward(self, txt_tokens):
+ """
+
+ :param txt_tokens: [B, T]
+ :return: {
+ 'encoder_out': [B x T x C]
+ }
+ """
+ x = self.embed_scale * self.embed_tokens(txt_tokens)
+ return super().forward(x)
+
+
+class ConditionalConvBlocks(ConvBlocks):
+ def __init__(self, hidden_size, c_cond, c_out, dilations, kernel_size,
+ norm_type='ln', layers_in_block=2, c_multiple=2,
+ dropout=0.0, ln_eps=1e-5, init_weights=True, is_BTC=True, num_layers=None):
+ super().__init__(hidden_size, c_out, dilations, kernel_size,
+ norm_type, layers_in_block, c_multiple,
+ dropout, ln_eps, init_weights, is_BTC=False, num_layers=num_layers)
+ self.g_prenet = nn.Conv1d(c_cond, hidden_size, 3, padding=1)
+ self.is_BTC_ = is_BTC
+ if init_weights:
+ self.g_prenet.apply(init_weights_func)
+
+ def forward(self, x, cond, nonpadding=None):
+ if self.is_BTC_:
+ x = x.transpose(1, 2)
+ cond = cond.transpose(1, 2)
+ if nonpadding is not None:
+ nonpadding = nonpadding.transpose(1, 2)
+ if nonpadding is None:
+ nonpadding = x.abs().sum(1)[:, None]
+ x = x + self.g_prenet(cond)
+ x = x * nonpadding
+ x = super(ConditionalConvBlocks, self).forward(x) # input needs to be BTC
+ if self.is_BTC_:
+ x = x.transpose(1, 2)
+ return x
diff --git a/preprocess/tools/note_transcription/modules/commons/layers.py b/preprocess/tools/note_transcription/modules/commons/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..e516e2cde01d3c002fb9511f2a24e5ed277c29e4
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/commons/layers.py
@@ -0,0 +1,85 @@
+import torch
+from torch import nn
+from torch.autograd import Function
+
+class LayerNorm(torch.nn.LayerNorm):
+ """Layer normalization module.
+ :param int nout: output dim size
+ :param int dim: dimension to be normalized
+ """
+
+ def __init__(self, nout, dim=-1, eps=1e-5):
+ """Construct an LayerNorm object."""
+ super(LayerNorm, self).__init__(nout, eps=eps)
+ self.dim = dim
+
+ def forward(self, x):
+ """Apply layer normalization.
+ :param torch.Tensor x: input tensor
+ :return: layer normalized tensor
+ :rtype torch.Tensor
+ """
+ if self.dim == -1:
+ return super(LayerNorm, self).forward(x)
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
+
+
+class Reshape(nn.Module):
+ def __init__(self, *args):
+ super(Reshape, self).__init__()
+ self.shape = args
+
+ def forward(self, x):
+ return x.view(self.shape)
+
+
+class Permute(nn.Module):
+ def __init__(self, *args):
+ super(Permute, self).__init__()
+ self.args = args
+
+ def forward(self, x):
+ return x.permute(self.args)
+
+
+def Linear(in_features, out_features, bias=True, init_type='xavier'):
+ m = nn.Linear(in_features, out_features, bias)
+ if init_type == 'xavier':
+ nn.init.xavier_uniform_(m.weight)
+ elif init_type == 'kaiming':
+ nn.init.kaiming_normal_(m.weight, mode='fan_in')
+ if bias:
+ nn.init.constant_(m.bias, 0.)
+ return m
+
+
+def Embedding(num_embeddings, embedding_dim, padding_idx=None, init_type='normal'):
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
+ if init_type == 'normal':
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
+ elif init_type == 'kaiming':
+ nn.init.kaiming_normal_(m.weight, mode='fan_in')
+ if padding_idx is not None:
+ nn.init.constant_(m.weight[padding_idx], 0)
+ return m
+
+
+class GradientReverseFunction(Function):
+ @staticmethod
+ def forward(ctx, input, coeff=1.):
+ ctx.coeff = coeff
+ output = input * 1.0
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output.neg() * ctx.coeff, None
+
+
+class GRL(nn.Module):
+ def __init__(self):
+ super(GRL, self).__init__()
+
+ def forward(self, *input):
+ return GradientReverseFunction.apply(*input)
+
diff --git a/preprocess/tools/note_transcription/modules/commons/rel_transformer.py b/preprocess/tools/note_transcription/modules/commons/rel_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..809aa4c0e4d33399c7689ceacf531cfdcf562fc0
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/commons/rel_transformer.py
@@ -0,0 +1,378 @@
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .layers import Embedding
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def shift_1d(x):
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
+ return x
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+class Encoder(nn.Module):
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
+ window_size=None, block_length=None, pre_ln=False, **kwargs):
+ super().__init__()
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+ self.block_length = block_length
+ self.pre_ln = pre_ln
+
+ self.drop = nn.Dropout(p_dropout)
+ self.attn_layers = nn.ModuleList()
+ self.norm_layers_1 = nn.ModuleList()
+ self.ffn_layers = nn.ModuleList()
+ self.norm_layers_2 = nn.ModuleList()
+ for i in range(self.n_layers):
+ self.attn_layers.append(
+ MultiHeadAttention(hidden_channels, hidden_channels, n_heads, window_size=window_size,
+ p_dropout=p_dropout, block_length=block_length))
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
+ self.ffn_layers.append(
+ FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
+ if pre_ln:
+ self.last_ln = LayerNorm(hidden_channels)
+
+ def forward(self, x, x_mask):
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+ for i in range(self.n_layers):
+ x = x * x_mask
+ x_ = x
+ if self.pre_ln:
+ x = self.norm_layers_1[i](x)
+ y = self.attn_layers[i](x, x, attn_mask)
+ y = self.drop(y)
+ x = x_ + y
+ if not self.pre_ln:
+ x = self.norm_layers_1[i](x)
+
+ x_ = x
+ if self.pre_ln:
+ x = self.norm_layers_2[i](x)
+ y = self.ffn_layers[i](x, x_mask)
+ y = self.drop(y)
+ x = x_ + y
+ if not self.pre_ln:
+ x = self.norm_layers_2[i](x)
+ if self.pre_ln:
+ x = self.last_ln(x)
+ x = x * x_mask
+ return x
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(self, channels, out_channels, n_heads, window_size=None, heads_share=True, p_dropout=0.,
+ block_length=None, proximal_bias=False, proximal_init=False):
+ super().__init__()
+ assert channels % n_heads == 0
+
+ self.channels = channels
+ self.out_channels = out_channels
+ self.n_heads = n_heads
+ self.window_size = window_size
+ self.heads_share = heads_share
+ self.block_length = block_length
+ self.proximal_bias = proximal_bias
+ self.p_dropout = p_dropout
+ self.attn = None
+
+ self.k_channels = channels // n_heads
+ self.conv_q = nn.Conv1d(channels, channels, 1)
+ self.conv_k = nn.Conv1d(channels, channels, 1)
+ self.conv_v = nn.Conv1d(channels, channels, 1)
+ if window_size is not None:
+ n_heads_rel = 1 if heads_share else n_heads
+ rel_stddev = self.k_channels ** -0.5
+ self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
+ self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
+ self.drop = nn.Dropout(p_dropout)
+
+ nn.init.xavier_uniform_(self.conv_q.weight)
+ nn.init.xavier_uniform_(self.conv_k.weight)
+ if proximal_init:
+ self.conv_k.weight.data.copy_(self.conv_q.weight.data)
+ self.conv_k.bias.data.copy_(self.conv_q.bias.data)
+ nn.init.xavier_uniform_(self.conv_v.weight)
+
+ def forward(self, x, c, attn_mask=None):
+ q = self.conv_q(x)
+ k = self.conv_k(c)
+ v = self.conv_v(c)
+
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
+
+ x = self.conv_o(x)
+ return x
+
+ def attention(self, query, key, value, mask=None):
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
+ b, d, t_s, t_t = (*key.size(), query.size(2))
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
+ if self.window_size is not None:
+ assert t_s == t_t, "Relative attention is only available for self-attention."
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
+ rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
+ rel_logits = self._relative_position_to_absolute_position(rel_logits)
+ scores_local = rel_logits / math.sqrt(self.k_channels)
+ scores = scores + scores_local
+ if self.proximal_bias:
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
+ if mask is not None:
+ scores = scores.masked_fill(mask == 0, -1e4)
+ if self.block_length is not None:
+ block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
+ scores = scores * block_mask + -1e4 * (1 - block_mask)
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
+ p_attn = self.drop(p_attn)
+ output = torch.matmul(p_attn, value)
+ if self.window_size is not None:
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
+ output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
+ return output, p_attn
+
+ def _matmul_with_relative_values(self, x, y):
+ """
+ x: [b, h, l, m]
+ y: [h or 1, m, d]
+ ret: [b, h, l, d]
+ """
+ ret = torch.matmul(x, y.unsqueeze(0))
+ return ret
+
+ def _matmul_with_relative_keys(self, x, y):
+ """
+ x: [b, h, l, d]
+ y: [h or 1, m, d]
+ ret: [b, h, l, m]
+ """
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
+ return ret
+
+ def _get_relative_embeddings(self, relative_embeddings, length):
+ max_relative_position = 2 * self.window_size + 1
+ # Pad first before slice to avoid using cond ops.
+ pad_length = max(length - (self.window_size + 1), 0)
+ slice_start_position = max((self.window_size + 1) - length, 0)
+ slice_end_position = slice_start_position + 2 * length - 1
+ if pad_length > 0:
+ padded_relative_embeddings = F.pad(
+ relative_embeddings,
+ convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
+ else:
+ padded_relative_embeddings = relative_embeddings
+ used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
+ return used_relative_embeddings
+
+ def _relative_position_to_absolute_position(self, x):
+ """
+ x: [b, h, l, 2*l-1]
+ ret: [b, h, l, l]
+ """
+ batch, heads, length, _ = x.size()
+ # Concat columns of pad to shift from relative to absolute indexing.
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
+
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
+ x_flat = x.view([batch, heads, length * 2 * length])
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
+
+ # Reshape and slice out the padded elements.
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
+ return x_final
+
+ def _absolute_position_to_relative_position(self, x):
+ """
+ x: [b, h, l, l]
+ ret: [b, h, l, 2*l-1]
+ """
+ batch, heads, length, _ = x.size()
+ # padd along column
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
+ x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)])
+ # add 0's in the beginning that will skew the elements after reshape
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
+ return x_final
+
+ def _attention_bias_proximal(self, length):
+ """Bias for self-attention to encourage attention to close positions.
+ Args:
+ length: an integer scalar.
+ Returns:
+ a Tensor with shape [1, 1, length, length]
+ """
+ r = torch.arange(length, dtype=torch.float32)
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
+
+
+class FFN(nn.Module):
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.activation = activation
+
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, 1)
+ self.drop = nn.Dropout(p_dropout)
+
+ def forward(self, x, x_mask):
+ x = self.conv_1(x * x_mask)
+ if self.activation == "gelu":
+ x = x * torch.sigmoid(1.702 * x)
+ else:
+ x = torch.relu(x)
+ x = self.drop(x)
+ x = self.conv_2(x * x_mask)
+ return x * x_mask
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, channels, eps=1e-4):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.gamma = nn.Parameter(torch.ones(channels))
+ self.beta = nn.Parameter(torch.zeros(channels))
+
+ def forward(self, x):
+ n_dims = len(x.shape)
+ mean = torch.mean(x, 1, keepdim=True)
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
+
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
+
+ shape = [1, -1] + [1] * (n_dims - 2)
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
+ return x
+
+
+class ConvReluNorm(nn.Module):
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+ assert n_layers > 1, "Number of layers should be larger than 0."
+
+ self.conv_layers = nn.ModuleList()
+ self.norm_layers = nn.ModuleList()
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.relu_drop = nn.Sequential(
+ nn.ReLU(),
+ nn.Dropout(p_dropout))
+ for _ in range(n_layers - 1):
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(self, x, x_mask):
+ x_org = x
+ for i in range(self.n_layers):
+ x = self.conv_layers[i](x * x_mask)
+ x = self.norm_layers[i](x)
+ x = self.relu_drop(x)
+ x = x_org + self.proj(x)
+ return x * x_mask
+
+
+class RelTransformerEncoder(nn.Module):
+ def __init__(self,
+ n_vocab,
+ out_channels,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout=0.0,
+ window_size=4,
+ block_length=None,
+ prenet=True,
+ pre_ln=True,
+ ):
+
+ super().__init__()
+
+ self.n_vocab = n_vocab
+ self.out_channels = out_channels
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+ self.block_length = block_length
+ self.prenet = prenet
+ if n_vocab > 0:
+ self.emb = Embedding(n_vocab, hidden_channels, padding_idx=0)
+
+ if prenet:
+ self.pre = ConvReluNorm(hidden_channels, hidden_channels, hidden_channels,
+ kernel_size=5, n_layers=3, p_dropout=0)
+ self.encoder = Encoder(
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout,
+ window_size=window_size,
+ block_length=block_length,
+ pre_ln=pre_ln,
+ )
+
+ def forward(self, x, x_mask=None):
+ if self.n_vocab > 0:
+ x_lengths = (x > 0).long().sum(-1)
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
+ else:
+ x_lengths = (x.abs().sum(-1) > 0).long().sum(-1)
+ x = torch.transpose(x, 1, -1) # [b, h, t]
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
+
+ if self.prenet:
+ x = self.pre(x, x_mask)
+ x = self.encoder(x, x_mask)
+ return x.transpose(1, 2)
diff --git a/preprocess/tools/note_transcription/modules/commons/rnn.py b/preprocess/tools/note_transcription/modules/commons/rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..205c2c76b8fda2de920bc59228a5eec0a20119a9
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/commons/rnn.py
@@ -0,0 +1,261 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+class PreNet(nn.Module):
+ def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
+ super().__init__()
+ self.fc1 = nn.Linear(in_dims, fc1_dims)
+ self.fc2 = nn.Linear(fc1_dims, fc2_dims)
+ self.p = dropout
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = F.relu(x)
+ x = F.dropout(x, self.p, training=self.training)
+ x = self.fc2(x)
+ x = F.relu(x)
+ x = F.dropout(x, self.p, training=self.training)
+ return x
+
+
+class HighwayNetwork(nn.Module):
+ def __init__(self, size):
+ super().__init__()
+ self.W1 = nn.Linear(size, size)
+ self.W2 = nn.Linear(size, size)
+ self.W1.bias.data.fill_(0.)
+
+ def forward(self, x):
+ x1 = self.W1(x)
+ x2 = self.W2(x)
+ g = torch.sigmoid(x2)
+ y = g * F.relu(x1) + (1. - g) * x
+ return y
+
+
+class BatchNormConv(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel, relu=True):
+ super().__init__()
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
+ self.bnorm = nn.BatchNorm1d(out_channels)
+ self.relu = relu
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = F.relu(x) if self.relu is True else x
+ return self.bnorm(x)
+
+
+class ConvNorm(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
+ padding=None, dilation=1, bias=True, w_init_gain='linear'):
+ super(ConvNorm, self).__init__()
+ if padding is None:
+ assert (kernel_size % 2 == 1)
+ padding = int(dilation * (kernel_size - 1) / 2)
+
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
+ kernel_size=kernel_size, stride=stride,
+ padding=padding, dilation=dilation,
+ bias=bias)
+
+ torch.nn.init.xavier_uniform_(
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
+
+ def forward(self, signal):
+ conv_signal = self.conv(signal)
+ return conv_signal
+
+
+class CBHG(nn.Module):
+ def __init__(self, K, in_channels, channels, proj_channels, num_highways):
+ super().__init__()
+
+ # List of all rnns to call `flatten_parameters()` on
+ self._to_flatten = []
+
+ self.bank_kernels = [i for i in range(1, K + 1)]
+ self.conv1d_bank = nn.ModuleList()
+ for k in self.bank_kernels:
+ conv = BatchNormConv(in_channels, channels, k)
+ self.conv1d_bank.append(conv)
+
+ self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
+
+ self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
+ self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
+
+ # Fix the highway input if necessary
+ if proj_channels[-1] != channels:
+ self.highway_mismatch = True
+ self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
+ else:
+ self.highway_mismatch = False
+
+ self.highways = nn.ModuleList()
+ for i in range(num_highways):
+ hn = HighwayNetwork(channels)
+ self.highways.append(hn)
+
+ self.rnn = nn.GRU(channels, channels, batch_first=True, bidirectional=True)
+ self._to_flatten.append(self.rnn)
+
+ # Avoid fragmentation of RNN parameters and associated warning
+ self._flatten_parameters()
+
+ def forward(self, x):
+ # Although we `_flatten_parameters()` on init, when using DataParallel
+ # the model gets replicated, making it no longer guaranteed that the
+ # weights are contiguous in GPU memory. Hence, we must call it again
+ self._flatten_parameters()
+
+ # Save these for later
+ residual = x
+ seq_len = x.size(-1)
+ conv_bank = []
+
+ # Convolution Bank
+ for conv in self.conv1d_bank:
+ c = conv(x) # Convolution
+ conv_bank.append(c[:, :, :seq_len])
+
+ # Stack along the channel axis
+ conv_bank = torch.cat(conv_bank, dim=1)
+
+ # dump the last padding to fit residual
+ x = self.maxpool(conv_bank)[:, :, :seq_len]
+
+ # Conv1d projections
+ x = self.conv_project1(x)
+ x = self.conv_project2(x)
+
+ # Residual Connect
+ x = x + residual
+
+ # Through the highways
+ x = x.transpose(1, 2)
+ if self.highway_mismatch is True:
+ x = self.pre_highway(x)
+ for h in self.highways:
+ x = h(x)
+
+ # And then the RNN
+ x, _ = self.rnn(x)
+ return x
+
+ def _flatten_parameters(self):
+ """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
+ to improve efficiency and avoid PyTorch yelling at us."""
+ [m.flatten_parameters() for m in self._to_flatten]
+
+
+class TacotronEncoder(nn.Module):
+ def __init__(self, embed_dims, num_chars, cbhg_channels, K, num_highways, dropout):
+ super().__init__()
+ self.embedding = nn.Embedding(num_chars, embed_dims)
+ self.pre_net = PreNet(embed_dims, embed_dims, embed_dims, dropout=dropout)
+ self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
+ proj_channels=[cbhg_channels, cbhg_channels],
+ num_highways=num_highways)
+ self.proj_out = nn.Linear(cbhg_channels * 2, cbhg_channels)
+
+ def forward(self, x):
+ x = self.embedding(x)
+ x = self.pre_net(x)
+ x.transpose_(1, 2)
+ x = self.cbhg(x)
+ x = self.proj_out(x)
+ return x
+
+
+class RNNEncoder(nn.Module):
+ def __init__(self, num_chars, embedding_dim, n_convolutions=3, kernel_size=5):
+ super(RNNEncoder, self).__init__()
+ self.embedding = nn.Embedding(num_chars, embedding_dim, padding_idx=0)
+ convolutions = []
+ for _ in range(n_convolutions):
+ conv_layer = nn.Sequential(
+ ConvNorm(embedding_dim,
+ embedding_dim,
+ kernel_size=kernel_size, stride=1,
+ padding=int((kernel_size - 1) / 2),
+ dilation=1, w_init_gain='relu'),
+ nn.BatchNorm1d(embedding_dim))
+ convolutions.append(conv_layer)
+ self.convolutions = nn.ModuleList(convolutions)
+
+ self.lstm = nn.LSTM(embedding_dim, int(embedding_dim / 2), 1,
+ batch_first=True, bidirectional=True)
+
+ def forward(self, x):
+ input_lengths = (x > 0).sum(-1)
+ input_lengths = input_lengths.cpu().numpy()
+
+ x = self.embedding(x)
+ x = x.transpose(1, 2) # [B, H, T]
+ for conv in self.convolutions:
+ x = F.dropout(F.relu(conv(x)), 0.5, self.training) + x
+ x = x.transpose(1, 2) # [B, T, H]
+
+ # pytorch tensor are not reversible, hence the conversion
+ x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
+
+ self.lstm.flatten_parameters()
+ outputs, _ = self.lstm(x)
+ outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
+
+ return outputs
+
+
+class DecoderRNN(torch.nn.Module):
+ def __init__(self, hidden_size, decoder_rnn_dim, dropout):
+ super(DecoderRNN, self).__init__()
+ self.in_conv1d = nn.Sequential(
+ torch.nn.Conv1d(
+ in_channels=hidden_size,
+ out_channels=hidden_size,
+ kernel_size=9, padding=4,
+ ),
+ torch.nn.ReLU(),
+ torch.nn.Conv1d(
+ in_channels=hidden_size,
+ out_channels=hidden_size,
+ kernel_size=9, padding=4,
+ ),
+ )
+ self.ln = nn.LayerNorm(hidden_size)
+ if decoder_rnn_dim == 0:
+ decoder_rnn_dim = hidden_size * 2
+ self.rnn = torch.nn.LSTM(
+ input_size=hidden_size,
+ hidden_size=decoder_rnn_dim,
+ num_layers=1,
+ batch_first=True,
+ bidirectional=True,
+ dropout=dropout
+ )
+ self.rnn.flatten_parameters()
+ self.conv1d = torch.nn.Conv1d(
+ in_channels=decoder_rnn_dim * 2,
+ out_channels=hidden_size,
+ kernel_size=3,
+ padding=1,
+ )
+
+ def forward(self, x):
+ input_masks = x.abs().sum(-1).ne(0).data[:, :, None]
+ input_lengths = input_masks.sum([-1, -2])
+ input_lengths = input_lengths.cpu().numpy()
+
+ x = self.in_conv1d(x.transpose(1, 2)).transpose(1, 2)
+ x = self.ln(x)
+ x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
+ self.rnn.flatten_parameters()
+ x, _ = self.rnn(x) # [B, T, C]
+ x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
+ x = x * input_masks
+ pre_mel = self.conv1d(x.transpose(1, 2)).transpose(1, 2) # [B, T, C]
+ pre_mel = pre_mel * input_masks
+ return pre_mel
diff --git a/preprocess/tools/note_transcription/modules/commons/transformer.py b/preprocess/tools/note_transcription/modules/commons/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a997d39cd569e688616391d31ad2a15c7ae7432
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/commons/transformer.py
@@ -0,0 +1,751 @@
+import math
+import torch
+from torch import nn
+from torch.nn import Parameter, Linear
+from .layers import LayerNorm, Embedding
+from ...utils.nn.seq_utils import (
+ get_incremental_state,
+ set_incremental_state,
+ softmax,
+ make_positions,
+)
+import torch.nn.functional as F
+
+DEFAULT_MAX_SOURCE_POSITIONS = 2000
+DEFAULT_MAX_TARGET_POSITIONS = 2000
+
+
+class SinusoidalPositionalEmbedding(nn.Module):
+ """This module produces sinusoidal positional embeddings of any length.
+
+ Padding symbols are ignored.
+ """
+
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.padding_idx = padding_idx
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ init_size,
+ embedding_dim,
+ padding_idx,
+ )
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
+
+ @staticmethod
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
+ """Build sinusoidal embeddings.
+
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
+ if embedding_dim % 2 == 1:
+ # zero pad
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+ if padding_idx is not None:
+ emb[padding_idx, :] = 0
+ return emb
+
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
+ """Input is expected to be of size [bsz x seqlen]."""
+ bsz, seq_len = input.shape[:2]
+ max_pos = self.padding_idx + 1 + seq_len
+ if self.weights is None or max_pos > self.weights.size(0):
+ # recompute/expand embeddings if needed
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ max_pos,
+ self.embedding_dim,
+ self.padding_idx,
+ )
+ self.weights = self.weights.to(self._float_tensor)
+
+ if incremental_state is not None:
+ # positions is the same for every token when decoding a single step
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
+
+ positions = make_positions(input, self.padding_idx) if positions is None else positions
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
+
+ def max_positions(self):
+ """Maximum number of supported positions."""
+ return int(1e5) # an arbitrary large number
+
+
+class TransformerFFNLayer(nn.Module):
+ def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.dropout = dropout
+ self.act = act
+ if padding == 'SAME':
+ self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
+ elif padding == 'LEFT':
+ self.ffn_1 = nn.Sequential(
+ nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
+ nn.Conv1d(hidden_size, filter_size, kernel_size)
+ )
+ self.ffn_2 = Linear(filter_size, hidden_size)
+
+ def forward(self, x, incremental_state=None):
+ # x: T x B x C
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_input' in saved_state:
+ prev_input = saved_state['prev_input']
+ x = torch.cat((prev_input, x), dim=0)
+ x = x[-self.kernel_size:]
+ saved_state['prev_input'] = x
+ self._set_input_buffer(incremental_state, saved_state)
+
+ x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
+ x = x * self.kernel_size ** -0.5
+
+ if incremental_state is not None:
+ x = x[-1:]
+ if self.act == 'gelu':
+ x = F.gelu(x)
+ if self.act == 'relu':
+ x = F.relu(x)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = self.ffn_2(x)
+ return x
+
+ def _get_input_buffer(self, incremental_state):
+ return get_incremental_state(
+ self,
+ incremental_state,
+ 'f',
+ ) or {}
+
+ def _set_input_buffer(self, incremental_state, buffer):
+ set_incremental_state(
+ self,
+ incremental_state,
+ 'f',
+ buffer,
+ )
+
+ def clear_buffer(self, incremental_state):
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_input' in saved_state:
+ del saved_state['prev_input']
+ self._set_input_buffer(incremental_state, saved_state)
+
+
+class MultiheadAttention(nn.Module):
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
+ encoder_decoder_attention=False):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.kdim = kdim if kdim is not None else embed_dim
+ self.vdim = vdim if vdim is not None else embed_dim
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+ self.scaling = self.head_dim ** -0.5
+
+ self.self_attention = self_attention
+ self.encoder_decoder_attention = encoder_decoder_attention
+
+ assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
+ 'value to be of the same size'
+
+ if self.qkv_same_dim:
+ self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
+ else:
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
+
+ if bias:
+ self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
+ else:
+ self.register_parameter('in_proj_bias', None)
+
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ if add_bias_kv:
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
+ else:
+ self.bias_k = self.bias_v = None
+
+ self.add_zero_attn = add_zero_attn
+
+ self.reset_parameters()
+
+ self.enable_torch_version = False
+ if hasattr(F, "multi_head_attention_forward"):
+ self.enable_torch_version = True
+ else:
+ self.enable_torch_version = False
+ self.last_attn_probs = None
+
+ def reset_parameters(self):
+ if self.qkv_same_dim:
+ nn.init.xavier_uniform_(self.in_proj_weight)
+ else:
+ nn.init.xavier_uniform_(self.k_proj_weight)
+ nn.init.xavier_uniform_(self.v_proj_weight)
+ nn.init.xavier_uniform_(self.q_proj_weight)
+
+ nn.init.xavier_uniform_(self.out_proj.weight)
+ if self.in_proj_bias is not None:
+ nn.init.constant_(self.in_proj_bias, 0.)
+ nn.init.constant_(self.out_proj.bias, 0.)
+ if self.bias_k is not None:
+ nn.init.xavier_normal_(self.bias_k)
+ if self.bias_v is not None:
+ nn.init.xavier_normal_(self.bias_v)
+
+ def forward(
+ self,
+ query, key, value,
+ key_padding_mask=None,
+ incremental_state=None,
+ need_weights=True,
+ static_kv=False,
+ attn_mask=None,
+ before_softmax=False,
+ need_head_weights=False,
+ enc_dec_attn_constraint_mask=None,
+ reset_attn_weight=None
+ ):
+ """Input shape: Time x Batch x Channel
+
+ Args:
+ key_padding_mask (ByteTensor, optional): mask to exclude
+ keys that are pads, of shape `(batch, src_len)`, where
+ padding elements are indicated by 1s.
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (ByteTensor, optional): typically used to
+ implement causal attention, where the mask prevents the
+ attention from looking forward in time (default: None).
+ before_softmax (bool, optional): return the raw attention
+ weights and values before the attention softmax.
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default:
+ return the average attention weights over all heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+ if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
+ if self.qkv_same_dim:
+ return F.multi_head_attention_forward(query, key, value,
+ self.embed_dim, self.num_heads,
+ self.in_proj_weight,
+ self.in_proj_bias, self.bias_k, self.bias_v,
+ self.add_zero_attn, self.dropout,
+ self.out_proj.weight, self.out_proj.bias,
+ self.training, key_padding_mask, need_weights,
+ attn_mask)
+ else:
+ return F.multi_head_attention_forward(query, key, value,
+ self.embed_dim, self.num_heads,
+ torch.empty([0]),
+ self.in_proj_bias, self.bias_k, self.bias_v,
+ self.add_zero_attn, self.dropout,
+ self.out_proj.weight, self.out_proj.bias,
+ self.training, key_padding_mask, need_weights,
+ attn_mask, use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj_weight,
+ k_proj_weight=self.k_proj_weight,
+ v_proj_weight=self.v_proj_weight)
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_key' in saved_state:
+ # previous time steps are cached - no need to recompute
+ # key and value if they are static
+ if static_kv:
+ assert self.encoder_decoder_attention and not self.self_attention
+ key = value = None
+ else:
+ saved_state = None
+
+ if self.self_attention:
+ # self-attention
+ q, k, v = self.in_proj_qkv(query)
+ elif self.encoder_decoder_attention:
+ # encoder-decoder attention
+ q = self.in_proj_q(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(key)
+
+ else:
+ q = self.in_proj_q(query)
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(value)
+ q *= self.scaling
+
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
+
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+
+ if saved_state is not None:
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+ if 'prev_key' in saved_state:
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ k = prev_key
+ else:
+ k = torch.cat((prev_key, k), dim=1)
+ if 'prev_value' in saved_state:
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ v = prev_value
+ else:
+ v = torch.cat((prev_value, v), dim=1)
+ if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
+ prev_key_padding_mask = saved_state['prev_key_padding_mask']
+ if static_kv:
+ key_padding_mask = prev_key_padding_mask
+ else:
+ key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
+
+ saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state['prev_key_padding_mask'] = key_padding_mask
+
+ self._set_input_buffer(incremental_state, saved_state)
+
+ src_len = k.size(1)
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if self.add_zero_attn:
+ src_len += 1
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
+
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ if len(attn_mask.shape) == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ elif len(attn_mask.shape) == 3:
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
+ bsz * self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights + attn_mask
+
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.masked_fill(
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
+ -1e8,
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if key_padding_mask is not None:
+ # don't attend to padding symbols
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ -1e8,
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+
+ if before_softmax:
+ return attn_weights, v
+
+ attn_weights_float = softmax(attn_weights, dim=-1)
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
+
+ if reset_attn_weight is not None:
+ if reset_attn_weight:
+ self.last_attn_probs = attn_probs.detach()
+ else:
+ assert self.last_attn_probs is not None
+ attn_probs = self.last_attn_probs
+ attn = torch.bmm(attn_probs, v)
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn = self.out_proj(attn)
+
+ if need_weights:
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
+ if not need_head_weights:
+ # average attention weights over heads
+ attn_weights = attn_weights.mean(dim=0)
+ else:
+ attn_weights = None
+
+ return attn, (attn_weights, attn_logits)
+
+ def in_proj_qkv(self, query):
+ return self._in_proj(query).chunk(3, dim=-1)
+
+ def in_proj_q(self, query):
+ if self.qkv_same_dim:
+ return self._in_proj(query, end=self.embed_dim)
+ else:
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[:self.embed_dim]
+ return F.linear(query, self.q_proj_weight, bias)
+
+ def in_proj_k(self, key):
+ if self.qkv_same_dim:
+ return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
+ else:
+ weight = self.k_proj_weight
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[self.embed_dim:2 * self.embed_dim]
+ return F.linear(key, weight, bias)
+
+ def in_proj_v(self, value):
+ if self.qkv_same_dim:
+ return self._in_proj(value, start=2 * self.embed_dim)
+ else:
+ weight = self.v_proj_weight
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[2 * self.embed_dim:]
+ return F.linear(value, weight, bias)
+
+ def _in_proj(self, input, start=0, end=None):
+ weight = self.in_proj_weight
+ bias = self.in_proj_bias
+ weight = weight[start:end, :]
+ if bias is not None:
+ bias = bias[start:end]
+ return F.linear(input, weight, bias)
+
+ def _get_input_buffer(self, incremental_state):
+ return get_incremental_state(
+ self,
+ incremental_state,
+ 'attn_state',
+ ) or {}
+
+ def _set_input_buffer(self, incremental_state, buffer):
+ set_incremental_state(
+ self,
+ incremental_state,
+ 'attn_state',
+ buffer,
+ )
+
+ def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
+ return attn_weights
+
+ def clear_buffer(self, incremental_state=None):
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_key' in saved_state:
+ del saved_state['prev_key']
+ if 'prev_value' in saved_state:
+ del saved_state['prev_value']
+ self._set_input_buffer(incremental_state, saved_state)
+
+
+class EncSALayer(nn.Module):
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
+ relu_dropout=0.1, kernel_size=9, padding='SAME', act='gelu'):
+ super().__init__()
+ self.c = c
+ self.dropout = dropout
+ self.num_heads = num_heads
+ if num_heads > 0:
+ self.layer_norm1 = LayerNorm(c)
+ self.self_attn = MultiheadAttention(
+ self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False)
+ self.layer_norm2 = LayerNorm(c)
+ self.ffn = TransformerFFNLayer(
+ c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
+
+ def forward(self, x, encoder_padding_mask=None, **kwargs):
+ layer_norm_training = kwargs.get('layer_norm_training', None)
+ if layer_norm_training is not None:
+ self.layer_norm1.training = layer_norm_training
+ self.layer_norm2.training = layer_norm_training
+ if self.num_heads > 0:
+ residual = x
+ x = self.layer_norm1(x)
+ x, _, = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=encoder_padding_mask
+ )
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
+
+ residual = x
+ x = self.layer_norm2(x)
+ x = self.ffn(x)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
+ return x
+
+
+class DecSALayer(nn.Module):
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
+ kernel_size=9, act='gelu'):
+ super().__init__()
+ self.c = c
+ self.dropout = dropout
+ self.layer_norm1 = LayerNorm(c)
+ self.self_attn = MultiheadAttention(
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
+ )
+ self.layer_norm2 = LayerNorm(c)
+ self.encoder_attn = MultiheadAttention(
+ c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
+ )
+ self.layer_norm3 = LayerNorm(c)
+ self.ffn = TransformerFFNLayer(
+ c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
+
+ def forward(
+ self,
+ x,
+ encoder_out=None,
+ encoder_padding_mask=None,
+ incremental_state=None,
+ self_attn_mask=None,
+ self_attn_padding_mask=None,
+ attn_out=None,
+ reset_attn_weight=None,
+ **kwargs,
+ ):
+ layer_norm_training = kwargs.get('layer_norm_training', None)
+ if layer_norm_training is not None:
+ self.layer_norm1.training = layer_norm_training
+ self.layer_norm2.training = layer_norm_training
+ self.layer_norm3.training = layer_norm_training
+ residual = x
+ x = self.layer_norm1(x)
+ x, _ = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ incremental_state=incremental_state,
+ attn_mask=self_attn_mask
+ )
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+
+ attn_logits = None
+ if encoder_out is not None or attn_out is not None:
+ residual = x
+ x = self.layer_norm2(x)
+ if encoder_out is not None:
+ x, attn = self.encoder_attn(
+ query=x,
+ key=encoder_out,
+ value=encoder_out,
+ key_padding_mask=encoder_padding_mask,
+ incremental_state=incremental_state,
+ static_kv=True,
+ enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
+ 'enc_dec_attn_constraint_mask'),
+ reset_attn_weight=reset_attn_weight
+ )
+ attn_logits = attn[1]
+ elif attn_out is not None:
+ x = self.encoder_attn.in_proj_v(attn_out)
+ if encoder_out is not None or attn_out is not None:
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+
+ residual = x
+ x = self.layer_norm3(x)
+ x = self.ffn(x, incremental_state=incremental_state)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ return x, attn_logits
+
+ def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
+ self.encoder_attn.clear_buffer(incremental_state)
+ self.ffn.clear_buffer(incremental_state)
+
+ def set_buffer(self, name, tensor, incremental_state):
+ return set_incremental_state(self, incremental_state, name, tensor)
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.dropout = dropout
+ self.num_heads = num_heads
+ self.op = EncSALayer(
+ hidden_size, num_heads, dropout=dropout,
+ attention_dropout=0.0, relu_dropout=dropout,
+ kernel_size=kernel_size)
+
+ def forward(self, x, **kwargs):
+ return self.op(x, **kwargs)
+
+
+class TransformerDecoderLayer(nn.Module):
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.dropout = dropout
+ self.num_heads = num_heads
+ self.op = DecSALayer(
+ hidden_size, num_heads, dropout=dropout,
+ attention_dropout=0.0, relu_dropout=dropout,
+ kernel_size=kernel_size)
+
+ def forward(self, x, **kwargs):
+ return self.op(x, **kwargs)
+
+ def clear_buffer(self, *args):
+ return self.op.clear_buffer(*args)
+
+ def set_buffer(self, *args):
+ return self.op.set_buffer(*args)
+
+
+class FFTBlocks(nn.Module):
+ def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=0.0,
+ num_heads=2, use_pos_embed=True, use_last_norm=True,
+ use_pos_embed_alpha=True):
+ super().__init__()
+ self.num_layers = num_layers
+ embed_dim = self.hidden_size = hidden_size
+ self.dropout = dropout
+ self.use_pos_embed = use_pos_embed
+ self.use_last_norm = use_last_norm
+ if use_pos_embed:
+ self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
+ self.padding_idx = 0
+ self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
+ self.embed_positions = SinusoidalPositionalEmbedding(
+ embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
+ )
+
+ self.layers = nn.ModuleList([])
+ self.layers.extend([
+ TransformerEncoderLayer(self.hidden_size, self.dropout,
+ kernel_size=ffn_kernel_size, num_heads=num_heads)
+ for _ in range(self.num_layers)
+ ])
+ if self.use_last_norm:
+ self.layer_norm = nn.LayerNorm(embed_dim)
+ else:
+ self.layer_norm = None
+
+ def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
+ """
+ :param x: [B, T, C]
+ :param padding_mask: [B, T]
+ :return: [B, T, C] or [L, B, T, C]
+ """
+ padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
+ nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
+ if self.use_pos_embed:
+ positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
+ x = x + positions
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1) * nonpadding_mask_TB
+ hiddens = []
+ for layer in self.layers:
+ x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
+ hiddens.append(x)
+ if self.use_last_norm:
+ x = self.layer_norm(x) * nonpadding_mask_TB
+ if return_hiddens:
+ x = torch.stack(hiddens, 0) # [L, T, B, C]
+ x = x.transpose(1, 2) # [L, B, T, C]
+ else:
+ x = x.transpose(0, 1) # [B, T, C]
+ return x
+
+
+class FastSpeechEncoder(FFTBlocks):
+ def __init__(self, dict_size, hidden_size=256, num_layers=4, kernel_size=9, num_heads=2,
+ dropout=0.0):
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
+ use_pos_embed=False, dropout=dropout) # use_pos_embed_alpha for compatibility
+ self.embed_tokens = Embedding(dict_size, hidden_size, 0)
+ self.embed_scale = math.sqrt(hidden_size)
+ self.padding_idx = 0
+ self.embed_positions = SinusoidalPositionalEmbedding(
+ hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
+ )
+
+ def forward(self, txt_tokens, attn_mask=None):
+ """
+
+ :param txt_tokens: [B, T]
+ :return: {
+ 'encoder_out': [B x T x C]
+ }
+ """
+ encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
+ x = self.forward_embedding(txt_tokens) # [B, T, H]
+ if self.num_layers > 0:
+ x = super(FastSpeechEncoder, self).forward(x, encoder_padding_mask, attn_mask=attn_mask)
+ return x
+
+ def forward_embedding(self, txt_tokens):
+ # embed tokens and positions
+ x = self.embed_scale * self.embed_tokens(txt_tokens)
+ positions = self.embed_positions(txt_tokens)
+ x = x + positions
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ return x
+
+
+class FastSpeechDecoder(FFTBlocks):
+ def __init__(self, hidden_size=256, num_layers=4, kernel_size=9, num_heads=2):
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads)
diff --git a/preprocess/tools/note_transcription/modules/commons/wavenet.py b/preprocess/tools/note_transcription/modules/commons/wavenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..99870dbea53845d841df86fdc2aa3833bc415b11
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/commons/wavenet.py
@@ -0,0 +1,109 @@
+import torch
+from torch import nn
+from packaging import version
+
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+ n_channels_int = n_channels[0]
+ in_act = input_a + input_b
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+ return acts
+
+jit_fused_add_tanh_sigmoid_multiply = fused_add_tanh_sigmoid_multiply
+
+def script_function():
+ if version.parse(torch.__version__) >= version.parse('2.0'):
+ global jit_fused_add_tanh_sigmoid_multiply
+ jit_fused_add_tanh_sigmoid_multiply = torch.jit.script(fused_add_tanh_sigmoid_multiply)
+
+
+class WN(torch.nn.Module):
+ def __init__(self, hidden_size, kernel_size, dilation_rate, n_layers, c_cond=0,
+ p_dropout=0, share_cond_layers=False, is_BTC=False):
+ super(WN, self).__init__()
+ assert (kernel_size % 2 == 1)
+ assert (hidden_size % 2 == 0)
+ self.is_BTC = is_BTC
+ self.hidden_size = hidden_size
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = c_cond
+ self.p_dropout = p_dropout
+ self.share_cond_layers = share_cond_layers
+
+ self.in_layers = torch.nn.ModuleList()
+ self.res_skip_layers = torch.nn.ModuleList()
+ self.drop = nn.Dropout(p_dropout)
+
+ if c_cond != 0 and not share_cond_layers:
+ cond_layer = torch.nn.Conv1d(c_cond, 2 * hidden_size * n_layers, 1)
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
+
+ for i in range(n_layers):
+ dilation = dilation_rate ** i
+ padding = int((kernel_size * dilation - dilation) / 2)
+ in_layer = torch.nn.Conv1d(hidden_size, 2 * hidden_size, kernel_size,
+ dilation=dilation, padding=padding)
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
+ self.in_layers.append(in_layer)
+
+ # last one is not necessary
+ if i < n_layers - 1:
+ res_skip_channels = 2 * hidden_size
+ else:
+ res_skip_channels = hidden_size
+
+ res_skip_layer = torch.nn.Conv1d(hidden_size, res_skip_channels, 1)
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
+ self.res_skip_layers.append(res_skip_layer)
+
+ script_function()
+
+ def forward(self, x, nonpadding=None, cond=None):
+ if self.is_BTC:
+ x = x.transpose(1, 2)
+ cond = cond.transpose(1, 2) if cond is not None else None
+ nonpadding = nonpadding.transpose(1, 2) if nonpadding is not None else None
+ if nonpadding is None:
+ nonpadding = 1
+ output = torch.zeros_like(x)
+ n_channels_tensor = torch.IntTensor([self.hidden_size])
+
+ if cond is not None and not self.share_cond_layers:
+ cond = self.cond_layer(cond)
+
+ for i in range(self.n_layers):
+ x_in = self.in_layers[i](x)
+ x_in = self.drop(x_in)
+ if cond is not None:
+ cond_offset = i * 2 * self.hidden_size
+ cond_l = cond[:, cond_offset:cond_offset + 2 * self.hidden_size, :]
+ else:
+ cond_l = torch.zeros_like(x_in)
+
+ if version.parse(torch.__version__) >= version.parse('2.0'):
+ acts = jit_fused_add_tanh_sigmoid_multiply(x_in, cond_l, n_channels_tensor)
+ else:
+ acts = fused_add_tanh_sigmoid_multiply(x_in, cond_l, n_channels_tensor)
+
+ res_skip_acts = self.res_skip_layers[i](acts)
+ if i < self.n_layers - 1:
+ x = (x + res_skip_acts[:, :self.hidden_size, :]) * nonpadding
+ output = output + res_skip_acts[:, self.hidden_size:, :]
+ else:
+ output = output + res_skip_acts
+ output = output * nonpadding
+ if self.is_BTC:
+ output = output.transpose(1, 2)
+ return output
+
+ def remove_weight_norm(self):
+ def remove_weight_norm(m):
+ try:
+ nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(remove_weight_norm)
diff --git a/preprocess/tools/note_transcription/modules/pe/__init__.py b/preprocess/tools/note_transcription/modules/pe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fda63eecd09d3ba4e5e5032a332421bfd760398
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/pe/__init__.py
@@ -0,0 +1 @@
+"""Pitch extractor modules for ROSVOT."""
diff --git a/preprocess/tools/note_transcription/modules/pe/rmvpe/__init__.py b/preprocess/tools/note_transcription/modules/pe/rmvpe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a03ccb69cb5cfe93701bb34b987c7f0ba73d28a
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/pe/rmvpe/__init__.py
@@ -0,0 +1,6 @@
+from .constants import *
+from .model import E2E0
+from .utils import to_local_average_f0, to_viterbi_f0
+from .inference import RMVPE
+from .spec import MelSpectrogram
+from .extractor import extract
diff --git a/preprocess/tools/note_transcription/modules/pe/rmvpe/constants.py b/preprocess/tools/note_transcription/modules/pe/rmvpe/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..525a2a0da43777cf67a494fcedc2a65d39349b15
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/pe/rmvpe/constants.py
@@ -0,0 +1,9 @@
+SAMPLE_RATE = 16000
+
+N_CLASS = 360
+
+N_MELS = 128
+MEL_FMIN = 30
+MEL_FMAX = 8000
+WINDOW_LENGTH = 1024
+CONST = 1997.3794084376191
diff --git a/preprocess/tools/note_transcription/modules/pe/rmvpe/deepunet.py b/preprocess/tools/note_transcription/modules/pe/rmvpe/deepunet.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e50d5e0b0af76fc507f905ae1ab64cebb7914d6
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/pe/rmvpe/deepunet.py
@@ -0,0 +1,173 @@
+import torch
+import torch.nn as nn
+from .constants import N_MELS
+
+
+class ConvBlockRes(nn.Module):
+ def __init__(self, in_channels, out_channels, momentum=0.01):
+ super(ConvBlockRes, self).__init__()
+ self.conv = nn.Sequential(
+ nn.Conv2d(in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(3, 3),
+ stride=(1, 1),
+ padding=(1, 1),
+ bias=False),
+ nn.BatchNorm2d(out_channels, momentum=momentum),
+ nn.ReLU(),
+
+ nn.Conv2d(in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=(3, 3),
+ stride=(1, 1),
+ padding=(1, 1),
+ bias=False),
+ nn.BatchNorm2d(out_channels, momentum=momentum),
+ nn.ReLU(),
+ )
+ if in_channels != out_channels:
+ self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
+ self.is_shortcut = True
+ else:
+ self.is_shortcut = False
+
+ def forward(self, x):
+ if self.is_shortcut:
+ return self.conv(x) + self.shortcut(x)
+ else:
+ return self.conv(x) + x
+
+
+class ResEncoderBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
+ super(ResEncoderBlock, self).__init__()
+ self.n_blocks = n_blocks
+ self.conv = nn.ModuleList()
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
+ for i in range(n_blocks - 1):
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
+ self.kernel_size = kernel_size
+ if self.kernel_size is not None:
+ self.pool = nn.AvgPool2d(kernel_size=kernel_size)
+
+ def forward(self, x):
+ for i in range(self.n_blocks):
+ x = self.conv[i](x)
+ if self.kernel_size is not None:
+ return x, self.pool(x)
+ else:
+ return x
+
+
+class ResDecoderBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
+ super(ResDecoderBlock, self).__init__()
+ out_padding = (0, 1) if stride == (1, 2) else (1, 1)
+ self.n_blocks = n_blocks
+ self.conv1 = nn.Sequential(
+ nn.ConvTranspose2d(in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(3, 3),
+ stride=stride,
+ padding=(1, 1),
+ output_padding=out_padding,
+ bias=False),
+ nn.BatchNorm2d(out_channels, momentum=momentum),
+ nn.ReLU(),
+ )
+ self.conv2 = nn.ModuleList()
+ self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
+ for i in range(n_blocks-1):
+ self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
+
+ def forward(self, x, concat_tensor):
+ x = self.conv1(x)
+ x = torch.cat((x, concat_tensor), dim=1)
+ for i in range(self.n_blocks):
+ x = self.conv2[i](x)
+ return x
+
+
+class Encoder(nn.Module):
+ def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
+ super(Encoder, self).__init__()
+ self.n_encoders = n_encoders
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
+ self.layers = nn.ModuleList()
+ self.latent_channels = []
+ for i in range(self.n_encoders):
+ self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum))
+ self.latent_channels.append([out_channels, in_size])
+ in_channels = out_channels
+ out_channels *= 2
+ in_size //= 2
+ self.out_size = in_size
+ self.out_channel = out_channels
+
+ def forward(self, x):
+ concat_tensors = []
+ x = self.bn(x)
+ for i in range(self.n_encoders):
+ _, x = self.layers[i](x)
+ concat_tensors.append(_)
+ return x, concat_tensors
+
+
+class Intermediate(nn.Module):
+ def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
+ super(Intermediate, self).__init__()
+ self.n_inters = n_inters
+ self.layers = nn.ModuleList()
+ self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
+ for i in range(self.n_inters-1):
+ self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
+
+ def forward(self, x):
+ for i in range(self.n_inters):
+ x = self.layers[i](x)
+ return x
+
+
+class Decoder(nn.Module):
+ def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
+ super(Decoder, self).__init__()
+ self.layers = nn.ModuleList()
+ self.n_decoders = n_decoders
+ for i in range(self.n_decoders):
+ out_channels = in_channels // 2
+ self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
+ in_channels = out_channels
+
+ def forward(self, x, concat_tensors):
+ for i in range(self.n_decoders):
+ x = self.layers[i](x, concat_tensors[-1-i])
+ return x
+
+
+class TimbreFilter(nn.Module):
+ def __init__(self, latent_rep_channels):
+ super(TimbreFilter, self).__init__()
+ self.layers = nn.ModuleList()
+ for latent_rep in latent_rep_channels:
+ self.layers.append(ConvBlockRes(latent_rep[0], latent_rep[0]))
+
+ def forward(self, x_tensors):
+ out_tensors = []
+ for i, layer in enumerate(self.layers):
+ out_tensors.append(layer(x_tensors[i]))
+ return out_tensors
+
+
+class DeepUnet0(nn.Module):
+ def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
+ super(DeepUnet0, self).__init__()
+ self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels)
+ self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
+ self.tf = TimbreFilter(self.encoder.latent_channels)
+ self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
+
+ def forward(self, x):
+ x, concat_tensors = self.encoder(x)
+ x = self.intermediate(x)
+ x = self.decoder(x, concat_tensors)
+ return x
diff --git a/preprocess/tools/note_transcription/modules/pe/rmvpe/extractor.py b/preprocess/tools/note_transcription/modules/pe/rmvpe/extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..7051c534dcbd7e0565f51119719cda56d9827b6b
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/pe/rmvpe/extractor.py
@@ -0,0 +1,183 @@
+import math
+import os
+
+from tqdm import tqdm
+import librosa
+import numpy as np
+import torch
+from torch.utils.data import Dataset, DataLoader, DistributedSampler
+import torch.multiprocessing as mp
+from torch.distributed import init_process_group
+import torch.distributed as dist
+
+from .inference import RMVPE
+from ....utils.commons.dataset_utils import batch_by_size, build_dataloader
+# import utils
+from ....utils.audio import get_wav_num_frames
+
+"""
+A convenient API for batch inference
+update: add ddp
+"""
+
+class RMVPEInferDataset(Dataset):
+ def __init__(self, wav_fns: list, id_and_sizes=None, sr=24000, hop_size=128, num_workers=0):
+ if id_and_sizes is None:
+ id_and_sizes = []
+ if type(wav_fns[0]) == str: # wav_paths
+ for idx, wav_path in enumerate(wav_fns):
+ total_frames = get_wav_num_frames(wav_path, sr)
+ id_and_sizes.append((idx, round(total_frames / hop_size)))
+ else: # numpy arrays, mono wavs
+ for idx, wav in enumerate(wav_fns):
+ id_and_sizes.append((idx, round(wav.shape[-1] / hop_size)))
+ self.wav_fns = wav_fns
+ self.id_and_sizes = id_and_sizes
+ self.sr = sr
+ self.num_workers = num_workers
+
+ def __getitem__(self, idx):
+ if type(self.wav_fns[idx]) == str:
+ wav_fn = self.wav_fns[idx]
+ wav, _ = librosa.core.load(wav_fn, sr=self.sr)
+ else:
+ wav = self.wav_fns[idx]
+ return idx, wav
+
+ def collater(self, samples: list):
+ return samples
+
+ def __len__(self):
+ return len(self.wav_fns)
+
+ def ordered_indices(self):
+ """Return an ordered list of indices. Batches will be constructed based
+ on this order."""
+ return np.arange(len(self))
+
+ def num_tokens(self, index):
+ return self.id_and_sizes[index][1]
+
+@torch.no_grad()
+def extract(wav_fns: list, id_and_sizes=None, ckpt=None, sr=24000, hop_size=128, bsz=128, max_tokens=100000,
+ fmax=900, fmin=50, ds_workers=0):
+ all_gpu_ids = [int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",") if x != '']
+ num_gpus = len(all_gpu_ids)
+ dist_config = {
+ "dist_backend": "nccl",
+ "dist_url": "tcp://localhost:54189",
+ "world_size": 1
+ }
+ # https://discuss.pytorch.org/t/how-to-fix-a-sigsegv-in-pytorch-when-using-distributed-training-e-g-ddp/113518/10#:~:text=Using%20start%20and%20join%20avoids
+ # https://github.com/pytorch/pytorch/issues/40403#issuecomment-648515174
+ # mp.set_start_method('spawn')
+ if num_gpus > 1:
+ result_queue = mp.Queue()
+ for rank in range(num_gpus):
+ mp.Process(target=extract_worker, args=(rank, wav_fns, id_and_sizes, ckpt, sr, hop_size, bsz, max_tokens, fmax,
+ fmin, dist_config, num_gpus, ds_workers, result_queue,)).start()
+ f0_res = [None] * len(wav_fns)
+ for _ in range(num_gpus):
+ f0_res_dict = result_queue.get()
+ for idx in f0_res_dict:
+ f0_res[idx] = f0_res_dict[idx]
+ del f0_res_dict
+ else:
+ # f0_res = extract_one_process(wav_fns, id_and_sizes, ckpt, sr, hop_size, bsz, max_tokens, fmax, fmin)
+ f0_res_dict = extract_worker(0, wav_fns, id_and_sizes, ckpt, sr, hop_size, bsz, max_tokens, fmax,
+ fmin, dist_config, num_gpus, ds_workers, None)
+ f0_res = [None] * len(wav_fns)
+ for idx in f0_res_dict:
+ f0_res[idx] = f0_res_dict[idx]
+ return f0_res
+
+@torch.no_grad()
+def extract_worker(rank, wav_fns: list, id_and_sizes=None, ckpt=None, sr=24000, hop_size=128, bsz=128, max_tokens=100000,
+ fmax=900, fmin=50, dist_config=None, num_gpus=1, ds_workers=0, q=None):
+ # print(f"rank: {rank}")
+ if num_gpus > 1:
+ init_process_group(backend=dist_config['dist_backend'], init_method=dist_config['dist_url'],
+ world_size=dist_config['world_size'] * num_gpus, rank=rank)
+ dataset = RMVPEInferDataset(wav_fns, id_and_sizes, sr, hop_size, num_workers=ds_workers)
+ # ds_sampler = DistributedSampler(dataset, shuffle=False) if num_gpus > 1 else None
+ # loader = DataLoader(dataset, sampler=ds_sampler, collate_fn=dataset.collator, batch_size=1, num_workers=40, drop_last=False)
+ loader = build_dataloader(dataset, shuffle=False, max_tokens=max_tokens, max_sentences=bsz, use_ddp=num_gpus > 1)
+ loader = tqdm(loader, desc=f'| Processing f0 in [n_ranks={num_gpus}; max_tokens={max_tokens}; max_sentences={bsz}]') if rank == 0 else loader
+
+ device = torch.device(f"cuda:{int(rank)}")
+ model = RMVPE(ckpt, device=device)
+ f0_res_dict = {}
+ for batch in loader:
+ if batch is None or len(batch) == 0:
+ continue
+ idxs = [item[0] for item in batch]
+ wavs = [item[1] for item in batch]
+ lengths = [(wav.shape[0] + hop_size - 1) // hop_size for wav in wavs]
+ with torch.no_grad():
+ f0s, uvs = model.get_pitch_batch(
+ wavs, sample_rate=sr,
+ hop_size=hop_size,
+ lengths=lengths,
+ fmax=fmax,
+ fmin=fmin
+ )
+ for i, idx in enumerate(idxs):
+ f0_res_dict[idx] = f0s[i]
+ if q is not None:
+ q.put(f0_res_dict)
+ else:
+ return f0_res_dict
+
+# old version
+def extract_one_process(wav_fns: list, id_and_sizes=None, ckpt=None, sr=24000, hop_size=128, bsz=128, max_tokens=100000,
+ fmax=900, fmin=50, device='cuda'):
+ assert ckpt is not None
+ rmvpe = RMVPE(ckpt, device=device)
+ if id_and_sizes is None:
+ id_and_sizes = []
+ if type(wav_fns[0]) == str: # wav_paths
+ for idx, wav_path in enumerate(wav_fns):
+ total_frames = get_wav_num_frames(wav_path, sr)
+ id_and_sizes.append((idx, round(total_frames / hop_size)))
+ else: # numpy arrays, mono wavs
+ for idx, wav in enumerate(wav_fns):
+ id_and_sizes.append((idx, round(wav.shape[-1] / hop_size)))
+ get_size = lambda x: x[1]
+ bs = batch_by_size(id_and_sizes, get_size, max_tokens=max_tokens, max_sentences=bsz)
+ for i in range(len(bs)):
+ bs[i] = [bs[i][j][0] for j in range(len(bs[i]))]
+
+ f0_res = [None] * len(wav_fns)
+ for batch in tqdm(bs, total=len(bs), desc=f'| Processing f0 in [max_tokens={max_tokens}; max_sentences={bsz}]'):
+ wavs, mel_lengths, lengths = [], [], []
+ for idx in batch:
+ if type(wav_fns[idx]) == str:
+ wav_fn = wav_fns[idx]
+ wav, _ = librosa.core.load(wav_fn, sr=sr)
+ else:
+ wav = wav_fns[idx]
+ wavs.append(wav)
+ mel_lengths.append(math.ceil((wav.shape[0] + 1) / hop_size))
+ lengths.append((wav.shape[0] + hop_size - 1) // hop_size)
+
+ with torch.no_grad():
+ f0s, uvs = rmvpe.get_pitch_batch(
+ wavs, sample_rate=sr,
+ hop_size=hop_size,
+ lengths=lengths,
+ fmax=fmax,
+ fmin=fmin
+ )
+
+ for i, idx in enumerate(batch):
+ f0_res[idx] = f0s[i]
+
+ if rmvpe is not None:
+ rmvpe.release_cuda()
+ torch.cuda.empty_cache()
+ rmvpe = None
+
+ return f0_res
+
+
+
diff --git a/preprocess/tools/note_transcription/modules/pe/rmvpe/inference.py b/preprocess/tools/note_transcription/modules/pe/rmvpe/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..d28022228c7f8335aa99206611149e1c49495b03
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/pe/rmvpe/inference.py
@@ -0,0 +1,134 @@
+import math
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torchaudio.transforms import Resample
+import pyworld as pw
+
+from ....utils.audio.pitch_utils import interp_f0, resample_align_curve
+from .constants import *
+from .model import E2E0
+from .spec import MelSpectrogram
+from .utils import to_local_average_f0, to_viterbi_f0
+
+
+class RMVPE:
+ def __init__(self, model_path, hop_length=160, device=None):
+ self.resample_kernel = {}
+ if device is None:
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ else:
+ self.device = device
+ self.model = E2E0(4, 1, (2, 2)).eval().to(self.device)
+ ckpt = torch.load(model_path, map_location=self.device)
+ self.model.load_state_dict(ckpt['model'], strict=False)
+ self.mel_extractor = MelSpectrogram(
+ N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX
+ ).to(self.device)
+ self.hop_length = hop_length
+
+ @torch.no_grad()
+ def mel2hidden(self, mel):
+ n_frames = mel.shape[-1]
+ mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='constant')
+ hidden = self.model(mel)
+ return hidden[:, :n_frames]
+
+ def decode(self, hidden, thred=0.03, use_viterbi=False):
+ if use_viterbi:
+ f0 = to_viterbi_f0(hidden, thred=thred)
+ else:
+ f0 = to_local_average_f0(hidden, thred=thred)
+ return f0
+
+ def postprocess(self, f0, fmin=50, fmax=1000, audio=None, min_gap=2):
+ if audio is not None:
+ # this doesn't work. deprecated
+ t = np.arange(0, f0.shape[0] * self.hop_length / 16000, self.hop_length / 16000)
+ f0 = pw.stonemask(audio.astype(np.float64), f0.astype(np.float64), t, 16000).astype(float)
+ f0[f0 < fmin] = 0
+ f0[f0 > fmax] = 0
+ # eliminate glitch
+ # min_gap: if successive positive f0 positions < min_gap, zero these positions
+ # eg: if min_gap=2, [0, 500, 500, 0] => [0, 0, 0, 0]
+ for idx in range(f0.shape[0] - min_gap - 1):
+ if f0[idx] == 0 and f0[idx + min_gap + 1] == 0 and np.sum(f0[idx: idx + min_gap + 2]) > 0:
+ f0[idx: idx + min_gap + 2] = 0
+ return f0
+
+ def infer_from_audio(self, audio, sample_rate=16000, thred=0.03, use_viterbi=False):
+ audio = torch.from_numpy(audio).float().unsqueeze(0).to(self.device)
+ if sample_rate == 16000:
+ audio_res = audio
+ else:
+ key_str = str(sample_rate)
+ if key_str not in self.resample_kernel:
+ self.resample_kernel[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128)
+ self.resample_kernel[key_str] = self.resample_kernel[key_str].to(self.device)
+ audio_res = self.resample_kernel[key_str](audio)
+ mel = self.mel_extractor(audio_res, center=True)
+ hidden = self.mel2hidden(mel)
+ f0 = self.decode(hidden, thred=thred, use_viterbi=use_viterbi).squeeze(0)
+ return f0
+
+ def get_pitch(self, waveform, sample_rate, hop_size, length, interp_uv=False, fmin=50, fmax=1000):
+ f0 = self.infer_from_audio(waveform, sample_rate=sample_rate)
+ f0 = self.postprocess(f0, fmin, fmax)
+ uv = f0 == 0
+ time_step = hop_size / sample_rate
+ f0_res = resample_align_curve(f0, 0.01, time_step, length)
+ uv_res = resample_align_curve(uv.astype(np.float32), 0.01, time_step, length) > 0.5
+ if not interp_uv:
+ f0_res[uv_res] = 0
+ return f0_res, uv_res
+
+ def infer_from_audio_batch(self, audios, sample_rate=16000, thred=0.03, use_viterbi=False):
+ from ....utils.commons.dataset_utils import collate_1d_or_2d
+ if isinstance(audios, list):
+ audios = [torch.from_numpy(audio).float() for audio in audios]
+ sizes = [math.ceil((audio.shape[0] + 1) / self.hop_length) for audio in audios]
+ audios = collate_1d_or_2d(audios, 0.0).to(self.device)
+ elif isinstance(audios, torch.Tensor):
+ sizes = None
+ if audios.device != self.device:
+ audios = audios.to(self.device)
+ else:
+ raise NotImplementedError
+ if sample_rate == 16000:
+ audios_res = audios
+ else:
+ key_str = str(sample_rate)
+ if key_str not in self.resample_kernel:
+ self.resample_kernel[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128)
+ self.resample_kernel[key_str] = self.resample_kernel[key_str].to(self.device)
+ audios_res = self.resample_kernel[key_str](audios)
+ mels = self.mel_extractor(audios_res, center=True)
+ hiddens = self.mel2hidden(mels)
+ f0 = self.decode(hiddens, thred=thred, use_viterbi=use_viterbi)
+ f0s = []
+ for i in range(f0.shape[0]):
+ f = f0[i, :sizes[i]] if sizes is not None else f0[i, :]
+ f0s.append(f)
+ return f0s
+
+ def get_pitch_batch(self, waveforms, sample_rate, hop_size, lengths, interp_uv=False, fmin=50, fmax=1000):
+ # hop_size, sample_rate: tgt params
+ f0s = self.infer_from_audio_batch(waveforms, sample_rate=sample_rate)
+ f0s_res, uvs_res = [], []
+ for idx, f0 in enumerate(f0s):
+ f0 = self.postprocess(f0, fmin, fmax, min_gap=6)
+ uv = f0 == 0
+ length = lengths[idx]
+ time_step = hop_size / sample_rate
+ f0_res = resample_align_curve(f0, 0.01, time_step, length)
+ uv_res = resample_align_curve(uv.astype(np.float32), 0.01, time_step, length) > 0.5
+ if not interp_uv:
+ f0_res[uv_res] = 0
+ f0s_res.append(f0_res)
+ uvs_res.append(uv_res)
+ return f0s_res, uvs_res
+
+ def release_cuda(self):
+ self.model = self.model.cpu()
+ self.mel_extractor = self.mel_extractor.cpu()
diff --git a/preprocess/tools/note_transcription/modules/pe/rmvpe/model.py b/preprocess/tools/note_transcription/modules/pe/rmvpe/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b2d72cfbde608869c3f6e884cd1a31326cc9164
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/pe/rmvpe/model.py
@@ -0,0 +1,32 @@
+from torch import nn
+
+from .constants import *
+from .deepunet import DeepUnet0
+from .seq import BiGRU
+
+
+class E2E0(nn.Module):
+ def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1,
+ en_out_channels=16):
+ super(E2E0, self).__init__()
+ self.unet = DeepUnet0(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
+ if n_gru:
+ self.fc = nn.Sequential(
+ BiGRU(3 * N_MELS, 256, n_gru),
+ nn.Linear(512, N_CLASS),
+ nn.Dropout(0.25),
+ nn.Sigmoid()
+ )
+ else:
+ self.fc = nn.Sequential(
+ nn.Linear(3 * N_MELS, N_CLASS),
+ nn.Dropout(0.25),
+ nn.Sigmoid()
+ )
+
+ def forward(self, mel):
+ mel = mel.transpose(-1, -2).unsqueeze(1)
+ x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
+ x = self.fc(x)
+ return x
diff --git a/preprocess/tools/note_transcription/modules/pe/rmvpe/seq.py b/preprocess/tools/note_transcription/modules/pe/rmvpe/seq.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c4c8f880d502c8a4abc6bdf40af86f08be744d7
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/pe/rmvpe/seq.py
@@ -0,0 +1,10 @@
+import torch.nn as nn
+
+
+class BiGRU(nn.Module):
+ def __init__(self, input_features, hidden_features, num_layers):
+ super(BiGRU, self).__init__()
+ self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
+
+ def forward(self, x):
+ return self.gru(x)[0]
diff --git a/preprocess/tools/note_transcription/modules/pe/rmvpe/spec.py b/preprocess/tools/note_transcription/modules/pe/rmvpe/spec.py
new file mode 100644
index 0000000000000000000000000000000000000000..675b88ad03f1298724290829b82c5008832d4585
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/pe/rmvpe/spec.py
@@ -0,0 +1,72 @@
+import torch
+import numpy as np
+import torch.nn.functional as F
+from librosa.filters import mel
+
+
+class MelSpectrogram(torch.nn.Module):
+ def __init__(
+ self,
+ n_mel_channels,
+ sampling_rate,
+ win_length,
+ hop_length,
+ n_fft=None,
+ mel_fmin=0,
+ mel_fmax=None,
+ clamp=1e-5
+ ):
+ super().__init__()
+ n_fft = win_length if n_fft is None else n_fft
+ self.hann_window = {}
+ mel_basis = mel(
+ sr=sampling_rate,
+ n_fft=n_fft,
+ n_mels=n_mel_channels,
+ fmin=mel_fmin,
+ fmax=mel_fmax,
+ htk=True)
+ mel_basis = torch.from_numpy(mel_basis).float()
+ self.register_buffer("mel_basis", mel_basis)
+ self.n_fft = win_length if n_fft is None else n_fft
+ self.hop_length = hop_length
+ self.win_length = win_length
+ self.sampling_rate = sampling_rate
+ self.n_mel_channels = n_mel_channels
+ self.clamp = clamp
+
+ def forward(self, audio, keyshift=0, speed=1, center=True):
+ factor = 2 ** (keyshift / 12)
+ n_fft_new = int(np.round(self.n_fft * factor))
+ win_length_new = int(np.round(self.win_length * factor))
+ hop_length_new = int(np.round(self.hop_length * speed))
+
+ keyshift_key = str(keyshift) + '_' + str(audio.device)
+ if keyshift_key not in self.hann_window:
+ self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
+ if center:
+ pad_left = win_length_new // 2
+ pad_right = (win_length_new + 1) // 2
+ audio = F.pad(audio, (pad_left, pad_right))
+
+ fft = torch.stft(
+ audio,
+ n_fft=n_fft_new,
+ hop_length=hop_length_new,
+ win_length=win_length_new,
+ window=self.hann_window[keyshift_key],
+ center=False,
+ return_complex=True
+ )
+ magnitude = fft.abs()
+
+ if keyshift != 0:
+ size = self.n_fft // 2 + 1
+ resize = magnitude.size(1)
+ if resize < size:
+ magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
+
+ mel_output = torch.matmul(self.mel_basis, magnitude)
+ log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
+ return log_mel_spec
diff --git a/preprocess/tools/note_transcription/modules/pe/rmvpe/utils.py b/preprocess/tools/note_transcription/modules/pe/rmvpe/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e852e4bcfeea4cbf2e11eee13e0226a0ae258a1
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/pe/rmvpe/utils.py
@@ -0,0 +1,43 @@
+import librosa
+import numpy as np
+import torch
+
+from .constants import *
+
+
+def to_local_average_f0(hidden, center=None, thred=0.03):
+ idx = torch.arange(N_CLASS, device=hidden.device)[None, None, :] # [B=1, T=1, N]
+ idx_cents = idx * 20 + CONST # [B=1, N]
+ if center is None:
+ center = torch.argmax(hidden, dim=2, keepdim=True) # [B, T, 1]
+ start = torch.clip(center - 4, min=0) # [B, T, 1]
+ end = torch.clip(center + 5, max=N_CLASS) # [B, T, 1]
+ idx_mask = (idx >= start) & (idx < end) # [B, T, N]
+ weights = hidden * idx_mask # [B, T, N]
+ product_sum = torch.sum(weights * idx_cents, dim=2) # [B, T]
+ weight_sum = torch.sum(weights, dim=2) # [B, T]
+ cents = product_sum / (weight_sum + (weight_sum == 0)) # avoid dividing by zero, [B, T]
+ f0 = 10 * 2 ** (cents / 1200)
+ uv = hidden.max(dim=2)[0] < thred # [B, T]
+ f0 = f0 * ~uv
+ return f0.cpu().numpy()
+
+
+def to_viterbi_f0(hidden, thred=0.03):
+ # Create viterbi transition matrix
+ if not hasattr(to_viterbi_f0, 'transition'):
+ xx, yy = np.meshgrid(range(N_CLASS), range(N_CLASS))
+ transition = np.maximum(30 - abs(xx - yy), 0)
+ transition = transition / transition.sum(axis=1, keepdims=True)
+ to_viterbi_f0.transition = transition
+
+ # Convert to probability
+ prob = hidden.squeeze(0).cpu().numpy()
+ prob = prob.T
+ prob = prob / prob.sum(axis=0)
+
+ # Perform viterbi decoding
+ path = librosa.sequence.viterbi(prob, to_viterbi_f0.transition).astype(np.int64)
+ center = torch.from_numpy(path).unsqueeze(0).unsqueeze(-1).to(hidden.device)
+
+ return to_local_average_f0(hidden, center=center, thred=thred)
diff --git a/preprocess/tools/note_transcription/modules/rosvot/__init__.py b/preprocess/tools/note_transcription/modules/rosvot/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ae04829795bef7e34b2cd4b922153eeff067d9d
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/rosvot/__init__.py
@@ -0,0 +1 @@
+"""Core ROSVOT model components."""
diff --git a/preprocess/tools/note_transcription/modules/rosvot/rosvot.py b/preprocess/tools/note_transcription/modules/rosvot/rosvot.py
new file mode 100644
index 0000000000000000000000000000000000000000..536802713694829c35276c7b8f04d5af8a46da44
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/rosvot/rosvot.py
@@ -0,0 +1,295 @@
+from copy import deepcopy
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from ...utils.commons.hparams import hparams
+from ...utils.commons.gpu_mem_track import MemTracker
+from ..commons.layers import Embedding
+from ..commons.conv import ResidualBlock, ConvBlocks
+from ..commons.conformer.conformer import ConformerLayers
+from .unet import Unet
+
+def regulate_boundary(bd_logits, threshold, min_gap=18, ref_bd=None, ref_bd_min_gap=8, non_padding=None):
+ # this doesn't preserve gradient
+ device = bd_logits.device
+ bd_logits = torch.sigmoid(bd_logits).data.cpu()
+ # bd_logits[0] = bd_logits[-1] = 1e-5 # avoid itv invalid problem
+ bd = (bd_logits > threshold).long()
+ bd_res = torch.zeros_like(bd).long()
+ for i in range(bd.shape[0]):
+ bd_i = bd[i]
+ last_bd_idx = -1
+ start = -1
+ for j in range(bd_i.shape[0]):
+ if bd_i[j] == 1:
+ if 0 <= start < j:
+ continue
+ elif start < 0:
+ start = j
+ else:
+ if 0 <= start < j:
+ if j - 1 > start:
+ bd_idx = start + int(torch.argmax(bd_logits[i, start: j]).item())
+ else:
+ bd_idx = start
+ if bd_idx - last_bd_idx < min_gap and last_bd_idx > 0:
+ bd_idx = round((bd_idx + last_bd_idx) / 2)
+ bd_res[i, last_bd_idx] = 0
+ bd_res[i, bd_idx] = 1
+ last_bd_idx = bd_idx
+ start = -1
+
+ # assert ref_bd_min_gap <= min_gap // 2
+ if ref_bd is not None and ref_bd_min_gap > 0:
+ ref = ref_bd.data.cpu()
+ for i in range(bd_res.shape[0]):
+ ref_bd_i = ref[i]
+ ref_bd_i_js = []
+ for j in range(ref_bd_i.shape[0]):
+ if ref_bd_i[j] == 1:
+ ref_bd_i_js.append(j)
+ seg_sum = torch.sum(bd_res[i, max(0, j - ref_bd_min_gap): j + ref_bd_min_gap])
+ if seg_sum == 0:
+ bd_res[i, j] = 1
+ elif seg_sum == 1 and bd_res[i, j] != 1:
+ bd_res[i, max(0, j - ref_bd_min_gap): j + ref_bd_min_gap] = \
+ ref_bd_i[max(0, j - ref_bd_min_gap): j + ref_bd_min_gap]
+ elif seg_sum > 1:
+ for k in range(1, ref_bd_min_gap+1):
+ if bd_res[i, max(0, j - k)] == 1 and ref_bd_i[max(0, j - k)] != 1:
+ bd_res[i, max(0, j - k)] = 0
+ break
+ if bd_res[i, min(bd_res.shape[1] - 1, j + k)] == 1 and ref_bd_i[min(bd_res.shape[1] - 1, j + k)] != 1:
+ bd_res[i, min(bd_res.shape[1] - 1, j + k)] = 0
+ break
+ bd_res[i, j] = 1
+ # final check
+ assert torch.sum(bd_res[i, ref_bd_i_js]) == len(ref_bd_i_js), \
+ f"{torch.sum(bd_res[i, ref_bd_i_js])} {len(ref_bd_i_js)}"
+
+ bd_res = bd_res.to(device)
+
+ # force valid begin and end
+ bd_res[:, 0] = 0
+ if non_padding is not None:
+ for i in range(bd_res.shape[0]):
+ bd_res[i, sum(non_padding[i]) - 1:] = 0
+ else:
+ bd_res[:, -1] = 0
+
+ return bd_res
+
+class BackboneNet(nn.Module):
+ def __init__(self, hparams):
+ super().__init__()
+ self.hidden_size = hidden_size = hparams['hidden_size']
+ self.dropout = hparams.get('dropout', 0.0)
+ updown_rates = [2, 2, 2]
+ channel_multiples = [1, 1, 1]
+ if hparams.get('updown_rates', None) is not None:
+ updown_rates = [int(i) for i in hparams.get('updown_rates', None).split('-')]
+ if hparams.get('channel_multiples', None) is not None:
+ channel_multiples = [float(i) for i in hparams.get('channel_multiples', None).split('-')]
+ assert len(updown_rates) == len(channel_multiples)
+ # convs
+ if hparams.get('bkb_net', 'conv') == 'conv':
+ self.net = Unet(hidden_size, down_layers=len(updown_rates), mid_layers=hparams.get('bkb_layers', 12),
+ up_layers=len(updown_rates), kernel_size=3, updown_rates=updown_rates,
+ channel_multiples=channel_multiples, dropout=0, is_BTC=True,
+ constant_channels=False, mid_net=None, use_skip_layer=hparams.get('unet_skip_layer', False))
+ # conformer
+ elif hparams.get('bkb_net', 'conv') == 'conformer':
+ mid_net = ConformerLayers(
+ hidden_size, num_layers=hparams.get('bkb_layers', 12), kernel_size=hparams.get('conformer_kernel', 9),
+ dropout=self.dropout, num_heads=4)
+ self.net = Unet(hidden_size, down_layers=len(updown_rates), up_layers=len(updown_rates), kernel_size=3,
+ updown_rates=updown_rates, channel_multiples=channel_multiples, dropout=0,
+ is_BTC=True, constant_channels=False, mid_net=mid_net,
+ use_skip_layer=hparams.get('unet_skip_layer', False))
+
+ def forward(self, x):
+ return self.net(x)
+
+class PitchDecoder(nn.Module):
+ def __init__(self, hparams):
+ super().__init__()
+ self.hidden_size = hidden_size = hparams['hidden_size']
+ self.dropout = hparams.get('dropout', 0.0)
+ self.note_bd_out = nn.Linear(hidden_size, 1)
+ self.note_bd_temperature = max(1e-7, hparams.get('note_bd_temperature', 1.0))
+
+ # note prediction
+ self.pitch_attn_num_head = hparams.get('pitch_attn_num_head', 1)
+ self.multihead_dot_attn = nn.Linear(hidden_size, self.pitch_attn_num_head)
+ self.post = ConvBlocks(hidden_size, out_dims=hidden_size, dilations=None, kernel_size=3,
+ layers_in_block=1, c_multiple=1, dropout=self.dropout, num_layers=1,
+ post_net_kernel=3, act_type='leakyrelu')
+ self.pitch_out = nn.Linear(hidden_size, hparams.get('note_num', 100) + 4)
+ self.note_num = hparams.get('note_num', 100)
+ self.note_start = hparams.get('note_start', 30)
+ self.pitch_temperature = max(1e-7, hparams.get('note_pitch_temperature', 1.0))
+
+ def forward(self, feat, note_bd, train=True):
+ bsz, T, _ = feat.shape
+
+ attn = torch.sigmoid(self.multihead_dot_attn(feat)) # [B, T, C] -> [B, T, num_head]
+ attn = F.dropout(attn, self.dropout, train)
+ attn_feat = feat.unsqueeze(3) * attn.unsqueeze(2) # [B, T, C, 1] x [B, T, 1, num_head] -> [B, T, C, num_head]
+ attn_feat = torch.mean(attn_feat, dim=-1) # [B, T, C, num_head] -> [B, T, C]
+ mel2note = torch.cumsum(note_bd, 1)
+ note_length = torch.max(torch.sum(note_bd, dim=1)).item() + 1 # max length
+ note_lengths = torch.sum(note_bd, dim=1) + 1 # [B]
+ # print('note_length', note_length)
+
+ attn = torch.mean(attn, dim=-1, keepdim=True) # [B, T, num_head] -> [B, T, 1]
+ denom = mel2note.new_zeros(bsz, note_length, dtype=attn.dtype).scatter_add_(
+ dim=1, index=mel2note, src=attn.squeeze(-1)
+ ) # [B, T] -> [B, note_length] count the note frames of each note (with padding excluded)
+ frame2note = mel2note.unsqueeze(-1).repeat(1, 1, self.hidden_size) # [B, T] -> [B, T, C], with padding included
+ note_aggregate = frame2note.new_zeros(bsz, note_length, self.hidden_size, dtype=attn_feat.dtype).scatter_add_(
+ dim=1, index=frame2note, src=attn_feat
+ ) # [B, T, C] -> [B, note_length, C]
+ note_aggregate = note_aggregate / (denom.unsqueeze(-1) + 1e-5)
+ note_aggregate = F.dropout(note_aggregate, self.dropout, train)
+ note_logits = self.post(note_aggregate)
+ note_logits = self.pitch_out(note_logits) / self.pitch_temperature
+ # note_logits = torch.clamp(note_logits, min=-16., max=16.) # don't know need it or not
+
+ note_pred = torch.softmax(note_logits, dim=-1) # [B, note_length, note_num]
+ note_pred = torch.argmax(note_pred, dim=-1) # [B, note_length]
+ # for some reason, note idx maybe 130 (why?)
+ note_pred[note_pred > self.note_num] = 0
+ note_pred[note_pred < self.note_start] = 0
+
+ return note_lengths, note_logits, note_pred
+
+class MidiExtractor(nn.Module):
+ def __init__(self, hparams):
+ super(MidiExtractor, self).__init__()
+ self.hparams = deepcopy(hparams)
+ self.hidden_size = hidden_size = hparams['hidden_size']
+ self.dropout = hparams.get('dropout', 0.0)
+ self.note_bd_threshold = hparams.get('note_bd_threshold', 0.5)
+ self.note_bd_min_gap = round(hparams.get('note_bd_min_gap', 100) * hparams['audio_sample_rate'] / 1000 / hparams['hop_size'])
+ self.note_bd_ref_min_gap = round(hparams.get('note_bd_ref_min_gap', 50) * hparams['audio_sample_rate'] / 1000 / hparams['hop_size'])
+
+ self.mel_proj = nn.Conv1d(hparams['use_mel_bins'], hidden_size, kernel_size=3, padding=1)
+ self.mel_encoder = ConvBlocks(hidden_size, out_dims=hidden_size, dilations=None, kernel_size=3,
+ layers_in_block=2, c_multiple=1, dropout=self.dropout, num_layers=1,
+ post_net_kernel=3, act_type='leakyrelu')
+ self.use_pitch = hparams.get('use_pitch_embed', True)
+ if self.use_pitch:
+ self.pitch_embed = Embedding(300, hidden_size, 0, 'kaiming')
+ self.uv_embed = Embedding(3, hidden_size, 0, 'kaiming')
+ self.use_wbd = hparams.get('use_wbd', True)
+ if self.use_wbd:
+ self.word_bd_embed = Embedding(3, hidden_size, 0, 'kaiming')
+ self.cond_encoder = ConvBlocks(hidden_size, out_dims=hidden_size, dilations=None, kernel_size=3,
+ layers_in_block=1, c_multiple=1, dropout=self.dropout, num_layers=1,
+ post_net_kernel=3, act_type='leakyrelu')
+
+ # backbone
+ self.net = BackboneNet(hparams)
+
+ # note bd prediction
+ self.note_bd_out = nn.Linear(hidden_size, 1)
+ self.note_bd_temperature = max(1e-7, hparams.get('note_bd_temperature', 1.0))
+
+ # note prediction
+ self.pitch_decoder = PitchDecoder(hparams)
+
+ self.reset_parameters()
+
+ def run_encoder(self, mel=None, word_bd=None, pitch=None, uv=None, non_padding=None):
+ mel_embed = self.mel_proj(mel.transpose(1, 2)).transpose(1, 2)
+ mel_embed = self.mel_encoder(mel_embed)
+ pitch_embed = word_bd_embed = 0
+ if self.use_pitch and pitch is not None and uv is not None:
+ pitch_embed = self.pitch_embed(pitch) + self.uv_embed(uv) # [B, T, C]
+ if self.use_wbd and word_bd is not None:
+ word_bd_embed = self.word_bd_embed(word_bd)
+ feat = self.cond_encoder(mel_embed + pitch_embed + word_bd_embed)
+
+ return feat
+
+ def forward(self, mel=None, word_bd=None, note_bd=None, pitch=None, uv=None, non_padding=None, train=True):
+ ret = {}
+ bsz, T, _ = mel.shape
+
+ feat = self.run_encoder(mel, word_bd, pitch, uv, non_padding)
+ feat = self.net(feat) # [B, T, C]
+
+ # note bd prediction
+ note_bd_logits = self.note_bd_out(F.dropout(feat, self.dropout, train)).squeeze(-1) / self.note_bd_temperature
+ note_bd_logits = torch.clamp(note_bd_logits, min=-16., max=16.)
+ ret['note_bd_logits'] = note_bd_logits # [B, T]
+ if note_bd is None or not train:
+ note_bd = regulate_boundary(note_bd_logits, self.note_bd_threshold, self.note_bd_min_gap,
+ word_bd, self.note_bd_ref_min_gap, non_padding)
+ ret['note_bd_pred'] = note_bd # [B, T]
+
+ # note pitch prediction
+ note_lengths, note_logits, note_pred = self.pitch_decoder(feat, note_bd, train)
+ ret['note_lengths'], ret['note_logits'], ret['note_pred'] = note_lengths, note_logits, note_pred
+
+ return ret
+
+ def reset_parameters(self):
+ nn.init.kaiming_normal_(self.pitch_decoder.multihead_dot_attn.weight, mode='fan_in')
+ nn.init.kaiming_normal_(self.note_bd_out.weight, mode='fan_in')
+ nn.init.kaiming_normal_(self.pitch_decoder.pitch_out.weight, mode='fan_in')
+ nn.init.kaiming_normal_(self.mel_proj.weight, mode='fan_in')
+ nn.init.constant_(self.pitch_decoder.multihead_dot_attn.bias, 0.0)
+ nn.init.constant_(self.note_bd_out.bias, 0.0)
+ nn.init.constant_(self.pitch_decoder.pitch_out.bias, 0.0)
+
+
+class WordbdExtractor(MidiExtractor):
+ def __init__(self, hparams):
+ super().__init__(hparams)
+ self.use_wbd = False
+ self.word_bd_embed = None
+ self.note_bd_out = self.note_bd_temperature = self.pitch_decoder = None
+
+ self.word_bd_threshold = hparams.get('word_bd_threshold', 0.5)
+ self.word_bd_min_gap = round(
+ hparams.get('word_bd_min_gap', 100) * hparams['audio_sample_rate'] / 1000 / hparams['hop_size'])
+
+ self.word_bd_out = nn.Linear(self.hidden_size, 1)
+ self.word_bd_temperature = max(1e-7, hparams.get('word_bd_temperature', 1.0))
+ nn.init.kaiming_normal_(self.word_bd_out.weight, mode='fan_in')
+ nn.init.constant_(self.word_bd_out.bias, 0.0)
+
+ def forward(self, mel=None, pitch=None, uv=None, non_padding=None, train=True):
+ # gpu_tracker.track()
+ ret = {}
+ bsz, T, _ = mel.shape
+
+ feat = self.run_encoder(mel=mel, pitch=pitch, uv=uv, non_padding=non_padding)
+ feat = self.net(feat) # [B, T, C]
+
+ word_bd_logits = self.word_bd_out(F.dropout(feat, self.dropout, train)).squeeze(-1) / self.word_bd_temperature
+ word_bd_logits = torch.clamp(word_bd_logits, min=-16., max=16.)
+ ret['word_bd_logits'] = word_bd_logits # [B, T]
+
+ if not train:
+ word_bd = regulate_boundary(word_bd_logits, self.word_bd_threshold, self.word_bd_min_gap,
+ non_padding=non_padding)
+ ret['word_bd_pred'] = word_bd # [B, T]
+
+ return ret
+
+ def reset_parameters(self):
+ if self.use_pitch:
+ nn.init.kaiming_normal_(self.pitch_embed.weight, mode='fan_in')
+ nn.init.kaiming_normal_(self.uv_embed.weight, mode='fan_in')
+ nn.init.kaiming_normal_(self.mel_proj.weight, mode='fan_in')
+ if self.use_pitch:
+ nn.init.constant_(self.pitch_embed.weight[self.pitch_embed.padding_idx], 0.0)
+ nn.init.constant_(self.uv_embed.weight[self.uv_embed.padding_idx], 0.0)
+
+
diff --git a/preprocess/tools/note_transcription/modules/rosvot/unet.py b/preprocess/tools/note_transcription/modules/rosvot/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7e56ea087970a8be1896a7183b6bdda6b6c0479
--- /dev/null
+++ b/preprocess/tools/note_transcription/modules/rosvot/unet.py
@@ -0,0 +1,172 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..commons.layers import LayerNorm, Embedding
+from ..commons.conv import ConvBlocks, ResidualBlock, get_norm_builder, get_act_builder
+
+class UnetDown(nn.Module):
+ def __init__(self, hidden_size, n_layers, kernel_size, down_rates, channel_multiples=None, dropout=0.0,
+ is_BTC=True, constant_channels=False):
+ super(UnetDown, self).__init__()
+ assert n_layers == len(down_rates) # downs, down sample rate
+ down_rates = [int(i) for i in down_rates]
+ self.n_layers = n_layers
+ self.hidden_size = hidden_size
+ self.is_BTC = is_BTC
+ channel_multiples = channel_multiples if channel_multiples is not None else down_rates
+ self.layers = nn.ModuleList()
+ self.downs = nn.ModuleList()
+ in_channels = hidden_size
+ for i in range(self.n_layers):
+ out_channels = int(in_channels * channel_multiples[i]) if not constant_channels else in_channels
+ self.layers.append(nn.Sequential(
+ ResidualBlock(in_channels, kernel_size, dilation=1, n=1, norm_type='ln', dropout=dropout,
+ c_multiple=1, ln_eps=1e-5, act_type='leakyrelu'),
+ nn.Conv1d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2),
+ ResidualBlock(out_channels, kernel_size, dilation=1, n=1, norm_type='ln',
+ dropout=dropout, c_multiple=1, ln_eps=1e-5, act_type='leakyrelu')
+ ))
+ self.downs.append(nn.Sequential(
+ nn.AvgPool1d(down_rates[i])
+ ))
+ in_channels = out_channels
+ self.last_norm = get_norm_builder('ln', out_channels)()
+ self.post_net = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size,
+ padding=kernel_size // 2)
+
+ def forward(self, x, **kwargs):
+ # x [B, T, C]
+ if self.is_BTC:
+ x = x.transpose(1, 2) # [B, C, T]
+ skip_xs = []
+ for i in range(self.n_layers):
+ skip_x = self.layers[i](x)
+ x = self.downs[i](skip_x)
+ if self.is_BTC:
+ skip_xs.append(skip_x.transpose(1, 2)) # [B, T, C]
+ else:
+ skip_xs.append(skip_x)
+ x = self.post_net(self.last_norm(x))
+ if self.is_BTC:
+ x = x.transpose(1, 2)
+ return x, skip_xs
+
+class UnetMid(nn.Module):
+ def __init__(self, hidden_size, kernel_size, n_layers=None, in_dims=None, out_dims=None,
+ dropout=0.0, is_BTC=True, net=None):
+ super(UnetMid, self).__init__()
+ in_dims = in_dims if in_dims is not None else hidden_size
+ out_dims = out_dims if out_dims is not None else hidden_size
+ self.pre = nn.Conv1d(in_dims, hidden_size, kernel_size, padding=kernel_size // 2)
+ self.post = nn.Conv1d(hidden_size, out_dims, kernel_size, padding=kernel_size // 2)
+ self.is_BTC = is_BTC
+ if net is not None:
+ self.net = net
+ else:
+ self.net = ConvBlocks(hidden_size, out_dims=hidden_size, dilations=None, kernel_size=kernel_size,
+ layers_in_block=2, c_multiple=2, dropout=dropout, num_layers=n_layers,
+ post_net_kernel=3, act_type='leakyrelu', is_BTC=is_BTC)
+
+ def forward(self, x, cond=None, **kwargs):
+ # x [B, T, C]
+ if self.is_BTC:
+ x = self.pre(x.transpose(1, 2)).transpose(1, 2)
+ else:
+ x = self.pre(x)
+ if cond is None:
+ cond = 0
+ x = self.net(x + cond)
+ if self.is_BTC:
+ x = self.post(x.transpose(1, 2)).transpose(1, 2)
+ else:
+ x = self.post(x)
+ return x
+
+class UnetUp(nn.Module):
+ def __init__(self, hidden_size, n_layers, kernel_size, up_rates, channel_multiples=None, dropout=0.0,
+ is_BTC=True, constant_channels=False, use_skip_layer=False, skip_scale=1.0):
+ super(UnetUp, self).__init__()
+ assert n_layers == len(up_rates) # this is reversed in up module, from the output to the interface with middle
+ up_rates = [int(i) for i in up_rates]
+ self.n_layers = n_layers
+ self.hidden_size = hidden_size
+ self.is_BTC = is_BTC
+ self.skip_scale = skip_scale
+ channel_multiples = channel_multiples if channel_multiples is not None else up_rates
+ # in_channels = int(np.cumprod(channel_multiples)[-1] * hidden_size) if not constant_channels else hidden_size
+ self.in_channels_lst = (np.cumprod([1] + channel_multiples) * hidden_size).astype(int) if not constant_channels \
+ else [hidden_size for _ in range(self.n_layers + 1)]
+ in_channels = self.in_channels_lst[-1]
+ self.ups = nn.ModuleList()
+ self.skip_layers = nn.ModuleList()
+ self.layers = nn.ModuleList()
+ for i in range(self.n_layers-1, -1, -1):
+ out_channels = self.in_channels_lst[i] if not constant_channels else in_channels
+ self.ups.append(nn.Sequential(
+ nn.ConvTranspose1d(in_channels, in_channels, kernel_size=kernel_size, stride=up_rates[i],
+ padding=kernel_size//2, output_padding=up_rates[i]-1),
+ get_norm_builder('ln', in_channels)(),
+ get_act_builder('leakyrelu')()
+ ))
+ self.layers.append(nn.Sequential(
+ # ResidualBlock(in_channels*2, kernel_size, dilation=1, n=1, norm_type='ln', dropout=dropout,
+ # c_multiple=1, ln_eps=1e-5, act_type='leakyrelu'),
+ nn.Conv1d(in_channels*2, out_channels, kernel_size, padding=(kernel_size - 1) // 2),
+ ResidualBlock(out_channels, kernel_size, dilation=1, n=1, norm_type='ln',
+ dropout=dropout, c_multiple=1, ln_eps=1e-5, act_type='leakyrelu')
+ ))
+ if use_skip_layer:
+ self.skip_layers.append(
+ ResidualBlock(in_channels, kernel_size, dilation=1, n=1, norm_type='ln', dropout=dropout,
+ c_multiple=1, ln_eps=1e-5, act_type='leakyrelu')
+ )
+ else:
+ self.skip_layers.append(nn.Identity())
+
+ in_channels = out_channels
+ self.out_channels = out_channels
+ self.last_norm = get_norm_builder('ln', out_channels)()
+ self.post_net = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size,
+ padding=kernel_size // 2)
+
+ def forward(self, x, skips, **kwargs):
+ # x [B, T, C]
+ if self.is_BTC:
+ x = x.transpose(1, 2) # [B, C, T]
+ for i in range(self.n_layers):
+ x = self.ups[i](x)
+ skip_x = skips[self.n_layers - i - 1] if not self.is_BTC \
+ else skips[self.n_layers - i - 1].transpose(1, 2) # [B, T, C] -> [B, C, T]
+ skip_x = self.skip_layers[i](skip_x) * self.skip_scale
+ x = torch.cat((x, skip_x), dim=1) # [B, C, T]
+ x = self.layers[i](x)
+ x = self.post_net(self.last_norm(x))
+ if self.is_BTC:
+ x = x.transpose(1, 2)
+ return x
+
+class Unet(nn.Module):
+ def __init__(self, hidden_size, down_layers, up_layers, kernel_size,
+ updown_rates, mid_layers=None, channel_multiples=None, dropout=0.0,
+ is_BTC=True, constant_channels=False, mid_net=None, use_skip_layer=False, skip_scale=1.0):
+ super(Unet, self).__init__()
+ assert len(updown_rates) == down_layers == up_layers, f"{len(updown_rates)}, {down_layers}, {up_layers}"
+ if channel_multiples is not None:
+ assert len(channel_multiples) == len(updown_rates)
+ else:
+ channel_multiples = updown_rates
+ self.down = UnetDown(hidden_size, down_layers, kernel_size, updown_rates,
+ channel_multiples, dropout, is_BTC, constant_channels)
+ down_out_dims = int(np.cumprod(channel_multiples)[-1] * hidden_size) if not constant_channels else hidden_size
+ self.mid = UnetMid(hidden_size, kernel_size, mid_layers,
+ in_dims=down_out_dims, out_dims=down_out_dims, dropout=dropout, is_BTC=is_BTC, net=mid_net)
+ self.up = UnetUp(hidden_size, up_layers, kernel_size, updown_rates,
+ channel_multiples, dropout, is_BTC, constant_channels, use_skip_layer, skip_scale)
+
+ def forward(self, x, mid_cond=None, **kwargs):
+ x, skips = self.down(x)
+ x = self.mid(x, mid_cond)
+ x = self.up(x, skips)
+ return x
diff --git a/preprocess/tools/note_transcription/utils/__init__.py b/preprocess/tools/note_transcription/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6730536de2088d8559c947eac60001764ed2e89
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/__init__.py
@@ -0,0 +1,15 @@
+
+def seed_everything(seed: int, seed_cudnn=False):
+ import random, os
+ import numpy as np
+ import torch
+
+ random.seed(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ if seed_cudnn:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = True
+
diff --git a/preprocess/tools/note_transcription/utils/audio/__init__.py b/preprocess/tools/note_transcription/utils/audio/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdaae23d8239b4d695a1cf690d4f70df6c352d40
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/audio/__init__.py
@@ -0,0 +1,100 @@
+import librosa
+import numpy as np
+import wave
+import soundfile as sf
+
+
+def librosa_pad_lr(x, fsize, fshift, pad_sides=1):
+ '''compute right padding (final frame) or both sides padding (first and final frames)
+ '''
+ assert pad_sides in (1, 2)
+ # return int(fsize // 2)
+ pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0]
+ if pad_sides == 1:
+ return 0, pad
+ else:
+ return pad // 2, pad // 2 + pad % 2
+
+
+def amp_to_db(x):
+ return 20 * np.log10(np.maximum(1e-5, x))
+
+
+def db_to_amp(x):
+ return 10.0 ** (x * 0.05)
+
+
+def normalize(S, min_level_db):
+ return (S - min_level_db) / -min_level_db
+
+
+def denormalize(D, min_level_db):
+ return (D * -min_level_db) + min_level_db
+
+
+def librosa_wav2spec(wav_path,
+ fft_size=1024,
+ hop_size=256,
+ win_length=1024,
+ window="hann",
+ num_mels=80,
+ fmin=80,
+ fmax=-1,
+ eps=1e-6,
+ sample_rate=22050,
+ loud_norm=False,
+ trim_long_sil=False):
+ import pyloudnorm as pyln
+ if isinstance(wav_path, str):
+ if trim_long_sil:
+ from .vad import trim_long_silences
+ wav, _, _ = trim_long_silences(wav_path, sample_rate)
+ else:
+ wav, _ = librosa.core.load(wav_path, sr=sample_rate)
+ else:
+ wav = wav_path
+ wav_orig = np.copy(wav)
+
+ if loud_norm:
+ meter = pyln.Meter(sample_rate) # create BS.1770 meter
+ loudness = meter.integrated_loudness(wav)
+ wav = pyln.normalize.loudness(wav, loudness, -22.0)
+ if np.abs(wav).max() > 1:
+ wav = wav / np.abs(wav).max()
+
+ # get amplitude spectrogram
+ x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size,
+ win_length=win_length, window=window, pad_mode="constant")
+ linear_spc = np.abs(x_stft) # (n_bins, T)
+
+ # get mel basis
+ fmin = 0 if fmin == -1 else fmin
+ fmax = sample_rate / 2 if fmax == -1 else fmax
+ mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
+
+ # calculate mel spec
+ mel = mel_basis @ linear_spc
+ mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T)
+ l_pad, r_pad = librosa_pad_lr(wav, fft_size, hop_size, 1)
+ wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)
+ wav = wav[:mel.shape[1] * hop_size]
+
+ # log linear spec
+ linear_spc = np.log10(np.maximum(eps, linear_spc))
+ return {'wav': wav, 'mel': mel.T, 'linear': linear_spc.T, 'mel_basis': mel_basis, 'wav_orig': wav_orig}
+
+def get_wav_num_frames(path, sr=None):
+ try:
+ with wave.open(path, 'rb') as f:
+ sr_ = f.getframerate()
+ if sr is None:
+ sr = sr_
+ return int(f.getnframes() / (sr_ / sr))
+ except wave.Error:
+ wav_file, sr_ = sf.read(path, dtype='float32')
+ if sr is None:
+ sr = sr_
+ return int(len(wav_file) / (sr_ / sr))
+ except:
+ wav_file, sr_ = librosa.core.load(path, sr=sr)
+ return len(wav_file)
diff --git a/preprocess/tools/note_transcription/utils/audio/align.py b/preprocess/tools/note_transcription/utils/audio/align.py
new file mode 100644
index 0000000000000000000000000000000000000000..4518deb2d1cd9415a710ac3dc0f349d1824a1d58
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/audio/align.py
@@ -0,0 +1,90 @@
+import re
+
+import torch
+import numpy as np
+
+from ..text.text_encoder import is_sil_phoneme
+
+
+def get_mel2ph(tg_fn, ph, mel, hop_size, audio_sample_rate, min_sil_duration=0):
+ from textgrid import TextGrid
+ ph_list = ph.split(" ")
+ itvs = TextGrid.fromFile(tg_fn)[1]
+ itvs_ = []
+ for i in range(len(itvs)):
+ if itvs[i].maxTime - itvs[i].minTime < min_sil_duration and i > 0 and is_sil_phoneme(itvs[i].mark):
+ itvs_[-1].maxTime = itvs[i].maxTime
+ else:
+ itvs_.append(itvs[i])
+ itvs.intervals = itvs_
+ itv_marks = [itv.mark for itv in itvs]
+ tg_len = len([x for x in itvs if not is_sil_phoneme(x.mark)])
+ ph_len = len([x for x in ph_list if not is_sil_phoneme(x)])
+ assert tg_len == ph_len, (tg_len, ph_len, itv_marks, ph_list, tg_fn)
+ mel2ph = np.zeros([mel.shape[0]], int)
+ i_itv = 0
+ i_ph = 0
+ while i_itv < len(itvs):
+ itv = itvs[i_itv]
+ ph = ph_list[i_ph]
+ itv_ph = itv.mark
+ start_frame = int(itv.minTime * audio_sample_rate / hop_size + 0.5)
+ end_frame = int(itv.maxTime * audio_sample_rate / hop_size + 0.5)
+ if is_sil_phoneme(itv_ph) and not is_sil_phoneme(ph):
+ mel2ph[start_frame:end_frame] = i_ph
+ i_itv += 1
+ elif not is_sil_phoneme(itv_ph) and is_sil_phoneme(ph):
+ i_ph += 1
+ else:
+ if not ((is_sil_phoneme(itv_ph) and is_sil_phoneme(ph)) \
+ or re.sub(r'\d+', '', itv_ph.lower()) == re.sub(r'\d+', '', ph.lower())):
+ print(f"| WARN: {tg_fn} phs are not same: ", itv_ph, ph, itv_marks, ph_list)
+ mel2ph[start_frame:end_frame] = i_ph + 1
+ i_ph += 1
+ i_itv += 1
+ mel2ph[-1] = mel2ph[-2]
+ assert not np.any(mel2ph == 0)
+ T_t = len(ph_list)
+ dur = mel2token_to_dur(mel2ph, T_t)
+ return mel2ph.tolist(), dur.tolist()
+
+
+def split_audio_by_mel2ph(audio, mel2ph, hop_size, audio_num_mel_bins):
+ if isinstance(audio, torch.Tensor):
+ audio = audio.numpy()
+ if isinstance(mel2ph, torch.Tensor):
+ mel2ph = mel2ph.numpy()
+ assert len(audio.shape) == 1, len(mel2ph.shape) == 1
+ split_locs = []
+ for i in range(1, len(mel2ph)):
+ if mel2ph[i] != mel2ph[i - 1]:
+ split_loc = i * hop_size
+ split_locs.append(split_loc)
+
+ new_audio = []
+ for i in range(len(split_locs) - 1):
+ new_audio.append(audio[split_locs[i]:split_locs[i + 1]])
+ new_audio.append(np.zeros([0.5 * audio_num_mel_bins]))
+ return np.concatenate(new_audio)
+
+
+def mel2token_to_dur(mel2token, T_txt=None, max_dur=None):
+ is_torch = isinstance(mel2token, torch.Tensor)
+ has_batch_dim = True
+ if not is_torch:
+ mel2token = torch.LongTensor(mel2token)
+ if T_txt is None:
+ T_txt = mel2token.max()
+ if len(mel2token.shape) == 1:
+ mel2token = mel2token[None, ...]
+ has_batch_dim = False
+ B, _ = mel2token.shape
+ dur = mel2token.new_zeros(B, T_txt + 1).scatter_add(1, mel2token, torch.ones_like(mel2token))
+ dur = dur[:, 1:]
+ if max_dur is not None:
+ dur = dur.clamp(max=max_dur)
+ if not is_torch:
+ dur = dur.numpy()
+ if not has_batch_dim:
+ dur = dur[0]
+ return dur
diff --git a/preprocess/tools/note_transcription/utils/audio/io.py b/preprocess/tools/note_transcription/utils/audio/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..34d5d20ae13e9aa481b1bc85117ad6539af8a624
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/audio/io.py
@@ -0,0 +1,22 @@
+import subprocess
+
+import numpy as np
+from scipy.io import wavfile
+
+
+def save_wav(wav, path, sr, norm=False):
+ if norm:
+ wav = wav / np.abs(wav).max()
+ wav = wav * 32767
+ wavfile.write(path[:-4] + '.wav', sr, wav.astype(np.int16))
+ if path[-4:] == '.mp3':
+ to_mp3(path[:-4])
+
+
+def to_mp3(out_path):
+ if out_path[-4:] == '.wav':
+ out_path = out_path[:-4]
+ subprocess.check_call(
+ f'ffmpeg -threads 1 -loglevel error -i "{out_path}.wav" -vn -b:a 192k -y -hide_banner -async 1 "{out_path}.mp3"',
+ shell=True, stdin=subprocess.PIPE)
+ subprocess.check_call(f'rm -f "{out_path}.wav"', shell=True)
diff --git a/preprocess/tools/note_transcription/utils/audio/mel.py b/preprocess/tools/note_transcription/utils/audio/mel.py
new file mode 100644
index 0000000000000000000000000000000000000000..37119319f544b9cc7203477c856d0d0d55648b23
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/audio/mel.py
@@ -0,0 +1,139 @@
+import math
+import numpy as np
+import torch
+import torch.utils.data
+from librosa.filters import mel as librosa_mel_fn
+from scipy.io.wavfile import read
+import torch
+import torch.nn as nn
+
+MAX_WAV_VALUE = 32768.0
+
+
+def load_wav(full_path):
+ sampling_rate, data = read(full_path)
+ return data, sampling_rate
+
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+ return np.log10(np.clip(x, a_min=clip_val, a_max=None) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+ return np.exp(x) / C
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+ return torch.log10(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+ return torch.exp(x) / C
+
+
+def spectral_normalize_torch(magnitudes):
+ output = dynamic_range_compression_torch(magnitudes)
+ return output
+
+
+def spectral_de_normalize_torch(magnitudes):
+ output = dynamic_range_decompression_torch(magnitudes)
+ return output
+
+
+class MelNet(nn.Module):
+ def __init__(self, hparams, device='cpu') -> None:
+ super().__init__()
+ self.n_fft = hparams['fft_size']
+ self.num_mels = hparams['audio_num_mel_bins']
+ self.sampling_rate = hparams['audio_sample_rate']
+ self.hop_size = hparams['hop_size']
+ self.win_size = hparams['win_size']
+ self.fmin = hparams['fmin']
+ self.fmax = hparams['fmax']
+ self.device = device
+
+ mel = librosa_mel_fn(sr=self.sampling_rate, n_fft=self.n_fft, n_mels=self.num_mels, fmin=self.fmin,
+ fmax=self.fmax)
+ self.mel_basis = torch.from_numpy(mel).float().to(self.device)
+ self.hann_window = torch.hann_window(self.win_size).to(self.device)
+
+ def to(self, device, **kwagrs):
+ super().to(device=device, **kwagrs)
+ self.mel_basis = self.mel_basis.to(device)
+ self.hann_window = self.hann_window.to(device)
+ self.device = device
+
+ def forward(self, y, center=False, complex=False):
+ if isinstance(y, np.ndarray):
+ y = torch.FloatTensor(y)
+ if len(y.shape) == 1:
+ y = y.unsqueeze(0)
+ y = y.clamp(min=-1., max=1.).to(self.device)
+
+ pad_length = math.ceil(y.shape[1] / self.hop_size) * self.hop_size - y.shape[1]
+ y = torch.nn.functional.pad(y.unsqueeze(1),
+ [int((self.n_fft - self.hop_size) / 2),
+ int((self.n_fft - self.hop_size) / 2 + pad_length)],
+ mode='reflect')
+ y = y.squeeze(1)
+
+ spec = torch.stft(y, self.n_fft, hop_length=self.hop_size, win_length=self.win_size, window=self.hann_window,
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
+ if not complex:
+ spec = torch.view_as_real(spec)
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) # [B, n_fft, T]
+ spec = torch.matmul(self.mel_basis, spec)
+ spec = spectral_normalize_torch(spec)
+ spec = spec.transpose(1, 2) # [B, T, n_fft]
+ else:
+ B, C, T, _ = spec.shape
+ spec = spec.transpose(1, 2) # [B, T, n_fft, 2]
+ return spec
+
+
+## below can be used in one gpu, but not ddp
+mel_basis = {}
+hann_window = {}
+
+
+def mel_spectrogram(y, hparams, center=False, complex=False): # y should be a tensor with shape (b,wav_len)
+ # hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
+ # win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
+ # fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
+ # fmax: 10000 # To be increased/reduced depending on data.
+ # fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter
+ # n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax,
+ n_fft = hparams['fft_size']
+ num_mels = hparams['audio_num_mel_bins']
+ sampling_rate = hparams['audio_sample_rate']
+ hop_size = hparams['hop_size']
+ win_size = hparams['win_size']
+ fmin = hparams['fmin']
+ fmax = hparams['fmax']
+ if isinstance(y, np.ndarray):
+ y = torch.FloatTensor(y)
+ if len(y.shape) == 1:
+ y = y.unsqueeze(0)
+ y = y.clamp(min=-1., max=1.)
+ global mel_basis, hann_window
+ if fmax not in mel_basis:
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
+ mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
+
+ y = torch.nn.functional.pad(y.unsqueeze(1), [int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)],
+ mode='reflect')
+ y = y.squeeze(1)
+
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=complex)
+
+ if not complex:
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
+ spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
+ spec = spectral_normalize_torch(spec)
+ else:
+ B, C, T, _ = spec.shape
+ spec = spec.transpose(1, 2) # [B, T, n_fft, 2]
+ return spec
diff --git a/preprocess/tools/note_transcription/utils/audio/pitch_extractors.py b/preprocess/tools/note_transcription/utils/audio/pitch_extractors.py
new file mode 100644
index 0000000000000000000000000000000000000000..34868e343085d204ef058f180ac249bbe0aa26db
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/audio/pitch_extractors.py
@@ -0,0 +1,60 @@
+import math
+import numpy as np
+
+PITCH_EXTRACTOR = {}
+
+
+def register_pitch_extractor(name):
+ def register_pitch_extractor_(cls):
+ PITCH_EXTRACTOR[name] = cls
+ return cls
+
+ return register_pitch_extractor_
+
+
+def get_pitch_extractor(name):
+ return PITCH_EXTRACTOR[name]
+
+
+def extract_pitch_simple(wav):
+ from ..commons.hparams import hparams
+ return extract_pitch(hparams['pitch_extractor'], wav,
+ hparams['hop_size'], hparams['audio_sample_rate'],
+ f0_min=hparams['f0_min'], f0_max=hparams['f0_max'])
+
+
+def extract_pitch(extractor_name, wav_data, hop_size, audio_sample_rate, f0_min=75, f0_max=800, **kwargs):
+ return get_pitch_extractor(extractor_name)(wav_data, hop_size, audio_sample_rate, f0_min, f0_max, **kwargs)
+
+
+@register_pitch_extractor('parselmouth')
+def parselmouth_pitch(wav_data, hop_size, audio_sample_rate, f0_min, f0_max,
+ voicing_threshold=0.6, *args, **kwargs):
+ import parselmouth
+ time_step = hop_size / audio_sample_rate * 1000
+ n_mel_frames = int(len(wav_data) // hop_size)
+ f0_pm = parselmouth.Sound(wav_data, audio_sample_rate).to_pitch_ac(
+ time_step=time_step / 1000, voicing_threshold=voicing_threshold,
+ pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
+ pad_size = (n_mel_frames - len(f0_pm) + 1) // 2
+ f0 = np.pad(f0_pm, [[pad_size, n_mel_frames - len(f0_pm) - pad_size]], mode='constant')
+ return f0
+
+@register_pitch_extractor('pyworld')
+def pyworld_pitch(wav_data, hop_size, audio_sample_rate, f0_min, f0_max,
+ voicing_threshold=0.6, *args, **kwargs):
+ import pyworld as pw
+ # f0, _ = pw.harvest(wav_data.astype(np.double), audio_sample_rate, f0_floor=f0_min, f0_ceil=f0_max,
+ # frame_period=hop_size * 1000 / audio_sample_rate)
+ f0, _ = pw.dio(wav_data.astype(np.double), audio_sample_rate, f0_floor=f0_min, f0_ceil=f0_max, frame_period=hop_size * 1000 / audio_sample_rate)
+ f0[f0 < f0_min] = 0.0
+ f0[f0 > f0_max] = 0.0
+ n_mel_frames = math.ceil(len(wav_data) / hop_size)
+ if n_mel_frames > len(f0):
+ pad_size = (n_mel_frames - len(f0) + 1) // 2
+ f0 = np.pad(f0, [[pad_size, n_mel_frames - len(f0) - pad_size]], mode='constant')
+ elif n_mel_frames < len(f0):
+ left_del = (len(f0) - n_mel_frames + 1) // 2
+ right_del = len(f0) - n_mel_frames - left_del
+ f0 = f0[left_del: (-right_del if right_del > 0 else len(f0))]
+ return f0
diff --git a/preprocess/tools/note_transcription/utils/audio/pitch_utils.py b/preprocess/tools/note_transcription/utils/audio/pitch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6dd8f84bbd2fc774e7604b9fcb1d4a06f88ded4c
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/audio/pitch_utils.py
@@ -0,0 +1,303 @@
+import numpy as np
+import torch
+import pretty_midi
+
+def to_lf0(f0):
+ f0[f0 < 1.0e-5] = 1.0e-6
+ lf0 = f0.log() if isinstance(f0, torch.Tensor) else np.log(f0)
+ lf0[f0 < 1.0e-5] = - 1.0E+10
+ return lf0
+
+
+def to_f0(lf0):
+ f0 = np.where(lf0 <= 0, 0.0, np.exp(lf0))
+ return f0.flatten()
+
+
+def f0_to_coarse(f0, f0_bin=256, f0_max=900.0, f0_min=50.0):
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
+ is_torch = isinstance(f0, torch.Tensor)
+ f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
+
+ f0_mel[f0_mel <= 1] = 1
+ f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
+ f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(int)
+ assert f0_coarse.max() <= f0_bin-1 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min(), f0.min(), f0.max())
+ return f0_coarse
+
+
+def coarse_to_f0(f0_coarse, f0_bin=256, f0_max=900.0, f0_min=50.0):
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
+ uv = f0_coarse == 1
+ f0 = f0_mel_min + (f0_coarse - 1) * (f0_mel_max - f0_mel_min) / (f0_bin - 2)
+ f0 = ((f0 / 1127).exp() - 1) * 700
+ f0[uv] = 0
+ return f0
+
+
+def norm_f0(f0, uv, pitch_norm='log', f0_mean=400, f0_std=100):
+ is_torch = isinstance(f0, torch.Tensor)
+ if pitch_norm == 'standard':
+ f0 = (f0 - f0_mean) / f0_std
+ if pitch_norm == 'log':
+ f0 = torch.log2(f0 + 1e-8) if is_torch else np.log2(f0 + 1e-8)
+ if uv is not None:
+ f0[uv > 0] = 0
+ return f0
+
+
+def norm_interp_f0(f0, pitch_norm='log', f0_mean=None, f0_std=None):
+ is_torch = isinstance(f0, torch.Tensor)
+ if is_torch:
+ device = f0.device
+ f0 = f0.data.cpu().numpy()
+ uv = f0 == 0
+ f0 = norm_f0(f0, uv, pitch_norm, f0_mean, f0_std)
+ if sum(uv) == len(f0):
+ f0[uv] = 0
+ elif sum(uv) > 0:
+ f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
+ if is_torch:
+ uv = torch.FloatTensor(uv)
+ f0 = torch.FloatTensor(f0)
+ f0 = f0.to(device)
+ uv = uv.to(device)
+ return f0, uv
+
+
+def denorm_f0(f0, uv, pitch_norm='log', f0_mean=400, f0_std=100, pitch_padding=None, min=50, max=900):
+ is_torch = isinstance(f0, torch.Tensor)
+ if pitch_norm == 'standard':
+ f0 = f0 * f0_std + f0_mean
+ if pitch_norm == 'log':
+ f0 = 2 ** f0
+ f0 = f0.clamp(min=min, max=max) if is_torch else np.clip(f0, a_min=min, a_max=max)
+ if uv is not None:
+ f0[uv > 0] = 0
+ if pitch_padding is not None:
+ f0[pitch_padding] = 0
+ return f0
+
+def interp_f0(f0, uv=None):
+ if uv is None:
+ uv = f0 == 0
+ f0 = norm_f0(f0, uv)
+ if uv.any() and not uv.all():
+ f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
+ return denorm_f0(f0, uv=None), uv
+
+def resample_align_curve(points: np.ndarray, original_timestep: float, target_timestep: float, align_length=-1):
+ t_max = (len(points) - 1) * original_timestep
+ curve_interp = np.interp(
+ np.arange(0, t_max, target_timestep),
+ original_timestep * np.arange(len(points)),
+ points
+ ).astype(points.dtype)
+ if align_length > 0:
+ delta_l = align_length - len(curve_interp)
+ if delta_l < 0:
+ curve_interp = curve_interp[:align_length]
+ elif delta_l > 0:
+ curve_interp = np.concatenate((curve_interp, np.full(delta_l, fill_value=curve_interp[-1])), axis=0)
+ return curve_interp
+
+def midi_to_hz(midi):
+ if type(midi) == np.ndarray:
+ non_mask = midi == 0
+ freq_hz = 440.0 * 2.0 ** ((midi - 69.0) / 12.0)
+ freq_hz[non_mask] = 0
+ else:
+ freq_hz = 440.0 * 2.0 ** ((midi - 69.0) / 12.0)
+ return freq_hz
+
+def hz_to_midi(hz):
+ if type(hz) == torch.Tensor:
+ non_mask = hz == 0
+ midi = 69.0 + 12.0 * (torch.log2(hz) - torch.log2(torch.Tensor(440.0)))
+ midi[non_mask] = 0
+ elif type(hz) == np.ndarray:
+ non_mask = hz == 0
+ midi = 69.0 + 12.0 * (np.log2(hz) - np.log2(440.0))
+ midi[non_mask] = 0
+ else:
+ midi = 69.0 + 12.0 * (np.log2(hz) - np.log2(440.0))
+ if hz == 0:
+ midi = 0
+ return midi
+
+def boundary2Interval(bd):
+ # bd has a shape of [T] with T frames
+ is_torch = isinstance(bd, torch.Tensor)
+ if is_torch:
+ device = bd.device
+ bd = bd.data.cpu().numpy()
+ assert len(bd.shape) == 1
+ # force valid begin and end
+ # bd[0] = 0 # took care of in regulate_boundary()
+ # bd[-1] = 0
+ ret = np.zeros(shape=(bd.sum() + 1, 2), dtype=int)
+ ret_idx = 0
+ ret[0, 0] = 0
+ for i, u in enumerate(bd):
+ if i == 0:
+ continue
+ if u == 1:
+ ret[ret_idx, 1] = i
+ ret[ret_idx+1, 0] = i
+ ret_idx += 1
+ ret[-1, 1] = bd.shape[0] - 1
+ if is_torch:
+ ret = torch.LongTensor(ret).to(device)
+ return ret
+
+def validate_pitch_and_itv(notes, note_itv):
+ # notes [T]
+ # note_itv [T, 2]
+ assert notes.shape[0] == note_itv.shape[0]
+ res_notes = []
+ res_note_itv = []
+ for idx in range(notes.shape[0]):
+ pitch, itv = notes[idx], note_itv[idx]
+ if itv[0] >= itv[1]:
+ raise RuntimeError("The note duration should be positive")
+ if pitch == 0:
+ continue
+ res_notes.append(pitch)
+ res_note_itv.append([itv[0], itv[1]])
+ res_notes = np.array(res_notes)
+ res_note_itv = np.array(res_note_itv)
+ return res_notes, res_note_itv
+
+def save_midi(notes, note_itv, midi_path):
+ # notes [T]
+ # note_itv [T, 2]
+ notes, note_itv = validate_pitch_and_itv(notes, note_itv)
+ if notes.shape == (0,):
+ return None
+ assert notes.shape[0] == note_itv.shape[0]
+ piano_chord = pretty_midi.PrettyMIDI()
+ piano_program = pretty_midi.instrument_name_to_program('Acoustic Grand Piano')
+ piano = pretty_midi.Instrument(program=piano_program)
+ for idx in range(notes.shape[0]):
+ pitch, itv = notes[idx], note_itv[idx]
+ note = pretty_midi.Note(velocity=120, pitch=pitch, start=itv[0], end=itv[1])
+ piano.notes.append(note)
+ piano_chord.remove_invalid_notes()
+ piano_chord.instruments.append(piano)
+ piano_chord.write(midi_path)
+ return piano_chord
+
+def midi2NoteInterval(mid):
+ assert type(mid) == pretty_midi.PrettyMIDI
+ if len(mid.instruments) == 0 or len(mid.instruments[0].notes) == 0:
+ return None
+ ret = np.zeros(shape=(len(mid.instruments[0].notes), 2))
+ for i, note in enumerate(mid.instruments[0].notes):
+ ret[i, 0] = note.start
+ ret[i, 1] = note.end
+ return ret
+
+def midi2NotePitch(mid):
+ assert type(mid) == pretty_midi.PrettyMIDI
+ if len(mid.instruments) == 0 or len(mid.instruments[0].notes) == 0:
+ return None
+ ret = np.zeros(shape=len(mid.instruments[0].notes))
+ for i, note in enumerate(mid.instruments[0].notes):
+ ret[i] = note.pitch
+ return ret
+
+def midi_onset_eval(mid_gt, mid_pred):
+ import mir_eval
+ interval_true = midi2NoteInterval(mid_gt)
+ if interval_true is None:
+ raise RuntimeError('Midi ground truth is None')
+ interval_pred = midi2NoteInterval(mid_pred)
+ if interval_pred is None:
+ return 0, 0, 0
+ onset_p, onset_r, onset_f = mir_eval.transcription.onset_precision_recall_f1(
+ interval_true, interval_pred, onset_tolerance=0.05, strict=False, beta=1.0)
+ return onset_p, onset_r, onset_f
+
+def midi_offset_eval(mid_gt, mid_pred):
+ import mir_eval
+ interval_true = midi2NoteInterval(mid_gt)
+ if interval_true is None:
+ raise RuntimeError('Midi ground truth is None')
+ interval_pred = midi2NoteInterval(mid_pred)
+ if interval_pred is None:
+ return 0, 0, 0
+ offset_p, offset_r, offset_f = mir_eval.transcription.offset_precision_recall_f1(
+ interval_true, interval_pred, offset_ratio=0.2, offset_min_tolerance=0.05, strict=False, beta=1.0)
+ return offset_p, offset_r, offset_f
+
+def midi_pitch_eval(mid_gt, mid_pred, offset_ratio=0.2):
+ import mir_eval
+ interval_true = midi2NoteInterval(mid_gt)
+ pitch_true = midi_to_hz(midi2NotePitch(mid_gt))
+ if interval_true is None or pitch_true is None:
+ raise RuntimeError('Midi ground truth is None')
+ interval_pred = midi2NoteInterval(mid_pred)
+ pitch_pred = midi2NotePitch(mid_pred)
+ if interval_pred is None:
+ return 0, 0, 0, 0
+ if pitch_pred is None:
+ pitch_pred = np.zeros(interval_pred.shape[0])
+ pitch_pred = midi_to_hz(pitch_pred)
+ overlap_p, overlap_r, overlap_f, avg_overlap_ratio = mir_eval.transcription.precision_recall_f1_overlap(
+ interval_true, pitch_true, interval_pred, pitch_pred, onset_tolerance=0.05, pitch_tolerance=50.0,
+ offset_ratio=offset_ratio, offset_min_tolerance=0.05, strict=False, beta=1.0)
+ return overlap_p, overlap_r, overlap_f, avg_overlap_ratio
+
+def midi_COn_eval(mid_gt, mid_pred):
+ return midi_onset_eval(mid_gt, mid_pred)
+
+def midi_COnP_eval(mid_gt, mid_pred):
+ return midi_pitch_eval(mid_gt, mid_pred, offset_ratio=None)
+
+def midi_COnPOff_eval(mid_gt, mid_pred):
+ return midi_pitch_eval(mid_gt, mid_pred)
+
+def midi_melody_eval(mid_gt, mid_pred, hop_size=256, sample_rate=48000):
+ interval_true = midi2NoteInterval(mid_gt)
+ pitch_true = midi_to_hz(midi2NotePitch(mid_gt))
+ if interval_true is None or pitch_true is None:
+ raise RuntimeError('Midi ground truth is None')
+ interval_pred = midi2NoteInterval(mid_pred)
+ pitch_pred = midi2NotePitch(mid_pred)
+ if interval_pred is None:
+ return 0, 0, 0, 0
+ if pitch_pred is None:
+ pitch_pred = np.zeros(interval_pred.shape[0])
+ pitch_pred = midi_to_hz(pitch_pred)
+
+ vr, vfa, rpa, rca, oa = melody_eval_pitch_and_itv(
+ pitch_true, interval_true, pitch_pred, interval_pred, hop_size, sample_rate)
+
+ return vr, vfa, rpa, rca, oa
+
+def melody_eval_pitch_and_itv(pitch_true, interval_true, pitch_pred, interval_pred, hop_size=256, sample_rate=48000):
+ import mir_eval
+ t_gt = np.arange(0, interval_true[-1][1], hop_size / sample_rate)
+ freq_gt = np.zeros_like(t_gt)
+ for idx in range(len(pitch_true)):
+ freq_gt[min(len(freq_gt) - 1, round(interval_true[idx][0] * sample_rate / hop_size)): round(
+ interval_true[idx][1] * sample_rate / hop_size)] = pitch_true[idx]
+
+ t_pred = np.arange(0, interval_pred[-1][1], hop_size / sample_rate)
+ freq_pred = np.zeros_like(t_pred)
+ for idx in range(len(pitch_pred)):
+ freq_pred[min(len(freq_pred) - 1, round(interval_pred[idx][0] * sample_rate / hop_size)): round(
+ interval_pred[idx][1] * sample_rate / hop_size)] = pitch_pred[idx]
+
+ ref_voicing, ref_cent, est_voicing, est_cent = mir_eval.melody.to_cent_voicing(t_gt, freq_gt,
+ t_pred, freq_pred)
+ vr, vfa = mir_eval.melody.voicing_measures(ref_voicing,
+ est_voicing) # voicing recall, voicing false alarm
+ rpa = mir_eval.melody.raw_pitch_accuracy(ref_voicing, ref_cent, est_voicing, est_cent)
+ rca = mir_eval.melody.raw_chroma_accuracy(ref_voicing, ref_cent, est_voicing, est_cent)
+ oa = mir_eval.melody.overall_accuracy(ref_voicing, ref_cent, est_voicing, est_cent)
+
+ return vr, vfa, rpa, rca, oa
diff --git a/preprocess/tools/note_transcription/utils/audio/vad.py b/preprocess/tools/note_transcription/utils/audio/vad.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbe9c7a6417f234ae46e1754d6736b26e22b2427
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/audio/vad.py
@@ -0,0 +1,78 @@
+from skimage.transform import resize
+import struct
+import webrtcvad
+from scipy.ndimage.morphology import binary_dilation
+import librosa
+import numpy as np
+import pyloudnorm as pyln
+import warnings
+
+warnings.filterwarnings("ignore", message="Possible clipped samples in output")
+
+int16_max = (2 ** 15) - 1
+
+
+def trim_long_silences(path, sr=None, return_raw_wav=False, norm=True, vad_max_silence_length=12):
+ """
+ Ensures that segments without voice in the waveform remain no longer than a
+ threshold determined by the VAD parameters in params.py.
+ :param wav: the raw waveform as a numpy array of floats
+ :param vad_max_silence_length: Maximum number of consecutive silent frames a segment can have.
+ :return: the same waveform with silences trimmed away (length <= original wav length)
+ """
+
+ ## Voice Activation Detection
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
+ # This sets the granularity of the VAD. Should not need to be changed.
+ sampling_rate = 16000
+ wav_raw, sr = librosa.core.load(path, sr=sr)
+
+ if norm:
+ meter = pyln.Meter(sr) # create BS.1770 meter
+ loudness = meter.integrated_loudness(wav_raw)
+ wav_raw = pyln.normalize.loudness(wav_raw, loudness, -20.0)
+ if np.abs(wav_raw).max() > 1.0:
+ wav_raw = wav_raw / np.abs(wav_raw).max()
+
+ wav = librosa.resample(wav_raw, sr, sampling_rate, res_type='kaiser_best')
+
+ vad_window_length = 30 # In milliseconds
+ # Number of frames to average together when performing the moving average smoothing.
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
+ vad_moving_average_width = 8
+
+ # Compute the voice detection window size
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
+
+ # Trim the end of the audio to have a multiple of the window size
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
+
+ # Convert the float waveform to 16-bit mono PCM
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
+
+ # Perform voice activation detection
+ voice_flags = []
+ vad = webrtcvad.Vad(mode=3)
+ for window_start in range(0, len(wav), samples_per_window):
+ window_end = window_start + samples_per_window
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
+ sample_rate=sampling_rate))
+ voice_flags = np.array(voice_flags)
+
+ # Smooth the voice detection with a moving average
+ def moving_average(array, width):
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
+ ret = np.cumsum(array_padded, dtype=float)
+ ret[width:] = ret[width:] - ret[:-width]
+ return ret[width - 1:] / width
+
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
+ audio_mask = np.round(audio_mask).astype(np.bool)
+
+ # Dilate the voiced regions
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
+ audio_mask = np.repeat(audio_mask, samples_per_window)
+ audio_mask = resize(audio_mask, (len(wav_raw),)) > 0
+ if return_raw_wav:
+ return wav_raw, audio_mask, sr
+ return wav_raw[audio_mask], audio_mask, sr
diff --git a/preprocess/tools/note_transcription/utils/commons/base_task.py b/preprocess/tools/note_transcription/utils/commons/base_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..31753b6df657a4df6a3b5999a5201b28b4ee4ad5
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/commons/base_task.py
@@ -0,0 +1,235 @@
+import logging
+import os
+import random
+import subprocess
+import sys
+from datetime import datetime
+import numpy as np
+import torch.utils.data
+from torch import nn
+from torch.utils.tensorboard import SummaryWriter
+from .dataset_utils import data_loader
+from .hparams import hparams
+from .meters import AvgrageMeter
+from .tensor_utils import tensors_to_scalars
+from .trainer import Trainer
+
+torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system'))
+
+log_format = '%(asctime)s %(message)s'
+logging.basicConfig(stream=sys.stdout, level=logging.INFO,
+ format=log_format, datefmt='%m/%d %I:%M:%S %p')
+
+
+class BaseTask(nn.Module):
+ def __init__(self, *args, **kwargs):
+ super(BaseTask, self).__init__()
+ self.current_epoch = 0
+ self.global_step = 0
+ self.trainer = None
+ self.use_ddp = False
+ self.gradient_clip_norm = hparams['clip_grad_norm']
+ self.gradient_clip_val = hparams.get('clip_grad_value', 0)
+ self.model = None
+ self.training_losses_meter = None
+ self.logger: SummaryWriter = None
+
+ ######################
+ # build model, dataloaders, optimizer, scheduler and tensorboard
+ ######################
+ def build_model(self):
+ raise NotImplementedError
+
+ @data_loader
+ def train_dataloader(self):
+ raise NotImplementedError
+
+ @data_loader
+ def test_dataloader(self):
+ raise NotImplementedError
+
+ @data_loader
+ def val_dataloader(self):
+ raise NotImplementedError
+
+ def build_scheduler(self, optimizer):
+ return None
+
+ def build_optimizer(self, model):
+ raise NotImplementedError
+
+ def configure_optimizers(self):
+ optm = self.build_optimizer(self.model)
+ self.scheduler = self.build_scheduler(optm)
+ if isinstance(optm, (list, tuple)):
+ return optm
+ return [optm]
+
+ def build_tensorboard(self, save_dir, name, **kwargs):
+ log_dir = os.path.join(save_dir, name)
+ os.makedirs(log_dir, exist_ok=True)
+ self.logger = SummaryWriter(log_dir=log_dir, **kwargs)
+
+ ######################
+ # training
+ ######################
+ def on_train_start(self):
+ pass
+
+ def on_train_end(self):
+ pass
+
+ def on_epoch_start(self):
+ self.training_losses_meter = {'total_loss': AvgrageMeter()}
+
+ def on_epoch_end(self):
+ loss_outputs = {k: round(v.avg, 4) for k, v in self.training_losses_meter.items()}
+ print(f"Epoch {self.current_epoch} ended. Steps: {self.global_step}. {loss_outputs}")
+
+ def _training_step(self, sample, batch_idx, optimizer_idx):
+ """
+
+ :param sample:
+ :param batch_idx:
+ :return: total loss: torch.Tensor, loss_log: dict
+ """
+ raise NotImplementedError
+
+ def training_step(self, sample, batch_idx, optimizer_idx=-1):
+ """
+
+ :param sample:
+ :param batch_idx:
+ :param optimizer_idx:
+ :return: {'loss': torch.Tensor, 'progress_bar': dict, 'tb_log': dict}
+ """
+ loss_ret = self._training_step(sample, batch_idx, optimizer_idx)
+ if loss_ret is None:
+ return {'loss': None}
+ total_loss, log_outputs = loss_ret
+ log_outputs = tensors_to_scalars(log_outputs)
+ for k, v in log_outputs.items():
+ if k not in self.training_losses_meter:
+ self.training_losses_meter[k] = AvgrageMeter()
+ if not np.isnan(v):
+ self.training_losses_meter[k].update(v)
+ self.training_losses_meter['total_loss'].update(total_loss.item())
+
+ if optimizer_idx >= 0:
+ log_outputs[f'lr_{optimizer_idx}'] = self.trainer.optimizers[optimizer_idx].param_groups[0]['lr']
+
+ progress_bar_log = log_outputs
+ tb_log = {f'tr/{k}': v for k, v in log_outputs.items()}
+ return {
+ 'loss': total_loss,
+ 'progress_bar': progress_bar_log,
+ 'tb_log': tb_log
+ }
+
+ def on_before_optimization(self, opt_idx):
+ if self.gradient_clip_norm > 0:
+ torch.nn.utils.clip_grad_norm_(self.parameters(), self.gradient_clip_norm)
+ if self.gradient_clip_val > 0:
+ torch.nn.utils.clip_grad_value_(self.parameters(), self.gradient_clip_val)
+
+ def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx):
+ if self.scheduler is not None:
+ # self.scheduler.step(self.global_step // hparams['accumulate_grad_batches'])
+ # the code above causes EPOCH_DEPRECATION_WARNING, changed it and changed the optimizer init with
+ # step_size divided by accumulate_grad_batches
+ self.scheduler.step()
+
+ ######################
+ # validation
+ ######################
+ def validation_start(self):
+ pass
+
+ def validation_step(self, sample, batch_idx):
+ """
+
+ :param sample:
+ :param batch_idx:
+ :return: output: {"losses": {...}, "total_loss": float, ...} or (total loss: torch.Tensor, loss_log: dict)
+ """
+ raise NotImplementedError
+
+ def validation_end(self, outputs):
+ """
+
+ :param outputs:
+ :return: loss_output: dict
+ """
+ all_losses_meter = {'total_loss': AvgrageMeter()}
+ for output in outputs:
+ if len(output) == 0 or output is None:
+ continue
+ if isinstance(output, dict):
+ assert 'losses' in output, 'Key "losses" should exist in validation output.'
+ n = output.pop('nsamples', 1)
+ losses = tensors_to_scalars(output['losses'])
+ total_loss = output.get('total_loss', sum(losses.values()))
+ else:
+ assert len(output) == 2, 'Validation output should only consist of two elements: (total_loss, losses)'
+ n = 1
+ total_loss, losses = output
+ losses = tensors_to_scalars(losses)
+ if isinstance(total_loss, torch.Tensor):
+ total_loss = total_loss.item()
+ for k, v in losses.items():
+ if k not in all_losses_meter:
+ all_losses_meter[k] = AvgrageMeter()
+ all_losses_meter[k].update(v, n)
+ all_losses_meter['total_loss'].update(total_loss, n)
+ loss_output = {k: round(v.avg, 4) for k, v in all_losses_meter.items()}
+ print(f"| Validation results@{self.global_step}: {loss_output}")
+ return {
+ 'tb_log': {f'val/{k}': v for k, v in loss_output.items()},
+ 'val_loss': loss_output['total_loss']
+ }
+
+ ######################
+ # testing
+ ######################
+ def test_start(self):
+ pass
+
+ def test_step(self, sample, batch_idx):
+ return self.validation_step(sample, batch_idx)
+
+ def test_end(self, outputs):
+ return self.validation_end(outputs)
+
+ ######################
+ # start training/testing
+ ######################
+ @classmethod
+ def start(cls):
+ os.environ['MASTER_PORT'] = str(random.randint(15000, 30000))
+ random.seed(hparams['seed'])
+ np.random.seed(hparams['seed'])
+ work_dir = hparams['work_dir']
+ trainer = Trainer(
+ work_dir=work_dir,
+ val_check_interval=hparams['val_check_interval'],
+ tb_log_interval=hparams['tb_log_interval'],
+ max_updates=hparams['max_updates'],
+ num_sanity_val_steps=hparams['num_sanity_val_steps'] if not hparams['validate'] else 10000,
+ accumulate_grad_batches=hparams['accumulate_grad_batches'],
+ print_nan_grads=hparams['print_nan_grads'],
+ resume_from_checkpoint=hparams.get('resume_from_checkpoint', 0),
+ amp=hparams['amp'],
+ monitor_key=hparams['valid_monitor_key'],
+ monitor_mode=hparams['valid_monitor_mode'],
+ num_ckpt_keep=hparams['num_ckpt_keep'],
+ save_best=hparams['save_best'],
+ seed=hparams['seed'],
+ debug=hparams['debug']
+ )
+ if not hparams['infer']: # train
+ trainer.fit(cls)
+ else:
+ trainer.test(cls)
+
+ def on_keyboard_interrupt(self):
+ pass
diff --git a/preprocess/tools/note_transcription/utils/commons/ckpt_utils.py b/preprocess/tools/note_transcription/utils/commons/ckpt_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..34d4b4831242cbc8d65f837d3bb68fd497395197
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/commons/ckpt_utils.py
@@ -0,0 +1,68 @@
+import glob
+import os
+import re
+import torch
+
+
+def get_last_checkpoint(work_dir, steps=None):
+ checkpoint = None
+ last_ckpt_path = None
+ ckpt_paths = get_all_ckpts(work_dir, steps)
+ if len(ckpt_paths) > 0:
+ last_ckpt_path = ckpt_paths[0]
+ checkpoint = torch.load(last_ckpt_path, map_location='cpu')
+ return checkpoint, last_ckpt_path
+
+
+def get_all_ckpts(work_dir, steps=None):
+ if steps is None:
+ ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt'
+ else:
+ ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt'
+ return sorted(glob.glob(ckpt_path_pattern),
+ key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))
+
+
+def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True, verbose=True):
+ if os.path.isfile(ckpt_base_dir):
+ base_dir = os.path.dirname(ckpt_base_dir)
+ ckpt_path = ckpt_base_dir
+ checkpoint = torch.load(ckpt_base_dir, map_location='cpu')
+ else:
+ base_dir = ckpt_base_dir
+ checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir)
+ if checkpoint is not None:
+ state_dict = checkpoint["state_dict"]
+ if len([k for k in state_dict.keys() if '.' in k]) > 0:
+ state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items()
+ if k.startswith(f'{model_name}.')}
+ else:
+ if '.' not in model_name:
+ state_dict = state_dict[model_name]
+ else:
+ base_model_name = model_name.split('.')[0]
+ rest_model_name = model_name[len(base_model_name) + 1:]
+ state_dict = {
+ k[len(rest_model_name) + 1:]: v for k, v in state_dict[base_model_name].items()
+ if k.startswith(f'{rest_model_name}.')}
+ if not strict:
+ cur_model_state_dict = cur_model.state_dict()
+ unmatched_keys = []
+ for key, param in state_dict.items():
+ if key in cur_model_state_dict:
+ new_param = cur_model_state_dict[key]
+ if new_param.shape != param.shape:
+ unmatched_keys.append(key)
+ print("| Unmatched keys: ", key, new_param.shape, param.shape)
+ for key in unmatched_keys:
+ del state_dict[key]
+ # print(state_dict)
+ cur_model.load_state_dict(state_dict, strict=strict)
+ if verbose:
+ print(f"| load '{model_name}' from '{ckpt_path}'.")
+ else:
+ e_msg = f"| ckpt not found in {base_dir}."
+ if force:
+ assert False, e_msg
+ else:
+ print(e_msg)
diff --git a/preprocess/tools/note_transcription/utils/commons/dataset_utils.py b/preprocess/tools/note_transcription/utils/commons/dataset_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc9fe7b29782d678e160ab0479060756f868d373
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/commons/dataset_utils.py
@@ -0,0 +1,372 @@
+import os
+import sys
+import traceback
+import types
+from functools import wraps
+from itertools import chain
+import numpy as np
+import torch.utils.data
+import torch.nn.functional as F
+from torch.utils.data import ConcatDataset
+from .hparams import hparams
+
+
+def collate_1d_or_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1):
+ if len(values[0].shape) == 1:
+ return collate_1d(values, pad_idx, left_pad, shift_right, max_len, shift_id)
+ else:
+ return collate_2d(values, pad_idx, left_pad, shift_right, max_len)
+
+
+def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1):
+ """Convert a list of 1d tensors into a padded 2d tensor."""
+ size = max(v.size(0) for v in values) if max_len is None else max_len
+ res = values[0].new(len(values), size).fill_(pad_idx)
+
+ def copy_tensor(src, dst):
+ assert dst.numel() == src.numel()
+ if shift_right:
+ dst[1:] = src[:-1]
+ dst[0] = shift_id
+ else:
+ dst.copy_(src)
+
+ for i, v in enumerate(values):
+ copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
+ return res
+
+
+def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None):
+ """Convert a list of 2d tensors into a padded 3d tensor."""
+ size = max(v.size(0) for v in values) if max_len is None else max_len
+ res = values[0].new(len(values), size, values[0].shape[1]).fill_(pad_idx)
+
+ def copy_tensor(src, dst):
+ assert dst.numel() == src.numel()
+ if shift_right:
+ dst[1:] = src[:-1]
+ else:
+ dst.copy_(src)
+
+ for i, v in enumerate(values):
+ copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
+ return res
+
+def collate_xd(values, pad_value=0, max_len=None):
+ size = ((max(v.size(0) for v in values) if max_len is None else max_len), *values[0].shape[1:])
+ res = torch.full((len(values), *size), fill_value=pad_value, dtype=values[0].dtype, device=values[0].device)
+
+ for i, v in enumerate(values):
+ res[i, :len(v), ...] = v
+ return res
+
+def pad_or_cut_1d(values: torch.tensor, tgt_len, pad_value=0):
+ src_len = values.shape[0]
+ if src_len < tgt_len:
+ res = F.pad(values, [0, tgt_len - src_len], value=pad_value)
+ else:
+ res = values[:tgt_len]
+ return res
+
+def pad_or_cut_2d(values: torch.tensor, tgt_len, dim=-1, pad_value=0):
+ if dim == 0 or dim == -2:
+ src_len = values.shape[0]
+ if src_len < tgt_len:
+ res = F.pad(values, [0, 0, 0, tgt_len - src_len], value=pad_value)
+ else:
+ res = values[:tgt_len]
+ elif dim == 1 or dim == -1:
+ src_len = values.shape[1]
+ if src_len < tgt_len:
+ res = F.pad(values, [0, tgt_len - src_len], value=pad_value)
+ else:
+ res = values[:, :tgt_len]
+ else:
+ raise RuntimeError(f"Wrong dim number {dim} while the tensor only has {len(values.shape)} dimensions.")
+ return res
+
+def pad_or_cut_3d(values: torch.tensor, tgt_len, dim=-1, pad_value=0):
+ if dim == 0 or dim == -3:
+ src_len = values.shape[0]
+ if src_len < tgt_len:
+ res = F.pad(values, [0, 0, 0, 0, 0, tgt_len - src_len], value=pad_value)
+ else:
+ res = values[:tgt_len]
+ elif dim == 1 or dim == -2:
+ src_len = values.shape[1]
+ if src_len < tgt_len:
+ res = F.pad(values, [0, 0, 0, tgt_len - src_len], value=pad_value)
+ else:
+ res = values[:, :tgt_len]
+ elif dim == 2 or dim == -1:
+ src_len = values.shape[2]
+ if src_len < tgt_len:
+ res = F.pad(values, [0, tgt_len - src_len], value=pad_value)
+ else:
+ res = values[:, :, :tgt_len]
+ else:
+ raise RuntimeError(f"Wrong dim number {dim} while the tensor only has {len(values.shape)} dimensions.")
+ return res
+
+def pad_or_cut_xd(values, tgt_len, dim=-1, pad_value=0):
+ if len(values.shape) == 1:
+ return pad_or_cut_1d(values, tgt_len, pad_value)
+ elif len(values.shape) == 2:
+ return pad_or_cut_2d(values, tgt_len, dim, pad_value)
+ elif len(values.shape) == 3:
+ return pad_or_cut_3d(values, tgt_len, dim, pad_value)
+ else:
+ raise NotImplementedError
+
+def _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
+ if len(batch) == 0:
+ return 0
+ if len(batch) == max_sentences:
+ return 1
+ if num_tokens > max_tokens:
+ return 1
+ return 0
+
+
+def batch_by_size(
+ indices, num_tokens_fn, max_tokens=None, max_sentences=None,
+ required_batch_size_multiple=1, distributed=False
+):
+ """
+ Yield mini-batches of indices bucketed by size. Batches may contain
+ sequences of different lengths.
+
+ Args:
+ indices (List[int]): ordered list of dataset indices
+ num_tokens_fn (callable): function that returns the number of tokens at
+ a given index
+ max_tokens (int, optional): max number of tokens in each batch
+ (default: None).
+ max_sentences (int, optional): max number of sentences in each
+ batch (default: None).
+ required_batch_size_multiple (int, optional): require batch size to
+ be a multiple of N (default: 1).
+ """
+ max_tokens = max_tokens if max_tokens is not None else sys.maxsize
+ max_sentences = max_sentences if max_sentences is not None else sys.maxsize
+ bsz_mult = required_batch_size_multiple
+
+ if isinstance(indices, types.GeneratorType):
+ indices = np.fromiter(indices, dtype=np.int64, count=-1)
+
+ sample_len = 0
+ sample_lens = []
+ batch = []
+ batches = []
+ for i in range(len(indices)):
+ idx = indices[i]
+ num_tokens = num_tokens_fn(idx)
+ sample_lens.append(num_tokens)
+ sample_len = max(sample_len, num_tokens)
+
+ assert sample_len <= max_tokens, (
+ "sentence at index {} of size {} exceeds max_tokens "
+ "limit of {}!".format(idx, sample_len, max_tokens)
+ )
+ num_tokens = (len(batch) + 1) * sample_len
+
+ if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
+ mod_len = max(
+ bsz_mult * (len(batch) // bsz_mult),
+ len(batch) % bsz_mult,
+ )
+ batches.append(batch[:mod_len])
+ batch = batch[mod_len:]
+ sample_lens = sample_lens[mod_len:]
+ sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
+ batch.append(idx)
+ if len(batch) > 0:
+ batches.append(batch)
+ return batches
+
+
+def build_dataloader(dataset, shuffle, max_tokens=None, max_sentences=None,
+ required_batch_size_multiple=-1, endless=False, apply_batch_by_size=True, pin_memory=False, use_ddp=False):
+ import torch.distributed as dist
+ devices_cnt = torch.cuda.device_count()
+ if devices_cnt == 0:
+ devices_cnt = 1
+ if not use_ddp:
+ devices_cnt = 1
+ if required_batch_size_multiple == -1:
+ required_batch_size_multiple = devices_cnt
+
+ def shuffle_batches(batches):
+ np.random.shuffle(batches)
+ return batches
+
+ if max_tokens is not None:
+ max_tokens *= devices_cnt
+ if max_sentences is not None:
+ max_sentences *= devices_cnt
+ indices = dataset.ordered_indices()
+ if apply_batch_by_size:
+ batch_sampler = batch_by_size(
+ indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
+ required_batch_size_multiple=required_batch_size_multiple,
+ )
+ else:
+ batch_sampler = []
+ for i in range(0, len(indices), max_sentences):
+ batch_sampler.append(indices[i:i + max_sentences])
+
+ if shuffle:
+ batches = shuffle_batches(list(batch_sampler))
+ if endless:
+ batches = [b for _ in range(1000) for b in shuffle_batches(list(batch_sampler))]
+ else:
+ batches = batch_sampler
+ if endless:
+ batches = [b for _ in range(1000) for b in batches]
+ num_workers = dataset.num_workers
+ if use_ddp:
+ num_replicas = dist.get_world_size()
+ rank = dist.get_rank()
+ # batches = [x[rank::num_replicas] for x in batches if len(x) % num_replicas == 0]
+ # ensure that every sample in the dataset is covered
+ batches_ = []
+ for x in batches:
+ if len(x) % num_replicas == 0:
+ batches_.append(x[rank::num_replicas])
+ else:
+ x_ = x + [x[-1]] * (len(x) - len(x) // num_replicas * num_replicas)
+ batches_.append(x_[rank::num_replicas])
+ batches = batches_
+ return torch.utils.data.DataLoader(dataset,
+ collate_fn=dataset.collater,
+ batch_sampler=batches,
+ num_workers=num_workers,
+ pin_memory=pin_memory)
+
+
+def unpack_dict_to_list(samples):
+ samples_ = []
+ bsz = samples.get('outputs').size(0)
+ for i in range(bsz):
+ res = {}
+ for k, v in samples.items():
+ try:
+ res[k] = v[i]
+ except:
+ pass
+ samples_.append(res)
+ return samples_
+
+
+def remove_padding(x, padding_idx=0):
+ if x is None:
+ return None
+ assert len(x.shape) in [1, 2]
+ if len(x.shape) == 2: # [T, H]
+ return x[np.abs(x).sum(-1) != padding_idx]
+ elif len(x.shape) == 1: # [T]
+ return x[x != padding_idx]
+
+
+def data_loader(fn):
+ """
+ Decorator to make any fx with this use the lazy property
+ :param fn:
+ :return:
+ """
+
+ wraps(fn)
+ attr_name = '_lazy_' + fn.__name__
+
+ def _get_data_loader(self):
+ try:
+ value = getattr(self, attr_name)
+ except AttributeError:
+ try:
+ value = fn(self) # Lazy evaluation, done only once.
+ except AttributeError as e:
+ # Guard against AttributeError suppression. (Issue #142)
+ traceback.print_exc()
+ error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e)
+ raise RuntimeError(error) from e
+ setattr(self, attr_name, value) # Memoize evaluation.
+ return value
+
+ return _get_data_loader
+
+
+class BaseDataset(torch.utils.data.Dataset):
+ def __init__(self, shuffle):
+ super().__init__()
+ self.hparams = hparams
+ self.shuffle = shuffle
+ self.sort_by_len = hparams['sort_by_len']
+ self.sizes = None
+
+ @property
+ def _sizes(self):
+ return self.sizes
+
+ def __getitem__(self, index):
+ raise NotImplementedError
+
+ def collater(self, samples):
+ raise NotImplementedError
+
+ def __len__(self):
+ return len(self._sizes)
+
+ def num_tokens(self, index):
+ return self.size(index)
+
+ def size(self, index):
+ """Return an example's size as a float or tuple. This value is used when
+ filtering a dataset with ``--max-positions``."""
+ return min(self._sizes[index], hparams['max_frames'])
+
+ def ordered_indices(self):
+ """Return an ordered list of indices. Batches will be constructed based
+ on this order."""
+ if self.shuffle:
+ indices = np.random.permutation(len(self))
+ if self.sort_by_len:
+ indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')]
+ else:
+ indices = np.arange(len(self))
+ return indices.tolist()
+
+ @property
+ def num_workers(self):
+ return int(os.getenv('NUM_WORKERS', hparams['ds_workers']))
+
+
+class BaseConcatDataset(ConcatDataset):
+ def collater(self, samples):
+ return self.datasets[0].collater(samples)
+
+ @property
+ def _sizes(self):
+ if not hasattr(self, 'sizes'):
+ self.sizes = list(chain.from_iterable([d._sizes for d in self.datasets]))
+ return self.sizes
+
+ def size(self, index):
+ return min(self._sizes[index], hparams['max_frames'])
+
+ def num_tokens(self, index):
+ return self.size(index)
+
+ def ordered_indices(self):
+ """Return an ordered list of indices. Batches will be constructed based
+ on this order."""
+ if self.datasets[0].shuffle:
+ indices = np.random.permutation(len(self))
+ if self.datasets[0].sort_by_len:
+ indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')]
+ else:
+ indices = np.arange(len(self))
+ return indices
+
+ @property
+ def num_workers(self):
+ return self.datasets[0].num_workers
diff --git a/preprocess/tools/note_transcription/utils/commons/ddp_utils.py b/preprocess/tools/note_transcription/utils/commons/ddp_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b928569ed17a92209562a2405e6bdd7b2202b2b
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/commons/ddp_utils.py
@@ -0,0 +1,164 @@
+from torch.nn.parallel import DistributedDataParallel
+from torch.nn.parallel.distributed import _find_tensors
+import torch.optim
+import torch.utils.data
+import torch
+from packaging import version
+
+class DDP(DistributedDataParallel):
+ """
+ Override the forward call in lightning so it goes to training and validation step respectively
+ """
+
+ def forward(self, *inputs, **kwargs): # pragma: no cover
+ # if version.parse(torch.__version__[:6]) < version.parse("1.11"):
+ if version.parse(torch.__version__) < version.parse("1.11"): # fix the hard [:6] problem
+ self._sync_params()
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ assert len(self.device_ids) == 1
+ if self.module.training:
+ output = self.module.training_step(*inputs[0], **kwargs[0])
+ elif self.module.testing:
+ output = self.module.test_step(*inputs[0], **kwargs[0])
+ else:
+ output = self.module.validation_step(*inputs[0], **kwargs[0])
+ if torch.is_grad_enabled():
+ # We'll return the output object verbatim since it is a freeform
+ # object. We need to find any tensors in this object, though,
+ # because we need to figure out which parameters were used during
+ # this forward pass, to ensure we short circuit reduction for any
+ # unused parameters. Only if `find_unused_parameters` is set.
+ if self.find_unused_parameters:
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
+ else:
+ self.reducer.prepare_for_backward([])
+ elif version.parse("1.11") <= version.parse(torch.__version__) < version.parse("2.0"):
+ from torch.nn.parallel.distributed import \
+ logging, Join, _DDPSink, _tree_flatten_with_rref, _tree_unflatten_with_rref
+ with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
+ if torch.is_grad_enabled() and self.require_backward_grad_sync:
+ self.logger.set_runtime_stats_and_log()
+ self.num_iterations += 1
+ self.reducer.prepare_for_forward()
+
+ # Notify the join context that this process has not joined, if
+ # needed
+ work = Join.notify_join_context(self)
+ if work:
+ self.reducer._set_forward_pass_work_handle(
+ work, self._divide_by_initial_world_size
+ )
+
+ # Calling _rebuild_buckets before forward compuation,
+ # It may allocate new buckets before deallocating old buckets
+ # inside _rebuild_buckets. To save peak memory usage,
+ # call _rebuild_buckets before the peak memory usage increases
+ # during forward computation.
+ # This should be called only once during whole training period.
+ if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
+ logging.info("Reducer buckets have been rebuilt in this iteration.")
+ self._has_rebuilt_buckets = True
+
+ # sync params according to location (before/after forward) user
+ # specified as part of hook, if hook was specified.
+ buffer_hook_registered = hasattr(self, 'buffer_hook')
+ if self._check_sync_bufs_pre_fwd():
+ self._sync_buffers()
+
+ if self._join_config.enable:
+ # Notify joined ranks whether they should sync in backwards pass or not.
+ self._check_global_requires_backward_grad_sync(is_joined_rank=False)
+
+ # modified part
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ if self.module.training:
+ output = self.module.training_step(*inputs[0], **kwargs[0])
+ elif self.module.testing:
+ output = self.module.test_step(*inputs[0], **kwargs[0])
+ else:
+ output = self.module.validation_step(*inputs[0], **kwargs[0])
+
+ # sync params according to location (before/after forward) user
+ # specified as part of hook, if hook was specified.
+ if self._check_sync_bufs_post_fwd():
+ self._sync_buffers()
+
+ if torch.is_grad_enabled() and self.require_backward_grad_sync:
+ self.require_forward_param_sync = True
+ # We'll return the output object verbatim since it is a freeform
+ # object. We need to find any tensors in this object, though,
+ # because we need to figure out which parameters were used during
+ # this forward pass, to ensure we short circuit reduction for any
+ # unused parameters. Only if `find_unused_parameters` is set.
+ if self.find_unused_parameters and not self.static_graph:
+ # Do not need to populate this for static graph.
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
+ else:
+ self.reducer.prepare_for_backward([])
+ else:
+ self.require_forward_param_sync = False
+
+ # TODO: DDPSink is currently enabled for unused parameter detection and
+ # static graph training for first iteration.
+ if (self.find_unused_parameters and not self.static_graph) or (
+ self.static_graph and self.num_iterations == 1
+ ):
+ state_dict = {
+ 'static_graph': self.static_graph,
+ 'num_iterations': self.num_iterations,
+ }
+
+ output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(
+ output
+ )
+ output_placeholders = [None for _ in range(len(output_tensor_list))]
+ # Do not touch tensors that have no grad_fn, which can cause issues
+ # such as https://github.com/pytorch/pytorch/issues/60733
+ for i, output in enumerate(output_tensor_list):
+ if torch.is_tensor(output) and output.grad_fn is None:
+ output_placeholders[i] = output
+
+ # When find_unused_parameters=True, makes tensors which require grad
+ # run through the DDPSink backward pass. When not all outputs are
+ # used in loss, this makes those corresponding tensors receive
+ # undefined gradient which the reducer then handles to ensure
+ # param.grad field is not touched and we don't error out.
+ passthrough_tensor_list = _DDPSink.apply(
+ self.reducer,
+ state_dict,
+ *output_tensor_list,
+ )
+ for i in range(len(output_placeholders)):
+ if output_placeholders[i] is None:
+ output_placeholders[i] = passthrough_tensor_list[i]
+
+ # Reconstruct output data structure.
+ output = _tree_unflatten_with_rref(
+ output_placeholders, treespec, output_is_rref
+ )
+ else:
+ # now pytorch version >= 2.0
+ with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
+ inputs, kwargs = self._pre_forward(*inputs, **kwargs)
+ output = (
+ # self.module.forward(*inputs, **kwargs)
+ # if self._delay_all_reduce_all_params
+ # else self._run_ddp_forward(*inputs, **kwargs)
+ # modified: delete 'delay_all_reduce_named_params' function
+ self._run_ddp_forward(*inputs, **kwargs)
+ )
+ return self._post_forward(output)
+ return output
+
+ def _run_ddp_forward(self, *inputs, **kwargs):
+ if version.parse(torch.__version__) >= version.parse("2.0"):
+ with self._inside_ddp_forward():
+ if self.module.training:
+ output = self.module.training_step(*inputs, **kwargs)
+ elif self.module.testing:
+ output = self.module.test_step(*inputs, **kwargs)
+ else:
+ output = self.module.validation_step(*inputs, **kwargs)
+ return output # type: ignore[index]
+ else:
+ return super(DDP, self)._run_ddp_forward(*inputs, **kwargs)
diff --git a/preprocess/tools/note_transcription/utils/commons/gpu_mem_track.py b/preprocess/tools/note_transcription/utils/commons/gpu_mem_track.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b7fb04dee5aa4168e1b6e4cb887ab1b3c514f1c
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/commons/gpu_mem_track.py
@@ -0,0 +1,113 @@
+import gc
+import datetime
+import inspect
+
+import torch
+import numpy as np
+
+dtype_memory_size_dict = {
+ torch.float64: 64/8,
+ torch.double: 64/8,
+ torch.float32: 32/8,
+ torch.float: 32/8,
+ torch.float16: 16/8,
+ torch.half: 16/8,
+ torch.int64: 64/8,
+ torch.long: 64/8,
+ torch.int32: 32/8,
+ torch.int: 32/8,
+ torch.int16: 16/8,
+ torch.short: 16/6,
+ torch.uint8: 8/8,
+ torch.int8: 8/8,
+}
+# compatibility of torch1.0
+if getattr(torch, "bfloat16", None) is not None:
+ dtype_memory_size_dict[torch.bfloat16] = 16/8
+if getattr(torch, "bool", None) is not None:
+ dtype_memory_size_dict[torch.bool] = 8/8 # pytorch use 1 byte for a bool, see https://github.com/pytorch/pytorch/issues/41571
+
+def get_mem_space(x):
+ try:
+ ret = dtype_memory_size_dict[x]
+ except KeyError:
+ print(f"dtype {x} is not supported!")
+ return ret
+
+class MemTracker(object):
+ """
+ Class used to track pytorch memory usage
+ Arguments:
+ detail(bool, default True): whether the function shows the detail gpu memory usage
+ path(str): where to save log file
+ verbose(bool, default False): whether show the trivial exception
+ device(int): GPU number, default is 0
+ """
+ def __init__(self, detail=True, path='', verbose=False, device=0):
+ self.print_detail = detail
+ self.last_tensor_sizes = set()
+ self.gpu_profile_fn = path + f'{datetime.datetime.now():%d-%b-%y-%H:%M:%S}-gpu_mem_track.txt'
+ self.verbose = verbose
+ self.begin = True
+ self.device = device
+
+ def get_tensors(self):
+ for obj in gc.get_objects():
+ try:
+ if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
+ tensor = obj
+ else:
+ continue
+ if tensor.is_cuda:
+ yield tensor
+ except Exception as e:
+ if self.verbose:
+ print('A trivial exception occured: {}'.format(e))
+
+ def get_tensor_usage(self):
+ sizes = [np.prod(np.array(tensor.size())) * get_mem_space(tensor.dtype) for tensor in self.get_tensors()]
+ return np.sum(sizes) / 1024**2
+
+ def get_allocate_usage(self):
+ return torch.cuda.memory_allocated() / 1024**2
+
+ def clear_cache(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def print_all_gpu_tensor(self, file=None):
+ for x in self.get_tensors():
+ print(x.size(), x.dtype, np.prod(np.array(x.size()))*get_mem_space(x.dtype)/1024**2, file=file)
+
+ def track(self):
+ """
+ Track the GPU memory usage
+ """
+ frameinfo = inspect.stack()[1]
+ where_str = frameinfo.filename + ' line ' + str(frameinfo.lineno) + ': ' + frameinfo.function
+
+ with open(self.gpu_profile_fn, 'a+') as f:
+
+ if self.begin:
+ f.write(f"GPU Memory Track | {datetime.datetime.now():%d-%b-%y-%H:%M:%S} |"
+ f" Total Tensor Used Memory:{self.get_tensor_usage():<7.1f}Mb"
+ f" Total Allocated Memory:{self.get_allocate_usage():<7.1f}Mb\n\n")
+ self.begin = False
+
+ if self.print_detail is True:
+ ts_list = [(tensor.size(), tensor.dtype) for tensor in self.get_tensors()]
+ new_tensor_sizes = {(type(x),
+ tuple(x.size()),
+ ts_list.count((x.size(), x.dtype)),
+ np.prod(np.array(x.size()))*get_mem_space(x.dtype)/1024**2,
+ x.dtype) for x in self.get_tensors()}
+ for t, s, n, m, data_type in new_tensor_sizes - self.last_tensor_sizes:
+ f.write(f'+ | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20} | {data_type}\n')
+ for t, s, n, m, data_type in self.last_tensor_sizes - new_tensor_sizes:
+ f.write(f'- | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20} | {data_type}\n')
+
+ self.last_tensor_sizes = new_tensor_sizes
+
+ f.write(f"\nAt {where_str:<50}"
+ f" Total Tensor Used Memory:{self.get_tensor_usage():<7.1f}Mb"
+ f" Total Allocated Memory:{self.get_allocate_usage():<7.1f}Mb\n\n")
\ No newline at end of file
diff --git a/preprocess/tools/note_transcription/utils/commons/hparams.py b/preprocess/tools/note_transcription/utils/commons/hparams.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce5db5e69006856b81ca1cf04ccb6dd2f2bca203
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/commons/hparams.py
@@ -0,0 +1,131 @@
+import argparse
+import os
+import yaml
+
+global_print_hparams = True
+hparams = {}
+
+
+class Args:
+ def __init__(self, **kwargs):
+ for k, v in kwargs.items():
+ self.__setattr__(k, v)
+
+
+def override_config(old_config: dict, new_config: dict):
+ for k, v in new_config.items():
+ if isinstance(v, dict) and k in old_config:
+ override_config(old_config[k], new_config[k])
+ else:
+ old_config[k] = v
+
+
+def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True, root_dir=''):
+ if config == '' and exp_name == '':
+ parser = argparse.ArgumentParser(description='')
+ parser.add_argument('--config', type=str, default='',
+ help='location of the data corpus')
+ parser.add_argument('--exp_name', type=str, default='', help='exp_name')
+ parser.add_argument('-hp', '--hparams', type=str, default='',
+ help='location of the data corpus')
+ parser.add_argument('--infer', action='store_true', help='infer')
+ parser.add_argument('--validate', action='store_true', help='validate')
+ parser.add_argument('--reset', action='store_true', help='reset hparams')
+ parser.add_argument('--remove', action='store_true', help='remove old ckpt')
+ parser.add_argument('--debug', action='store_true', help='debug')
+ parser.add_argument('--root_dir', type=str, default='', help='root directory of the project.')
+ args, unknown = parser.parse_known_args()
+ print("| Unknow hparams: ", unknown)
+ else:
+ args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
+ infer=False, validate=False, reset=False, debug=False, remove=False, root_dir=root_dir)
+ global hparams
+ assert args.config != '' or args.exp_name != ''
+ root_dir = args.root_dir
+ if args.config != '':
+ assert os.path.exists(os.path.join(root_dir, args.config)), f'| Wrong config path! root_dir: {root_dir}, config_path: {args.config}'
+
+ config_chains = []
+ loaded_config = set()
+
+ def load_config(config_fn):
+ # deep first inheritance and avoid the second visit of one node
+ if not os.path.exists(os.path.join(root_dir, config_fn)):
+ return {}
+ with open(os.path.join(root_dir, config_fn)) as f:
+ hparams_ = yaml.safe_load(f)
+ loaded_config.add(config_fn)
+ if 'base_config' in hparams_:
+ ret_hparams = {}
+ if not isinstance(hparams_['base_config'], list):
+ hparams_['base_config'] = [hparams_['base_config']]
+ for c in hparams_['base_config']:
+ if c.startswith('.'):
+ c = f'{os.path.dirname(config_fn)}/{c}'
+ c = os.path.normpath(c)
+ if c not in loaded_config:
+ override_config(ret_hparams, load_config(c))
+ override_config(ret_hparams, hparams_)
+ else:
+ ret_hparams = hparams_
+ config_chains.append(config_fn)
+ return ret_hparams
+
+ saved_hparams = {}
+ args_work_dir = ''
+ if args.exp_name != '':
+ args_work_dir = os.path.join(root_dir, f'checkpoints/{args.exp_name}')
+ ckpt_config_path = f'{args_work_dir}/config.yaml'
+ if os.path.exists(ckpt_config_path):
+ with open(ckpt_config_path) as f:
+ saved_hparams_ = yaml.safe_load(f)
+ if saved_hparams_ is not None:
+ saved_hparams.update(saved_hparams_)
+ hparams_ = {}
+ if args.config != '':
+ hparams_.update(load_config(args.config))
+ if not args.reset:
+ hparams_.update(saved_hparams)
+ hparams_['work_dir'] = args_work_dir
+
+ # Support config overriding in command line. Support list type config overriding.
+ # Examples: --hparams="a=1,b.c=2,d=[1 1 1]"
+ if args.hparams != "":
+ for new_hparam in args.hparams.split(","):
+ k, v = new_hparam.split("=")
+ v = v.strip("\'\" ")
+ config_node = hparams_
+ for k_ in k.split(".")[:-1]:
+ config_node = config_node[k_]
+ k = k.split(".")[-1]
+ if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]:
+ if type(config_node[k]) == list:
+ v = v.replace(" ", ",")
+ config_node[k] = eval(v)
+ else:
+ config_node[k] = type(config_node[k])(v)
+ if args_work_dir != '' and args.remove:
+ answer = input("REMOVE old checkpoint? Y/N [Default: N]: ")
+ if answer.lower() == "y":
+ pass
+ if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
+ os.makedirs(hparams_['work_dir'], exist_ok=True)
+ with open(ckpt_config_path, 'w') as f:
+ yaml.safe_dump(hparams_, f)
+
+ hparams_['infer'] = args.infer
+ hparams_['debug'] = args.debug
+ hparams_['validate'] = args.validate
+ hparams_['exp_name'] = args.exp_name
+ global global_print_hparams
+ if global_hparams:
+ hparams.clear()
+ hparams.update(hparams_)
+ if print_hparams and global_print_hparams and global_hparams:
+ # print('| Hparams chains: ', config_chains)
+ # print('| Hparams: ')
+ # for i, (k, v) in enumerate(sorted(hparams_.items())):
+ # print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
+ # print("")
+ global_print_hparams = False
+ return hparams_
diff --git a/preprocess/tools/note_transcription/utils/commons/indexed_datasets.py b/preprocess/tools/note_transcription/utils/commons/indexed_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..e15632be30d6296a3c9aa80a1f351058003698b3
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/commons/indexed_datasets.py
@@ -0,0 +1,71 @@
+import pickle
+from copy import deepcopy
+
+import numpy as np
+
+
+class IndexedDataset:
+ def __init__(self, path, num_cache=1):
+ super().__init__()
+ self.path = path
+ self.data_file = None
+ self.data_offsets = np.load(f"{path}.idx", allow_pickle=True).item()['offsets']
+ self.data_file = open(f"{path}.data", 'rb', buffering=-1)
+ self.cache = []
+ self.num_cache = num_cache
+
+ def check_index(self, i):
+ if i < 0 or i >= len(self.data_offsets) - 1:
+ raise IndexError('index out of range')
+
+ def __del__(self):
+ if self.data_file:
+ self.data_file.close()
+
+ def __getitem__(self, i):
+ self.check_index(i)
+ if self.num_cache > 0:
+ for c in self.cache:
+ if c[0] == i:
+ return c[1]
+ self.data_file.seek(self.data_offsets[i])
+ b = self.data_file.read(self.data_offsets[i + 1] - self.data_offsets[i])
+ item = pickle.loads(b)
+ if self.num_cache > 0:
+ self.cache = [(i, deepcopy(item))] + self.cache[:-1]
+ return item
+
+ def __len__(self):
+ return len(self.data_offsets) - 1
+
+class IndexedDatasetBuilder:
+ def __init__(self, path):
+ self.path = path
+ self.out_file = open(f"{path}.data", 'wb')
+ self.byte_offsets = [0]
+
+ def add_item(self, item):
+ s = pickle.dumps(item)
+ bytes = self.out_file.write(s)
+ self.byte_offsets.append(self.byte_offsets[-1] + bytes)
+
+ def finalize(self):
+ self.out_file.close()
+ np.save(open(f"{self.path}.idx", 'wb'), {'offsets': self.byte_offsets})
+
+
+if __name__ == "__main__":
+ import random
+ from tqdm import tqdm
+ ds_path = '/tmp/indexed_ds_example'
+ size = 100
+ items = [{"a": np.random.normal(size=[10000, 10]),
+ "b": np.random.normal(size=[10000, 10])} for i in range(size)]
+ builder = IndexedDatasetBuilder(ds_path)
+ for i in tqdm(range(size)):
+ builder.add_item(items[i])
+ builder.finalize()
+ ds = IndexedDataset(ds_path)
+ for i in tqdm(range(10000)):
+ idx = random.randint(0, size - 1)
+ assert (ds[idx]['a'] == items[idx]['a']).all()
diff --git a/preprocess/tools/note_transcription/utils/commons/losses.py b/preprocess/tools/note_transcription/utils/commons/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..33a6f95aeceacbbf78cdbf15537d609396a8e93e
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/commons/losses.py
@@ -0,0 +1,53 @@
+import torch
+import torch.nn.functional as F
+
+def sigmoid_focal_loss(
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ alpha: float = 0.25,
+ gamma: float = 2,
+ reduction: str = "none",
+) -> torch.Tensor:
+ """
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+
+ Args:
+ inputs (Tensor): A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ alpha (float): Weighting factor in range (0,1) to balance
+ positive vs negative examples or -1 for ignore. Default: ``0.25``.
+ gamma (float): Exponent of the modulating factor (1 - p_t) to
+ balance easy vs hard examples. Default: ``2``.
+ reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
+ ``'none'``: No reduction will be applied to the output.
+ ``'mean'``: The output will be averaged.
+ ``'sum'``: The output will be summed. Default: ``'none'``.
+ Returns:
+ Loss tensor with the reduction option applied.
+ """
+ # Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
+ p = torch.sigmoid(inputs)
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+ p_t = p * targets + (1 - p) * (1 - targets)
+ loss = ce_loss * ((1 - p_t) ** gamma)
+
+ if alpha >= 0: # decrease the importance of negative samples
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+ loss = alpha_t * loss
+
+ # Check reduction option and return loss accordingly
+ if reduction == "none":
+ pass
+ elif reduction == "mean":
+ loss = loss.mean()
+ elif reduction == "sum":
+ loss = loss.sum()
+ else:
+ raise ValueError(
+ f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
+ )
+ return loss
+
diff --git a/preprocess/tools/note_transcription/utils/commons/meters.py b/preprocess/tools/note_transcription/utils/commons/meters.py
new file mode 100644
index 0000000000000000000000000000000000000000..e38790e9f292ec843a820dad73c9795eb2ab8daa
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/commons/meters.py
@@ -0,0 +1,42 @@
+import time
+import torch
+
+
+class AvgrageMeter(object):
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.avg = 0
+ self.sum = 0
+ self.cnt = 0
+
+ def update(self, val, n=1):
+ self.sum += val * n
+ self.cnt += n
+ self.avg = self.sum / self.cnt
+
+
+class Timer:
+ timer_map = {}
+
+ def __init__(self, name, enable=False):
+ if name not in Timer.timer_map:
+ Timer.timer_map[name] = 0
+ self.name = name
+ self.enable = enable
+
+ def __enter__(self):
+ if self.enable:
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ self.t = time.time()
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if self.enable:
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ Timer.timer_map[self.name] += time.time() - self.t
+ if self.enable:
+ print(f'[Timer] {self.name}: {Timer.timer_map[self.name]}')
diff --git a/preprocess/tools/note_transcription/utils/commons/multiprocess_utils.py b/preprocess/tools/note_transcription/utils/commons/multiprocess_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..62cb37872f9c6b7a2b572d666fbd9c43ec00b2a9
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/commons/multiprocess_utils.py
@@ -0,0 +1,180 @@
+import os
+import traceback
+from functools import partial
+from tqdm import tqdm
+import torch
+
+def chunked_worker(worker_id, args_queue=None, results_queue=None, init_ctx_func=None):
+ ctx = init_ctx_func(worker_id) if init_ctx_func is not None else None
+ while True:
+ args = args_queue.get()
+ if args == '':
+ return
+ job_idx, map_func, arg = args
+ try:
+ map_func_ = partial(map_func, ctx=ctx) if ctx is not None else map_func
+ if isinstance(arg, dict):
+ res = map_func_(**arg)
+ elif isinstance(arg, (list, tuple)):
+ res = map_func_(*arg)
+ else:
+ res = map_func_(arg)
+ results_queue.put((job_idx, res))
+ except:
+ traceback.print_exc()
+ results_queue.put((job_idx, None))
+
+
+class MultiprocessManager:
+ def __init__(self, num_workers=None, init_ctx_func=None, multithread=False, queue_max=-1):
+ if multithread:
+ from multiprocessing.dummy import Queue, Process
+ else:
+ from multiprocessing import Queue, Process
+ if num_workers is None:
+ num_workers = int(os.getenv('N_PROC', os.cpu_count()))
+ self.num_workers = num_workers
+ self.results_queue = Queue(maxsize=-1)
+ self.jobs_pending = []
+ self.args_queue = Queue(maxsize=queue_max)
+ self.workers = []
+ self.total_jobs = 0
+ self.multithread = multithread
+ for i in range(num_workers):
+ if multithread:
+ p = Process(target=chunked_worker,
+ args=(i, self.args_queue, self.results_queue, init_ctx_func))
+ else:
+ p = Process(target=chunked_worker,
+ args=(i, self.args_queue, self.results_queue, init_ctx_func),
+ daemon=True)
+ self.workers.append(p)
+ p.start()
+
+ def add_job(self, func, args):
+ if not self.args_queue.full():
+ self.args_queue.put((self.total_jobs, func, args))
+ else:
+ self.jobs_pending.append((self.total_jobs, func, args))
+ self.total_jobs += 1
+
+ def get_results(self):
+ self.n_finished = 0
+ while self.n_finished < self.total_jobs:
+ while len(self.jobs_pending) > 0 and not self.args_queue.full():
+ self.args_queue.put(self.jobs_pending[0])
+ self.jobs_pending = self.jobs_pending[1:]
+ job_id, res = self.results_queue.get()
+ yield job_id, res
+ self.n_finished += 1
+ for w in range(self.num_workers):
+ self.args_queue.put("")
+ for w in self.workers:
+ w.join()
+
+ def close(self):
+ if not self.multithread:
+ for w in self.workers:
+ w.terminate()
+
+ def __len__(self):
+ return self.total_jobs
+
+
+def multiprocess_run_tqdm(map_func, args, num_workers=None, ordered=True, init_ctx_func=None,
+ multithread=False, queue_max=-1, desc=None):
+ for i, res in tqdm(
+ multiprocess_run(map_func, args, num_workers, ordered, init_ctx_func, multithread,
+ queue_max=queue_max),
+ total=len(args), desc=desc):
+ yield i, res
+
+
+def multiprocess_run(map_func, args, num_workers=None, ordered=True, init_ctx_func=None, multithread=False,
+ queue_max=-1):
+ """
+ Multiprocessing running chunked jobs.
+
+ Examples:
+ >>> for res in tqdm(multiprocess_run(job_func, args):
+ >>> print(res)
+
+ :param map_func:
+ :param args:
+ :param num_workers:
+ :param ordered:
+ :param init_ctx_func:
+ :param q_max_size:
+ :param multithread:
+ :return:
+ """
+ if num_workers is None:
+ num_workers = int(os.getenv('N_PROC', os.cpu_count()))
+ manager = MultiprocessManager(num_workers, init_ctx_func, multithread, queue_max=queue_max)
+ for arg in args:
+ manager.add_job(map_func, arg)
+ if ordered:
+ n_jobs = len(args)
+ results = ['' for _ in range(n_jobs)]
+ i_now = 0
+ for job_i, res in manager.get_results():
+ results[job_i] = res
+ while i_now < n_jobs and (not isinstance(results[i_now], str) or results[i_now] != ''):
+ yield i_now, results[i_now]
+ results[i_now] = None
+ i_now += 1
+ else:
+ for job_i, res in manager.get_results():
+ yield job_i, res
+ manager.close()
+
+# #### this is the old version of chunked_multiprocess_run
+def chunked_worker_old(worker_id, map_func, args, results_queue=None, init_ctx_func=None):
+ ctx = init_ctx_func(worker_id) if init_ctx_func is not None else None
+ for job_idx, arg in args:
+ try:
+ if not isinstance(arg, tuple) and not isinstance(arg, list):
+ arg = [arg]
+ if ctx is not None:
+ res = map_func(*arg, ctx=ctx)
+ else:
+ res = map_func(*arg)
+ results_queue.put((job_idx, res))
+ except:
+ traceback.print_exc()
+ results_queue.put((job_idx, None))
+
+def chunked_multiprocess_run(
+ map_func, args, num_workers=None, ordered=True,
+ init_ctx_func=None, q_max_size=1000, multithread=False):
+ if multithread:
+ from multiprocessing.dummy import Queue, Process
+ else:
+ from multiprocessing import Queue, Process
+ args = zip(range(len(args)), args)
+ args = list(args)
+ n_jobs = len(args)
+ if num_workers is None:
+ num_workers = int(os.getenv('N_PROC', os.cpu_count()))
+ results_queues = []
+ if ordered:
+ for i in range(num_workers):
+ results_queues.append(Queue(maxsize=q_max_size // num_workers))
+ else:
+ results_queue = Queue(maxsize=q_max_size)
+ for i in range(num_workers):
+ results_queues.append(results_queue)
+ workers = []
+ for i in range(num_workers):
+ args_worker = args[i::num_workers]
+ p = Process(target=chunked_worker_old, args=(
+ i, map_func, args_worker, results_queues[i], init_ctx_func), daemon=True)
+ workers.append(p)
+ p.start()
+ for n_finished in range(n_jobs):
+ results_queue = results_queues[n_finished % num_workers]
+ job_idx, res = results_queue.get()
+ assert job_idx == n_finished or not ordered, (job_idx, n_finished)
+ yield res
+ for w in workers:
+ w.join()
diff --git a/preprocess/tools/note_transcription/utils/commons/signal.py b/preprocess/tools/note_transcription/utils/commons/signal.py
new file mode 100644
index 0000000000000000000000000000000000000000..8499c6bed9d2dcad2f7bd6fb09a73eefb4ad8abb
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/commons/signal.py
@@ -0,0 +1,74 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+def get_filter_2d(kernel, kernel_size, channels, no_grad=True):
+ # Reshape to 2d depthwise convolutional weight
+ kernel = kernel.view(1, 1, kernel_size, kernel_size)
+ kernel = kernel.repeat(channels, 1, 1, 1)
+
+ filter = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, groups=channels,
+ bias=False, padding=kernel_size // 2)
+
+ filter.weight.data = kernel
+ if no_grad:
+ filter.weight.requires_grad = False
+
+ return filter
+
+def get_filter_1d(kernel, kernel_size, channels, no_grad=True):
+ kernel = kernel.view(1, 1, kernel_size)
+ kernel = kernel.repeat(channels, 1, 1)
+
+ filter = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, groups=channels,
+ bias=False, padding=kernel_size // 2)
+
+ filter.weight.data = kernel
+ if no_grad:
+ filter.weight.requires_grad = False
+
+ return filter
+
+def get_gaussian_kernel_2d(kernel_size, sigma):
+ # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
+ x_coord = torch.arange(kernel_size)
+ x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
+ y_grid = x_grid.t()
+ xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
+
+ mean = (kernel_size - 1) / 2.
+ variance = sigma ** 2.
+
+ # Calculate the 2-dimensional gaussian kernel which is
+ # the product of two gaussian distributions for two different
+ # variables (in this case called x and y)
+ gaussian_kernel = (1. / (2. * np.pi * variance)) * torch.exp(
+ -torch.sum((xy_grid - mean) ** 2., dim=-1) / (2 * variance))
+
+ # Make sure sum of values in gaussian kernel equals 1.
+ gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
+
+ return gaussian_kernel
+
+def get_gaussian_kernel_1d(kernel_size, sigma):
+ x_grid = torch.arange(kernel_size)
+ mean = (kernel_size - 1) / 2.
+ variance = sigma ** 2.
+ gaussian_kernel = (1. / ((2. * np.pi) ** 0.5 * sigma)) * torch.exp(-(x_grid - mean) ** 2. / (2 * variance))
+ gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
+ return gaussian_kernel
+
+def get_hann_kernel_1d(kernel_size, periodic=False):
+ # periodic=False gives symmetric kernel, otherwise equivalent to hann(kernel_size + 1)
+ return torch.hann_window(kernel_size, periodic)
+
+def get_triangle_kernel_1d(kernel_size):
+ kernel = torch.zeros(kernel_size)
+ for idx in range(kernel_size):
+ kernel[idx] = 1 - abs((idx - (kernel_size - 1) / 2) / ((kernel_size - 1) / 2))
+ return kernel
+
+def add_gaussian_noise(tensor, mean=0, std=1):
+ noise = torch.randn(tensor.size()) * std + mean
+ noisy_tensor = tensor + noise
+ return noisy_tensor
diff --git a/preprocess/tools/note_transcription/utils/commons/single_thread_env.py b/preprocess/tools/note_transcription/utils/commons/single_thread_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..849219afd2cddec2ec6d489f12f60a34994bfb80
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/commons/single_thread_env.py
@@ -0,0 +1,5 @@
+import os
+
+os.environ["OMP_NUM_THREADS"] = "1"
+os.environ['TF_NUM_INTEROP_THREADS'] = '1'
+os.environ['TF_NUM_INTRAOP_THREADS'] = '1'
diff --git a/preprocess/tools/note_transcription/utils/commons/tensor_utils.py b/preprocess/tools/note_transcription/utils/commons/tensor_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..be4b69a4f135b95fcf18618668ed909314f24871
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/commons/tensor_utils.py
@@ -0,0 +1,92 @@
+import torch
+import torch.distributed as dist
+
+
+def reduce_tensors(metrics):
+ new_metrics = {}
+ for k, v in metrics.items():
+ if isinstance(v, torch.Tensor):
+ dist.all_reduce(v)
+ v = v / dist.get_world_size()
+ if type(v) is dict:
+ v = reduce_tensors(v)
+ new_metrics[k] = v
+ return new_metrics
+
+
+def tensors_to_scalars(tensors):
+ if isinstance(tensors, torch.Tensor):
+ tensors = tensors.item()
+ return tensors
+ elif isinstance(tensors, dict):
+ new_tensors = {}
+ for k, v in tensors.items():
+ v = tensors_to_scalars(v)
+ new_tensors[k] = v
+ return new_tensors
+ elif isinstance(tensors, list):
+ return [tensors_to_scalars(v) for v in tensors]
+ else:
+ return tensors
+
+
+def tensors_to_np(tensors):
+ if isinstance(tensors, dict):
+ new_np = {}
+ for k, v in tensors.items():
+ if isinstance(v, torch.Tensor):
+ v = v.cpu().numpy()
+ if type(v) is dict:
+ v = tensors_to_np(v)
+ new_np[k] = v
+ elif isinstance(tensors, list):
+ new_np = []
+ for v in tensors:
+ if isinstance(v, torch.Tensor):
+ v = v.cpu().numpy()
+ if type(v) is dict:
+ v = tensors_to_np(v)
+ new_np.append(v)
+ elif isinstance(tensors, torch.Tensor):
+ v = tensors
+ if isinstance(v, torch.Tensor):
+ v = v.cpu().numpy()
+ if type(v) is dict:
+ v = tensors_to_np(v)
+ new_np = v
+ else:
+ raise Exception(f'tensors_to_np does not support type {type(tensors)}.')
+ return new_np
+
+
+def move_to_cpu(tensors):
+ ret = {}
+ for k, v in tensors.items():
+ if isinstance(v, torch.Tensor):
+ v = v.cpu()
+ if type(v) is dict:
+ v = move_to_cpu(v)
+ ret[k] = v
+ return ret
+
+
+def move_to_cuda(batch, gpu_id=0):
+ # base case: object can be directly moved using `cuda` or `to`
+ if callable(getattr(batch, 'cuda', None)):
+ return batch.cuda(gpu_id, non_blocking=True)
+ elif callable(getattr(batch, 'to', None)):
+ return batch.to(torch.device('cuda', gpu_id), non_blocking=True)
+ elif isinstance(batch, list):
+ for i, x in enumerate(batch):
+ batch[i] = move_to_cuda(x, gpu_id)
+ return batch
+ elif isinstance(batch, tuple):
+ batch = list(batch)
+ for i, x in enumerate(batch):
+ batch[i] = move_to_cuda(x, gpu_id)
+ return tuple(batch)
+ elif isinstance(batch, dict):
+ for k, v in batch.items():
+ batch[k] = move_to_cuda(v, gpu_id)
+ return batch
+ return batch
diff --git a/preprocess/tools/note_transcription/utils/commons/trainer.py b/preprocess/tools/note_transcription/utils/commons/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c67139a2f633fb3d85aa35f2807afea210ec1651
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/commons/trainer.py
@@ -0,0 +1,557 @@
+import random
+import subprocess
+import traceback
+from datetime import datetime
+
+from torch.cuda.amp import GradScaler, autocast
+import numpy as np
+import torch.optim
+import torch.utils.data
+import copy
+import logging
+import os
+import re
+import sys
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import tqdm
+
+from .ckpt_utils import get_last_checkpoint, get_all_ckpts
+from .ddp_utils import DDP
+from .hparams import hparams
+from .tensor_utils import move_to_cuda
+
+
+class Tee(object):
+ def __init__(self, name, mode):
+ self.file = open(name, mode)
+ self.stdout = sys.stdout
+ sys.stdout = self
+
+ def __del__(self):
+ sys.stdout = self.stdout
+ self.file.close()
+
+ def write(self, data):
+ self.file.write(data)
+ self.stdout.write(data)
+
+ def flush(self):
+ self.file.flush()
+
+
+class Trainer:
+ def __init__(
+ self,
+ work_dir,
+ default_save_path=None,
+ accumulate_grad_batches=1,
+ max_updates=160000,
+ print_nan_grads=False,
+ val_check_interval=2000,
+ num_sanity_val_steps=5,
+ amp=False,
+ # tb logger
+ log_save_interval=100,
+ tb_log_interval=10,
+ # checkpoint
+ monitor_key='val_loss',
+ monitor_mode='min',
+ num_ckpt_keep=5,
+ save_best=True,
+ resume_from_checkpoint=0,
+ seed=1234,
+ debug=False,
+ ):
+ os.makedirs(work_dir, exist_ok=True)
+ self.work_dir = work_dir
+ self.accumulate_grad_batches = accumulate_grad_batches
+ self.max_updates = max_updates
+ self.num_sanity_val_steps = num_sanity_val_steps
+ self.print_nan_grads = print_nan_grads
+ self.default_save_path = default_save_path
+ self.resume_from_checkpoint = resume_from_checkpoint if resume_from_checkpoint > 0 else None
+ self.seed = seed
+ self.debug = debug
+ # model and optm
+ self.task = None
+ self.optimizers = []
+
+ # trainer state
+ self.testing = False
+ self.global_step = 0
+ self.current_epoch = 0
+ self.total_batches = 0
+
+ # configure checkpoint
+ self.monitor_key = monitor_key
+ self.num_ckpt_keep = num_ckpt_keep
+ self.save_best = save_best
+ self.monitor_op = np.less if monitor_mode == 'min' else np.greater
+ self.best_val_results = np.Inf if monitor_mode == 'min' else -np.Inf
+ self.mode = 'min'
+
+ # allow int, string and gpu list
+ self.all_gpu_ids = [
+ int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",") if x != '']
+ self.num_gpus = len(self.all_gpu_ids)
+ self.on_gpu = self.num_gpus > 0
+ self.root_gpu = 0
+ logging.info(f'GPU available: {torch.cuda.is_available()}, GPU used: {self.all_gpu_ids}')
+ self.use_ddp = self.num_gpus > 1
+ self.proc_rank = 0
+ # Tensorboard logging
+ self.log_save_interval = log_save_interval
+ self.val_check_interval = val_check_interval
+ self.tb_log_interval = tb_log_interval
+ self.amp = amp
+ self.amp_scalar = GradScaler()
+
+ def test(self, task_cls):
+ self.testing = True
+ self.fit(task_cls)
+
+ def fit(self, task_cls):
+ if len(self.all_gpu_ids) > 1:
+ mp.spawn(self.ddp_run, nprocs=self.num_gpus, args=(task_cls, copy.deepcopy(hparams)))
+ else:
+ self.task = task_cls()
+ self.task.trainer = self
+ self.run_single_process(self.task)
+ return 1
+
+ def ddp_run(self, gpu_idx, task_cls, hparams_):
+ hparams.update(hparams_)
+ self.proc_rank = gpu_idx
+ self.init_ddp_connection(self.proc_rank, self.num_gpus)
+ if dist.get_rank() != 0 and not self.debug:
+ sys.stdout = open(os.devnull, "w")
+ sys.stderr = open(os.devnull, "w")
+ task = task_cls()
+ task.trainer = self
+ torch.cuda.set_device(gpu_idx)
+ self.root_gpu = gpu_idx
+ self.task = task
+ self.run_single_process(task)
+
+ def run_single_process(self, task):
+ """Sanity check a few things before starting actual training.
+
+ :param task:
+ """
+ # build model, optm and load checkpoint
+ if self.proc_rank == 0:
+ self.save_terminal_logs()
+ if not self.testing:
+ self.save_codes()
+
+ model = task.build_model()
+ if model is not None:
+ task.model = model
+ checkpoint, _ = get_last_checkpoint(self.work_dir, self.resume_from_checkpoint)
+ if checkpoint is not None:
+ self.restore_weights(checkpoint)
+ elif self.on_gpu:
+ task.cuda(self.root_gpu)
+ if not self.testing:
+ self.optimizers = task.configure_optimizers()
+ self.fisrt_epoch = True
+ if checkpoint is not None:
+ self.restore_opt_state(checkpoint)
+ del checkpoint
+ # clear cache after restore
+ if self.on_gpu:
+ torch.cuda.empty_cache()
+
+ if self.use_ddp:
+ self.task = self.configure_ddp(self.task)
+ dist.barrier()
+
+ task_ref = self.get_task_ref()
+ task_ref.trainer = self
+ task_ref.testing = self.testing
+ # link up experiment object
+ if self.proc_rank == 0:
+ task_ref.build_tensorboard(save_dir=self.work_dir, name='tb_logs')
+ else:
+ os.makedirs('tmp', exist_ok=True)
+ task_ref.build_tensorboard(save_dir='tmp', name='tb_tmp')
+ self.logger = task_ref.logger
+ try:
+ if self.testing:
+ self.run_evaluation(test=True)
+ else:
+ self.train()
+ except KeyboardInterrupt as e:
+ traceback.print_exc()
+ task_ref.on_keyboard_interrupt()
+
+ ####################
+ # valid and test
+ ####################
+ def run_evaluation(self, test=False):
+ eval_results = self.evaluate(self.task, test, tqdm_desc='Valid' if not test else 'test',
+ max_batches=hparams['eval_max_batches'])
+ if eval_results is not None and 'tb_log' in eval_results:
+ tb_log_output = eval_results['tb_log']
+ self.log_metrics_to_tb(tb_log_output)
+ if self.proc_rank == 0 and not test:
+ self.save_checkpoint(epoch=self.current_epoch, logs=eval_results)
+
+ def evaluate(self, task, test=False, tqdm_desc='Valid', max_batches=None):
+ if max_batches == -1:
+ max_batches = None
+ # enable eval mode
+ task.zero_grad()
+ task.eval()
+ torch.set_grad_enabled(False)
+
+ task_ref = self.get_task_ref()
+ if test:
+ ret = task_ref.test_start()
+ if ret == 'EXIT':
+ return
+ else:
+ task_ref.validation_start()
+ outputs = []
+ dataloader = task_ref.test_dataloader() if test else task_ref.val_dataloader()
+ pbar = tqdm.tqdm(dataloader, desc=tqdm_desc, total=max_batches, dynamic_ncols=True, unit='step',
+ disable=self.root_gpu > 0)
+ # give model a chance to do something with the outputs (and method defined)
+ for batch_idx, batch in enumerate(pbar):
+ if batch is None: # pragma: no cover
+ continue
+ # stop short when on fast_dev_run (sets max_batch=1)
+ if max_batches is not None and batch_idx >= max_batches:
+ break
+
+ # make dataloader_idx arg in validation_step optional
+ if self.on_gpu:
+ batch = move_to_cuda(batch, self.root_gpu)
+ args = [batch, batch_idx]
+ if self.use_ddp:
+ output = task(*args)
+ else:
+ if test:
+ output = task_ref.test_step(*args)
+ else:
+ output = task_ref.validation_step(*args)
+ # track outputs for collation
+ outputs.append(output)
+ # give model a chance to do something with the outputs (and method defined)
+ if test:
+ eval_results = task_ref.test_end(outputs)
+ else:
+ eval_results = task_ref.validation_end(outputs)
+ # enable train mode again
+ task.train()
+ torch.set_grad_enabled(True)
+ return eval_results
+
+ ####################
+ # train
+ ####################
+ def train(self):
+ task_ref = self.get_task_ref()
+ task_ref.on_train_start()
+ if self.num_sanity_val_steps > 0:
+ # run tiny validation (if validation defined) to make sure program won't crash during val
+ self.evaluate(self.task, False, 'Sanity Val', max_batches=self.num_sanity_val_steps)
+ # clear cache before training
+ if self.on_gpu:
+ torch.cuda.empty_cache()
+ dataloader = task_ref.train_dataloader()
+ epoch = self.current_epoch
+ # run all epochs
+ while True:
+ # set seed for distributed sampler (enables shuffling for each epoch)
+ if self.use_ddp and hasattr(dataloader.sampler, 'set_epoch'):
+ dataloader.sampler.set_epoch(epoch)
+ # update training progress in trainer and model
+ task_ref.current_epoch = epoch
+ self.current_epoch = epoch
+ # total batches includes multiple val checks
+ self.batch_loss_value = 0 # accumulated grads
+ # before epoch hook
+ task_ref.on_epoch_start()
+
+ # run epoch
+ train_pbar = tqdm.tqdm(dataloader, initial=self.global_step, total=float('inf'),
+ dynamic_ncols=True, unit='step', disable=self.root_gpu > 0)
+ for batch_idx, batch in enumerate(train_pbar):
+ if self.global_step % self.val_check_interval == 0 and not self.fisrt_epoch:
+ self.run_evaluation()
+ pbar_metrics, tb_metrics = self.run_training_batch(batch_idx, batch)
+ train_pbar.set_postfix(**pbar_metrics)
+ self.fisrt_epoch = False
+ # when metrics should be logged
+ if (self.global_step + 1) % self.tb_log_interval == 0:
+ # logs user requested information to logger
+ self.log_metrics_to_tb(tb_metrics)
+
+ self.global_step += 1
+ task_ref.global_step = self.global_step
+ if self.global_step > self.max_updates:
+ print("| Training end..")
+ break
+ # epoch end hook
+ task_ref.on_epoch_end()
+ epoch += 1
+ if self.global_step > self.max_updates:
+ break
+ task_ref.on_train_end()
+
+ def run_training_batch(self, batch_idx, batch):
+ if batch is None:
+ return {}
+ all_progress_bar_metrics = []
+ all_log_metrics = []
+ task_ref = self.get_task_ref()
+ for opt_idx, optimizer in enumerate(self.optimizers):
+ if optimizer is None:
+ continue
+ # make sure only the gradients of the current optimizer's paramaters are calculated
+ # in the training step to prevent dangling gradients in multiple-optimizer setup.
+ if len(self.optimizers) > 1:
+ for param in task_ref.parameters():
+ param.requires_grad = False
+ for group in optimizer.param_groups:
+ for param in group['params']:
+ param.requires_grad = True
+
+ # forward pass
+ with autocast(enabled=self.amp):
+ if self.on_gpu:
+ batch = move_to_cuda(copy.copy(batch), self.root_gpu)
+ args = [batch, batch_idx, opt_idx]
+ if self.use_ddp:
+ output = self.task(*args)
+ else:
+ output = task_ref.training_step(*args)
+ loss = output['loss']
+ if loss is None:
+ continue
+ progress_bar_metrics = output['progress_bar']
+ log_metrics = output['tb_log']
+ # accumulate loss
+ loss = loss / self.accumulate_grad_batches
+
+ # backward pass
+ if loss.requires_grad:
+ if self.amp:
+ self.amp_scalar.scale(loss).backward()
+ else:
+ loss.backward()
+
+ # track progress bar metrics
+ all_log_metrics.append(log_metrics)
+ all_progress_bar_metrics.append(progress_bar_metrics)
+
+ if loss is None:
+ continue
+
+ # nan grads
+ if self.print_nan_grads:
+ has_nan_grad = False
+ for name, param in task_ref.named_parameters():
+ if (param.grad is not None) and torch.isnan(param.grad.float()).any():
+ print("| NaN params: ", name, param, param.grad)
+ has_nan_grad = True
+ if has_nan_grad:
+ exit(0)
+
+ # gradient update with accumulated gradients
+ if (self.global_step + 1) % self.accumulate_grad_batches == 0:
+ task_ref.on_before_optimization(opt_idx)
+ if self.amp:
+ self.amp_scalar.step(optimizer)
+ self.amp_scalar.update()
+ else:
+ optimizer.step()
+ optimizer.zero_grad()
+ task_ref.on_after_optimization(self.current_epoch, batch_idx, optimizer, opt_idx)
+
+ # collapse all metrics into one dict
+ all_progress_bar_metrics = {k: v for d in all_progress_bar_metrics for k, v in d.items()}
+ all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()}
+ return all_progress_bar_metrics, all_log_metrics
+
+ ####################
+ # load and save checkpoint
+ ####################
+ def restore_weights(self, checkpoint):
+ # load model state
+ task_ref = self.get_task_ref()
+
+ for k, v in checkpoint['state_dict'].items():
+ getattr(task_ref, k).load_state_dict(v)
+
+ if self.on_gpu:
+ task_ref.cuda(self.root_gpu)
+ # load training state (affects trainer only)
+ self.best_val_results = checkpoint['checkpoint_callback_best']
+ self.global_step = checkpoint['global_step']
+ self.current_epoch = checkpoint['epoch']
+ task_ref.global_step = self.global_step
+
+ # wait for all models to restore weights
+ if self.use_ddp:
+ # wait for all processes to catch up
+ dist.barrier()
+
+ def restore_opt_state(self, checkpoint):
+ if self.testing:
+ return
+ # restore the optimizers
+ optimizer_states = checkpoint['optimizer_states']
+ for optimizer, opt_state in zip(self.optimizers, optimizer_states):
+ if optimizer is None:
+ return
+ try:
+ optimizer.load_state_dict(opt_state)
+ # move optimizer to GPU 1 weight at a time
+ if self.on_gpu:
+ for state in optimizer.state.values():
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor):
+ state[k] = v.cuda(self.root_gpu)
+ except ValueError:
+ print("| WARMING: optimizer parameters not match !!!")
+ try:
+ if dist.is_initialized() and dist.get_rank() > 0:
+ return
+ except Exception as e:
+ print(e)
+ return
+ did_restore = True
+ return did_restore
+
+ def save_checkpoint(self, epoch, logs=None):
+ monitor_op = np.less
+ ckpt_path = f'{self.work_dir}/model_ckpt_steps_{self.global_step}.ckpt'
+ logging.info(f'Epoch {epoch:05d}@{self.global_step}: saving model to {ckpt_path}')
+ self._atomic_save(ckpt_path)
+ for old_ckpt in get_all_ckpts(self.work_dir)[self.num_ckpt_keep:]:
+ pass
+ current = None
+ if logs is not None and self.monitor_key in logs:
+ current = logs[self.monitor_key]
+ if current is not None and self.save_best:
+ if monitor_op(current, self.best_val_results):
+ best_filepath = f'{self.work_dir}/model_ckpt_best.pt'
+ self.best_val_results = current
+ logging.info(
+ f'Epoch {epoch:05d}@{self.global_step}: {self.monitor_key} reached {current:0.5f}. '
+ f'Saving model to {best_filepath}')
+ self._atomic_save(best_filepath)
+
+ def _atomic_save(self, filepath):
+ checkpoint = self.dump_checkpoint()
+ tmp_path = str(filepath) + ".part"
+ torch.save(checkpoint, tmp_path, _use_new_zipfile_serialization=False)
+ os.replace(tmp_path, filepath)
+
+ def dump_checkpoint(self):
+ checkpoint = {'epoch': self.current_epoch, 'global_step': self.global_step,
+ 'checkpoint_callback_best': self.best_val_results}
+ # save optimizers
+ optimizer_states = []
+ for i, optimizer in enumerate(self.optimizers):
+ if optimizer is not None:
+ optimizer_states.append(optimizer.state_dict())
+
+ checkpoint['optimizer_states'] = optimizer_states
+ task_ref = self.get_task_ref()
+ checkpoint['state_dict'] = {
+ k: v.state_dict() for k, v in task_ref.named_children() if len(list(v.parameters())) > 0}
+ return checkpoint
+
+ ####################
+ # DDP
+ ####################
+ def configure_ddp(self, task):
+ task = DDP(task, device_ids=[self.root_gpu], find_unused_parameters=hparams.get('find_unused_parameters', True))
+ random.seed(self.seed)
+ np.random.seed(self.seed)
+ return task
+
+ def init_ddp_connection(self, proc_rank, world_size):
+ root_node = '127.0.0.1'
+ root_node = self.resolve_root_node_address(root_node)
+ os.environ['MASTER_ADDR'] = root_node
+ dist.init_process_group('nccl', rank=proc_rank, world_size=world_size)
+
+ def resolve_root_node_address(self, root_node):
+ if '[' in root_node:
+ name = root_node.split('[')[0]
+ number = root_node.split(',')[0]
+ if '-' in number:
+ number = number.split('-')[0]
+ number = re.sub('[^0-9]', '', number)
+ root_node = name + number
+ return root_node
+
+ ####################
+ # utils
+ ####################
+ def get_task_ref(self):
+ from .base_task import BaseTask
+ task: BaseTask = self.task.module if isinstance(self.task, DDP) else self.task
+ return task
+
+ def log_metrics_to_tb(self, metrics, step=None):
+ """Logs the metric dict passed in.
+
+ :param metrics:
+ """
+ # turn all tensors to scalars
+ scalar_metrics = self.metrics_to_scalars(metrics)
+
+ step = step if step is not None else self.global_step
+ # log actual metrics
+ if self.proc_rank == 0:
+ self.log_metrics(self.logger, scalar_metrics, step=step)
+
+ @staticmethod
+ def log_metrics(logger, metrics, step=None):
+ for k, v in metrics.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ logger.add_scalar(k, v, step)
+
+ def metrics_to_scalars(self, metrics):
+ new_metrics = {}
+ for k, v in metrics.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+
+ if type(v) is dict:
+ v = self.metrics_to_scalars(v)
+
+ new_metrics[k] = v
+
+ return new_metrics
+
+ def save_terminal_logs(self):
+ t = datetime.now().strftime('%Y%m%d%H%M%S')
+ os.makedirs(f'{self.work_dir}/terminal_logs', exist_ok=True)
+ Tee(f'{self.work_dir}/terminal_logs/log_{t}.txt', 'w')
+
+ def save_codes(self):
+ if len(hparams['save_codes']) > 0:
+ t = datetime.now().strftime('%Y%m%d%H%M%S')
+ code_dir = f'{self.work_dir}/codes/{t}'
+ subprocess.check_call(f'mkdir -p "{code_dir}"', shell=True)
+ for c in hparams['save_codes']:
+ if os.path.exists(c):
+ subprocess.check_call(
+ f'rsync -aR '
+ f'--include="*.py" '
+ f'--include="*.yaml" '
+ f'--exclude="__pycache__" '
+ f'--include="*/" '
+ f'--exclude="*" '
+ f'"./{c}" "{code_dir}/"',
+ shell=True)
+ print(f"| Copied codes to {code_dir}.")
diff --git a/preprocess/tools/note_transcription/utils/metrics/diagonal_metrics.py b/preprocess/tools/note_transcription/utils/metrics/diagonal_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba9807c1a594b38632c4731391e2d4fa3289037b
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/metrics/diagonal_metrics.py
@@ -0,0 +1,74 @@
+import torch
+
+
+def get_focus_rate(attn, src_padding_mask=None, tgt_padding_mask=None):
+ '''
+ attn: bs x L_t x L_s
+ '''
+ if src_padding_mask is not None:
+ attn = attn * (1 - src_padding_mask.float())[:, None, :]
+
+ if tgt_padding_mask is not None:
+ attn = attn * (1 - tgt_padding_mask.float())[:, :, None]
+
+ focus_rate = attn.max(-1).values.sum(-1)
+ focus_rate = focus_rate / attn.sum(-1).sum(-1)
+ return focus_rate
+
+
+def get_phone_coverage_rate(attn, src_padding_mask=None, src_seg_mask=None, tgt_padding_mask=None):
+ '''
+ attn: bs x L_t x L_s
+ '''
+ src_mask = attn.new(attn.size(0), attn.size(-1)).bool().fill_(False)
+ if src_padding_mask is not None:
+ src_mask |= src_padding_mask
+ if src_seg_mask is not None:
+ src_mask |= src_seg_mask
+
+ attn = attn * (1 - src_mask.float())[:, None, :]
+ if tgt_padding_mask is not None:
+ attn = attn * (1 - tgt_padding_mask.float())[:, :, None]
+
+ phone_coverage_rate = attn.max(1).values.sum(-1)
+ # phone_coverage_rate = phone_coverage_rate / attn.sum(-1).sum(-1)
+ phone_coverage_rate = phone_coverage_rate / (1 - src_mask.float()).sum(-1)
+ return phone_coverage_rate
+
+
+def get_diagonal_focus_rate(attn, attn_ks, target_len, src_padding_mask=None, tgt_padding_mask=None,
+ band_mask_factor=5, band_width=50):
+ '''
+ attn: bx x L_t x L_s
+ attn_ks: shape: tensor with shape [batch_size], input_lens/output_lens
+
+ diagonal: y=k*x (k=attn_ks, x:output, y:input)
+ 1 0 0
+ 0 1 0
+ 0 0 1
+ y>=k*(x-width) and y<=k*(x+width):1
+ else:0
+ '''
+ # width = min(target_len/band_mask_factor, 50)
+ width1 = target_len / band_mask_factor
+ width2 = target_len.new(target_len.size()).fill_(band_width)
+ width = torch.where(width1 < width2, width1, width2).float()
+ base = torch.ones(attn.size()).to(attn.device)
+ zero = torch.zeros(attn.size()).to(attn.device)
+ x = torch.arange(0, attn.size(1)).to(attn.device)[None, :, None].float() * base
+ y = torch.arange(0, attn.size(2)).to(attn.device)[None, None, :].float() * base
+ cond = (y - attn_ks[:, None, None] * x)
+ cond1 = cond + attn_ks[:, None, None] * width[:, None, None]
+ cond2 = cond - attn_ks[:, None, None] * width[:, None, None]
+ mask1 = torch.where(cond1 < 0, zero, base)
+ mask2 = torch.where(cond2 > 0, zero, base)
+ mask = mask1 * mask2
+
+ if src_padding_mask is not None:
+ attn = attn * (1 - src_padding_mask.float())[:, None, :]
+ if tgt_padding_mask is not None:
+ attn = attn * (1 - tgt_padding_mask.float())[:, :, None]
+
+ diagonal_attn = attn * mask
+ diagonal_focus_rate = diagonal_attn.sum(-1).sum(-1) / attn.sum(-1).sum(-1)
+ return diagonal_focus_rate, mask
diff --git a/preprocess/tools/note_transcription/utils/metrics/dtw.py b/preprocess/tools/note_transcription/utils/metrics/dtw.py
new file mode 100644
index 0000000000000000000000000000000000000000..829e8e160355f8729b8e478bc4a24ca8597df58e
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/metrics/dtw.py
@@ -0,0 +1,160 @@
+from numpy import array, zeros, full, argmin, inf, ndim
+from scipy.spatial.distance import cdist
+from math import isinf
+
+
+def dtw(x, y, dist, warp=1, w=inf, s=1.0):
+ """
+ Computes Dynamic Time Warping (DTW) of two sequences.
+
+ :param array x: N1*M array
+ :param array y: N2*M array
+ :param func dist: distance used as cost measure
+ :param int warp: how many shifts are computed.
+ :param int w: window size limiting the maximal distance between indices of matched entries |i,j|.
+ :param float s: weight applied on off-diagonal moves of the path. As s gets larger, the warping path is increasingly biased towards the diagonal
+ Returns the minimum distance, the cost matrix, the accumulated cost matrix, and the wrap path.
+ """
+ assert len(x)
+ assert len(y)
+ assert isinf(w) or (w >= abs(len(x) - len(y)))
+ assert s > 0
+ r, c = len(x), len(y)
+ if not isinf(w):
+ D0 = full((r + 1, c + 1), inf)
+ for i in range(1, r + 1):
+ D0[i, max(1, i - w):min(c + 1, i + w + 1)] = 0
+ D0[0, 0] = 0
+ else:
+ D0 = zeros((r + 1, c + 1))
+ D0[0, 1:] = inf
+ D0[1:, 0] = inf
+ D1 = D0[1:, 1:] # view
+ for i in range(r):
+ for j in range(c):
+ if (isinf(w) or (max(0, i - w) <= j <= min(c, i + w))):
+ D1[i, j] = dist(x[i], y[j])
+ C = D1.copy()
+ jrange = range(c)
+ for i in range(r):
+ if not isinf(w):
+ jrange = range(max(0, i - w), min(c, i + w + 1))
+ for j in jrange:
+ min_list = [D0[i, j]]
+ for k in range(1, warp + 1):
+ i_k = min(i + k, r)
+ j_k = min(j + k, c)
+ min_list += [D0[i_k, j] * s, D0[i, j_k] * s]
+ D1[i, j] += min(min_list)
+ if len(x) == 1:
+ path = zeros(len(y)), range(len(y))
+ elif len(y) == 1:
+ path = range(len(x)), zeros(len(x))
+ else:
+ path = _traceback(D0)
+ return D1[-1, -1], C, D1, path
+
+
+def accelerated_dtw(x, y, dist, warp=1):
+ """
+ Computes Dynamic Time Warping (DTW) of two sequences in a faster way.
+ Instead of iterating through each element and calculating each distance,
+ this uses the cdist function from scipy (https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html)
+
+ :param array x: N1*M array
+ :param array y: N2*M array
+ :param string or func dist: distance parameter for cdist. When string is given, cdist uses optimized functions for the distance metrics.
+ If a string is passed, the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'.
+ :param int warp: how many shifts are computed.
+ Returns the minimum distance, the cost matrix, the accumulated cost matrix, and the wrap path.
+ """
+ assert len(x)
+ assert len(y)
+ if ndim(x) == 1:
+ x = x.reshape(-1, 1)
+ if ndim(y) == 1:
+ y = y.reshape(-1, 1)
+ r, c = len(x), len(y)
+ D0 = zeros((r + 1, c + 1))
+ D0[0, 1:] = inf
+ D0[1:, 0] = inf
+ D1 = D0[1:, 1:]
+ D0[1:, 1:] = cdist(x, y, dist)
+ C = D1.copy()
+ for i in range(r):
+ for j in range(c):
+ min_list = [D0[i, j]]
+ for k in range(1, warp + 1):
+ min_list += [D0[min(i + k, r), j],
+ D0[i, min(j + k, c)]]
+ D1[i, j] += min(min_list)
+ if len(x) == 1:
+ path = zeros(len(y)), range(len(y))
+ elif len(y) == 1:
+ path = range(len(x)), zeros(len(x))
+ else:
+ path = _traceback(D0)
+ return D1[-1, -1], C, D1, path
+
+
+def _traceback(D):
+ i, j = array(D.shape) - 2
+ p, q = [i], [j]
+ while (i > 0) or (j > 0):
+ tb = argmin((D[i, j], D[i, j + 1], D[i + 1, j]))
+ if tb == 0:
+ i -= 1
+ j -= 1
+ elif tb == 1:
+ i -= 1
+ else: # (tb == 2):
+ j -= 1
+ p.insert(0, i)
+ q.insert(0, j)
+ return array(p), array(q)
+
+
+if __name__ == '__main__':
+ w = inf
+ s = 1.0
+ if 1: # 1-D numeric
+ from sklearn.metrics.pairwise import manhattan_distances
+
+ x = [0, 0, 1, 1, 2, 4, 2, 1, 2, 0]
+ y = [1, 1, 1, 2, 2, 2, 2, 3, 2, 0]
+ dist_fun = manhattan_distances
+ w = 1
+ # s = 1.2
+ elif 0: # 2-D numeric
+ from sklearn.metrics.pairwise import euclidean_distances
+
+ x = [[0, 0], [0, 1], [1, 1], [1, 2], [2, 2], [4, 3], [2, 3], [1, 1], [2, 2], [0, 1]]
+ y = [[1, 0], [1, 1], [1, 1], [2, 1], [4, 3], [4, 3], [2, 3], [3, 1], [1, 2], [1, 0]]
+ dist_fun = euclidean_distances
+ else: # 1-D list of strings
+ from nltk.metrics.distance import edit_distance
+
+ # x = ['we', 'shelled', 'clams', 'for', 'the', 'chowder']
+ # y = ['class', 'too']
+ x = ['i', 'soon', 'found', 'myself', 'muttering', 'to', 'the', 'walls']
+ y = ['see', 'drown', 'himself']
+ # x = 'we talked about the situation'.split()
+ # y = 'we talked about the situation'.split()
+ dist_fun = edit_distance
+ dist, cost, acc, path = dtw(x, y, dist_fun, w=w, s=s)
+
+ # Vizualize
+ from matplotlib import pyplot as plt
+
+ plt.imshow(cost.T, origin='lower', cmap=plt.cm.Reds, interpolation='nearest')
+ plt.plot(path[0], path[1], '-o') # relation
+ plt.xticks(range(len(x)), x)
+ plt.yticks(range(len(y)), y)
+ plt.xlabel('x')
+ plt.ylabel('y')
+ plt.axis('tight')
+ if isinf(w):
+ plt.title('Minimum distance: {}, slope weight: {}'.format(dist, s))
+ else:
+ plt.title('Minimum distance: {}, window widht: {}, slope weight: {}'.format(dist, w, s))
+ plt.show()
diff --git a/preprocess/tools/note_transcription/utils/metrics/laplace_var.py b/preprocess/tools/note_transcription/utils/metrics/laplace_var.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec6f5f8d877195e7ee512d7e9f6f8a879d3ef32c
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/metrics/laplace_var.py
@@ -0,0 +1,4 @@
+import scipy.ndimage
+
+def laplace_var(x):
+ return scipy.ndimage.laplace(x).var()
diff --git a/preprocess/tools/note_transcription/utils/metrics/pitch_distance.py b/preprocess/tools/note_transcription/utils/metrics/pitch_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bc11424a9f75270fc7eb5ef98731129e25ff715
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/metrics/pitch_distance.py
@@ -0,0 +1,102 @@
+import numpy as np
+import matplotlib.pyplot as plt
+from numba import jit
+
+import torch
+
+
+@jit
+def time_warp(costs):
+ dtw = np.zeros_like(costs)
+ dtw[0, 1:] = np.inf
+ dtw[1:, 0] = np.inf
+ eps = 1e-4
+ for i in range(1, costs.shape[0]):
+ for j in range(1, costs.shape[1]):
+ dtw[i, j] = costs[i, j] + min(dtw[i - 1, j], dtw[i, j - 1], dtw[i - 1, j - 1])
+ return dtw
+
+
+def align_from_distances(distance_matrix, debug=False, return_mindist=False):
+ # for each position in spectrum 1, returns best match position in spectrum2
+ # using monotonic alignment
+ dtw = time_warp(distance_matrix)
+
+ i = distance_matrix.shape[0] - 1
+ j = distance_matrix.shape[1] - 1
+ results = [0] * distance_matrix.shape[0]
+ while i > 0 and j > 0:
+ results[i] = j
+ i, j = min([(i - 1, j), (i, j - 1), (i - 1, j - 1)], key=lambda x: dtw[x[0], x[1]])
+
+ if debug:
+ visual = np.zeros_like(dtw)
+ visual[range(len(results)), results] = 1
+ plt.matshow(visual)
+ plt.show()
+ if return_mindist:
+ return results, dtw[-1, -1]
+ return results
+
+
+def get_local_context(input_f, max_window=32, scale_factor=1.):
+ # input_f: [S, 1], support numpy array or torch tensor
+ # return hist: [S, max_window * 2], list of list
+ T = input_f.shape[0]
+ # max_window = int(max_window * scale_factor)
+ derivative = [[0 for _ in range(max_window * 2)] for _ in range(T)]
+
+ for t in range(T): # travel the time series
+ for feat_idx in range(-max_window, max_window):
+ if t + feat_idx < 0 or t + feat_idx >= T:
+ value = 0
+ else:
+ value = input_f[t + feat_idx]
+ derivative[t][feat_idx + max_window] = value
+ return derivative
+
+
+def cal_localnorm_dist(src, tgt, src_len, tgt_len):
+ local_src = torch.tensor(get_local_context(src))
+ local_tgt = torch.tensor(get_local_context(tgt, scale_factor=tgt_len / src_len))
+
+ local_norm_src = (local_src - local_src.mean(-1).unsqueeze(-1)) # / local_src.std(-1).unsqueeze(-1) # [T1, 32]
+ local_norm_tgt = (local_tgt - local_tgt.mean(-1).unsqueeze(-1)) # / local_tgt.std(-1).unsqueeze(-1) # [T2, 32]
+
+ dists = torch.cdist(local_norm_src[None, :, :], local_norm_tgt[None, :, :]) # [1, T1, T2]
+ return dists
+
+
+## here is API for one sample
+def LoNDTWDistance(src, tgt):
+ # src: [S]
+ # tgt: [T]
+ dists = cal_localnorm_dist(src, tgt, src.shape[0], tgt.shape[0]) # [1, S, T]
+ costs = dists.squeeze(0) # [S, T]
+ alignment, min_distance = align_from_distances(costs.T.cpu().detach().numpy(), return_mindist=True) # [T]
+ return alignment, min_distance
+
+# if __name__ == '__main__':
+# # utils from ns
+# from utils.pitch_utils import denorm_f0
+# from tasks.singing.fsinging import FastSingingDataset
+# from utils.hparams import hparams, set_hparams
+#
+# set_hparams()
+#
+# train_ds = FastSingingDataset('test')
+#
+# # Test One sample case
+# sample = train_ds[0]
+# amateur_f0 = sample['f0']
+# prof_f0 = sample['prof_f0']
+#
+# amateur_uv = sample['uv']
+# amateur_padding = sample['mel2ph'] == 0
+# prof_uv = sample['prof_uv']
+# prof_padding = sample['prof_mel2ph'] == 0
+# amateur_f0_denorm = denorm_f0(amateur_f0, amateur_uv, hparams, pitch_padding=amateur_padding)
+# prof_f0_denorm = denorm_f0(prof_f0, prof_uv, hparams, pitch_padding=prof_padding)
+# alignment, min_distance = LoNDTWDistance(amateur_f0_denorm, prof_f0_denorm)
+# print(min_distance)
+# python utils/pitch_distance.py --config egs/datasets/audio/molar/svc_ppg.yaml
diff --git a/preprocess/tools/note_transcription/utils/metrics/ssim.py b/preprocess/tools/note_transcription/utils/metrics/ssim.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb8c6a47b14fbd450a6717a21236906d6de9679f
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/metrics/ssim.py
@@ -0,0 +1,84 @@
+"""
+Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim
+"""
+
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+import numpy as np
+from math import exp
+
+
+def gaussian(window_size, sigma):
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
+ return gauss / gauss.sum()
+
+
+def create_window(window_size, channel):
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
+ return window
+
+
+def _ssim(img1, img2, window, window_size, channel, size_average=True):
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
+
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1 * mu2
+
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
+
+ C1 = 0.01 ** 2
+ C2 = 0.03 ** 2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+
+ if size_average:
+ return ssim_map.mean()
+ else:
+ return ssim_map.mean(1)
+
+
+class SSIM(torch.nn.Module):
+ def __init__(self, window_size=11, size_average=True):
+ super(SSIM, self).__init__()
+ self.window_size = window_size
+ self.size_average = size_average
+ self.channel = 1
+ self.window = create_window(window_size, self.channel)
+
+ def forward(self, img1, img2):
+ (_, channel, _, _) = img1.size()
+
+ if channel == self.channel and self.window.data.type() == img1.data.type():
+ window = self.window
+ else:
+ window = create_window(self.window_size, channel)
+
+ if img1.is_cuda:
+ window = window.cuda(img1.get_device())
+ window = window.type_as(img1)
+
+ self.window = window
+ self.channel = channel
+
+ return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
+
+
+window = None
+
+
+def ssim(img1, img2, window_size=11, size_average=True):
+ (_, channel, _, _) = img1.size()
+ global window
+ if window is None:
+ window = create_window(window_size, channel)
+ if img1.is_cuda:
+ window = window.cuda(img1.get_device())
+ window = window.type_as(img1)
+ return _ssim(img1, img2, window, window_size, channel, size_average)
diff --git a/preprocess/tools/note_transcription/utils/nn/model_utils.py b/preprocess/tools/note_transcription/utils/nn/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b81200e9a2629ac4d791a37d31d5f13330aefd30
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/nn/model_utils.py
@@ -0,0 +1,14 @@
+import numpy as np
+
+
+def print_arch(model, model_name='model'):
+ print(f"| {model_name} Arch: ", model)
+ num_params(model, model_name=model_name)
+
+
+def num_params(model, print_out=True, model_name="model"):
+ parameters = filter(lambda p: p.requires_grad, model.parameters())
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
+ if print_out:
+ print(f'| {model_name} Trainable Parameters: %.3fM' % parameters)
+ return parameters
diff --git a/preprocess/tools/note_transcription/utils/nn/schedulers.py b/preprocess/tools/note_transcription/utils/nn/schedulers.py
new file mode 100644
index 0000000000000000000000000000000000000000..931cdc4336d2c3f6c111e5189277200397d777b8
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/nn/schedulers.py
@@ -0,0 +1,65 @@
+class NoneSchedule(object):
+ def __init__(self, optimizer, lr):
+ self.optimizer = optimizer
+ self.constant_lr = lr
+ self.step(0)
+
+ def step(self, num_updates):
+ self.lr = self.constant_lr
+ for param_group in self.optimizer.param_groups:
+ param_group['lr'] = self.lr
+ return self.lr
+
+ def get_lr(self):
+ return self.optimizer.param_groups[0]['lr']
+
+ def get_last_lr(self):
+ return self.get_lr()
+
+
+class RSQRTSchedule(NoneSchedule):
+ def __init__(self, optimizer, lr, warmup_updates, hidden_size, last_step=-1):
+ self.optimizer = optimizer
+ self.constant_lr = lr
+ self.warmup_updates = warmup_updates
+ self.hidden_size = hidden_size
+ self.lr = lr
+ self.last_step = last_step
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = self.lr
+ self.step()
+
+ def step(self, num_updates=None):
+ if num_updates is None:
+ self.last_step += 1
+ num_updates = self.last_step
+ constant_lr = self.constant_lr
+ warmup = min(num_updates / self.warmup_updates, 1.0)
+ rsqrt_decay = max(self.warmup_updates, num_updates) ** -0.5
+ rsqrt_hidden = self.hidden_size ** -0.5
+ self.lr = max(constant_lr * warmup * rsqrt_decay * rsqrt_hidden, 1e-7)
+ for param_group in self.optimizer.param_groups:
+ param_group['lr'] = self.lr
+ return self.lr
+
+
+class WarmupSchedule(NoneSchedule):
+ def __init__(self, optimizer, lr, warmup_updates, last_step=-1):
+ self.optimizer = optimizer
+ self.constant_lr = self.lr = lr
+ self.warmup_updates = warmup_updates
+ self.last_step = last_step
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = self.lr
+ self.step()
+
+ def step(self, num_updates=None):
+ if num_updates is None:
+ self.last_step += 1
+ num_updates = self.last_step
+ constant_lr = self.constant_lr
+ warmup = min(num_updates / self.warmup_updates, 1.0)
+ self.lr = max(constant_lr * warmup, 1e-7)
+ for param_group in self.optimizer.param_groups:
+ param_group['lr'] = self.lr
+ return self.lr
diff --git a/preprocess/tools/note_transcription/utils/nn/seq_utils.py b/preprocess/tools/note_transcription/utils/nn/seq_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1308bf7d1806a6c36de9c8af5e9d217eaefa7b56
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/nn/seq_utils.py
@@ -0,0 +1,305 @@
+from collections import defaultdict
+import torch
+import torch.nn.functional as F
+
+
+def make_positions(tensor, padding_idx):
+ """Replace non-padding symbols with their position numbers.
+
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
+ """
+ # The series of casts and type-conversions here are carefully
+ # balanced to both work with ONNX export and XLA. In particular XLA
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
+ # how to handle the dtype kwarg in cumsum.
+ mask = tensor.ne(padding_idx).int()
+ return (
+ torch.cumsum(mask, dim=1).type_as(mask) * mask
+ ).long() + padding_idx
+
+
+def softmax(x, dim):
+ return F.softmax(x, dim=dim, dtype=torch.float32)
+
+
+def sequence_mask(lengths, maxlen, dtype=torch.bool):
+ if maxlen is None:
+ maxlen = lengths.max()
+ mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t()
+ mask.type(dtype)
+ return mask
+
+
+def weights_nonzero_speech(target):
+ # target : B x T x mel
+ # Assign weight 1.0 to all labels except for padding (id=0).
+ dim = target.size(-1)
+ return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim)
+
+
+INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
+
+
+def _get_full_incremental_state_key(module_instance, key):
+ module_name = module_instance.__class__.__name__
+
+ # assign a unique ID to each module instance, so that incremental state is
+ # not shared across module instances
+ if not hasattr(module_instance, '_instance_id'):
+ INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
+ module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
+
+ return '{}.{}.{}'.format(module_name, module_instance._instance_id, key)
+
+
+def get_incremental_state(module, incremental_state, key):
+ """Helper for getting incremental state for an nn.Module."""
+ full_key = _get_full_incremental_state_key(module, key)
+ if incremental_state is None or full_key not in incremental_state:
+ return None
+ return incremental_state[full_key]
+
+
+def set_incremental_state(module, incremental_state, key, value):
+ """Helper for setting incremental state for an nn.Module."""
+ if incremental_state is not None:
+ full_key = _get_full_incremental_state_key(module, key)
+ incremental_state[full_key] = value
+
+
+def fill_with_neg_inf(t):
+ """FP16-compatible function that fills a tensor with -inf."""
+ return t.float().fill_(float('-inf')).type_as(t)
+
+
+def fill_with_neg_inf2(t):
+ """FP16-compatible function that fills a tensor with -inf."""
+ return t.float().fill_(-1e8).type_as(t)
+
+
+def select_attn(attn_logits, type='best'):
+ """
+
+ :param attn_logits: [n_layers, B, n_head, T_sp, T_txt]
+ :return:
+ """
+ encdec_attn = torch.stack(attn_logits, 0).transpose(1, 2)
+ # [n_layers * n_head, B, T_sp, T_txt]
+ encdec_attn = (encdec_attn.reshape([-1, *encdec_attn.shape[2:]])).softmax(-1)
+ if type == 'best':
+ indices = encdec_attn.max(-1).values.sum(-1).argmax(0)
+ encdec_attn = encdec_attn.gather(
+ 0, indices[None, :, None, None].repeat(1, 1, encdec_attn.size(-2), encdec_attn.size(-1)))[0]
+ return encdec_attn
+ elif type == 'mean':
+ return encdec_attn.mean(0)
+
+
+def make_pad_mask(lengths, xs=None, length_dim=-1):
+ """Make mask tensor containing indices of padded part.
+ Args:
+ lengths (LongTensor or List): Batch of lengths (B,).
+ xs (Tensor, optional): The reference tensor.
+ If set, masks will be the same shape as this tensor.
+ length_dim (int, optional): Dimension indicator of the above tensor.
+ See the example.
+ Returns:
+ Tensor: Mask tensor containing indices of padded part.
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
+ Examples:
+ With only lengths.
+ >>> lengths = [5, 3, 2]
+ >>> make_non_pad_mask(lengths)
+ masks = [[0, 0, 0, 0 ,0],
+ [0, 0, 0, 1, 1],
+ [0, 0, 1, 1, 1]]
+ With the reference tensor.
+ >>> xs = torch.zeros((3, 2, 4))
+ >>> make_pad_mask(lengths, xs)
+ tensor([[[0, 0, 0, 0],
+ [0, 0, 0, 0]],
+ [[0, 0, 0, 1],
+ [0, 0, 0, 1]],
+ [[0, 0, 1, 1],
+ [0, 0, 1, 1]]], dtype=torch.uint8)
+ >>> xs = torch.zeros((3, 2, 6))
+ >>> make_pad_mask(lengths, xs)
+ tensor([[[0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1]],
+ [[0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1]],
+ [[0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
+ With the reference tensor and dimension indicator.
+ >>> xs = torch.zeros((3, 6, 6))
+ >>> make_pad_mask(lengths, xs, 1)
+ tensor([[[0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1, 1]],
+ [[0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1]],
+ [[0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
+ >>> make_pad_mask(lengths, xs, 2)
+ tensor([[[0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1]],
+ [[0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1]],
+ [[0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
+ """
+ if length_dim == 0:
+ raise ValueError("length_dim cannot be 0: {}".format(length_dim))
+
+ if not isinstance(lengths, list):
+ lengths = lengths.tolist()
+ bs = int(len(lengths))
+ if xs is None:
+ maxlen = int(max(lengths))
+ else:
+ maxlen = xs.size(length_dim)
+
+ seq_range = torch.arange(0, maxlen, dtype=torch.int64)
+ seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
+ seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
+ mask = seq_range_expand >= seq_length_expand
+
+ if xs is not None:
+ assert xs.size(0) == bs, (xs.size(0), bs)
+
+ if length_dim < 0:
+ length_dim = xs.dim() + length_dim
+ # ind = (:, None, ..., None, :, , None, ..., None)
+ ind = tuple(
+ slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
+ )
+ mask = mask[ind].expand_as(xs).to(xs.device)
+ return mask
+
+
+def make_non_pad_mask(lengths, xs=None, length_dim=-1):
+ """Make mask tensor containing indices of non-padded part.
+ Args:
+ lengths (LongTensor or List): Batch of lengths (B,).
+ xs (Tensor, optional): The reference tensor.
+ If set, masks will be the same shape as this tensor.
+ length_dim (int, optional): Dimension indicator of the above tensor.
+ See the example.
+ Returns:
+ ByteTensor: mask tensor containing indices of padded part.
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
+ Examples:
+ With only lengths.
+ >>> lengths = [5, 3, 2]
+ >>> make_non_pad_mask(lengths)
+ masks = [[1, 1, 1, 1 ,1],
+ [1, 1, 1, 0, 0],
+ [1, 1, 0, 0, 0]]
+ With the reference tensor.
+ >>> xs = torch.zeros((3, 2, 4))
+ >>> make_non_pad_mask(lengths, xs)
+ tensor([[[1, 1, 1, 1],
+ [1, 1, 1, 1]],
+ [[1, 1, 1, 0],
+ [1, 1, 1, 0]],
+ [[1, 1, 0, 0],
+ [1, 1, 0, 0]]], dtype=torch.uint8)
+ >>> xs = torch.zeros((3, 2, 6))
+ >>> make_non_pad_mask(lengths, xs)
+ tensor([[[1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0]],
+ [[1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0]],
+ [[1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
+ With the reference tensor and dimension indicator.
+ >>> xs = torch.zeros((3, 6, 6))
+ >>> make_non_pad_mask(lengths, xs, 1)
+ tensor([[[1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0]],
+ [[1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0]],
+ [[1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
+ >>> make_non_pad_mask(lengths, xs, 2)
+ tensor([[[1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0]],
+ [[1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0]],
+ [[1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
+ """
+ return ~make_pad_mask(lengths, xs, length_dim)
+
+
+def get_mask_from_lengths(lengths):
+ max_len = torch.max(lengths).item()
+ ids = torch.arange(0, max_len).to(lengths.device)
+ mask = (ids < lengths.unsqueeze(1)).bool()
+ return mask
+
+
+def group_hidden_by_segs(h, seg_ids, max_len):
+ """
+
+ :param h: [B, T, H]
+ :param seg_ids: [B, T]
+ :return: h_ph: [B, T_ph, H]
+ """
+ B, T, H = h.shape
+ h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h)
+ all_ones = h.new_ones(h.shape[:2])
+ cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous()
+ h_gby_segs = h_gby_segs[:, 1:]
+ cnt_gby_segs = cnt_gby_segs[:, 1:]
+ h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1)
+ return h_gby_segs, cnt_gby_segs
diff --git a/preprocess/tools/note_transcription/utils/os_utils.py b/preprocess/tools/note_transcription/utils/os_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..df4cf8e2163ae8e20b7a80c5a5b3ed2e0bac8318
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/os_utils.py
@@ -0,0 +1,21 @@
+import os
+import subprocess
+from pathlib import Path
+
+def link_file(from_file, to_file):
+ subprocess.check_call(
+ f'ln -s "`realpath --relative-to="{os.path.dirname(to_file)}" "{from_file}"`" "{to_file}"', shell=True)
+
+
+def move_file(from_file, to_file):
+ subprocess.check_call(f'mv "{from_file}" "{to_file}"', shell=True)
+
+
+def copy_file(from_file, to_file):
+ subprocess.check_call(f'cp -r "{from_file}" "{to_file}"', shell=True)
+
+
+def safe_path(path):
+ os.makedirs(Path(path).parent, exist_ok=True)
+ return path
+
diff --git a/preprocess/tools/note_transcription/utils/plot/plot.py b/preprocess/tools/note_transcription/utils/plot/plot.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d7fc02cef69fa5517228437156e687ca054efc8
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/plot/plot.py
@@ -0,0 +1,51 @@
+import matplotlib
+
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+LINE_COLORS = ['w', 'r', 'orange', 'k', 'cyan', 'm', 'b', 'lime', 'g', 'brown', 'navy']
+
+
+def spec_to_figure(spec, vmin=None, vmax=None, title='', f0s=None, dur_info=None):
+ if isinstance(spec, torch.Tensor):
+ spec = spec.cpu().numpy()
+ H = spec.shape[1] // 2
+ fig = plt.figure(figsize=(12, 6))
+ plt.title(title)
+ plt.pcolor(spec.T, vmin=vmin, vmax=vmax)
+ if dur_info is not None:
+ assert isinstance(dur_info, dict)
+ txt = dur_info['txt']
+ dur_gt = dur_info['dur_gt']
+ if isinstance(dur_gt, torch.Tensor):
+ dur_gt = dur_gt.cpu().numpy()
+ dur_gt = np.cumsum(dur_gt).astype(int)
+ for i in range(len(dur_gt)):
+ shift = (i % 8) + 1
+ plt.text(dur_gt[i], shift * 4, txt[i])
+ plt.vlines(dur_gt[i], 0, H // 2, colors='b') # blue is gt
+ plt.xlim(0, dur_gt[-1])
+ if 'dur_pred' in dur_info:
+ dur_pred = dur_info['dur_pred']
+ if isinstance(dur_pred, torch.Tensor):
+ dur_pred = dur_pred.cpu().numpy()
+ dur_pred = np.cumsum(dur_pred).astype(int)
+ for i in range(len(dur_pred)):
+ shift = (i % 8) + 1
+ plt.text(dur_pred[i], H + shift * 4, txt[i])
+ plt.vlines(dur_pred[i], H, H * 1.5, colors='r') # red is pred
+ plt.xlim(0, max(dur_gt[-1], dur_pred[-1]))
+ if f0s is not None:
+ ax = plt.gca()
+ ax2 = ax.twinx()
+ if not isinstance(f0s, dict):
+ f0s = {'f0': f0s}
+ for i, (k, f0) in enumerate(f0s.items()):
+ if isinstance(f0, torch.Tensor):
+ f0 = f0.cpu().numpy()
+ ax2.plot(f0, label=k, c=LINE_COLORS[i], linewidth=1, alpha=0.5)
+ ax2.set_ylim(0, 1000)
+ ax2.legend()
+ return fig
diff --git a/preprocess/tools/note_transcription/utils/rosvot_utils.py b/preprocess/tools/note_transcription/utils/rosvot_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b03adbcf6b4288e0c91356feb86022367249e77
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/rosvot_utils.py
@@ -0,0 +1,111 @@
+import numpy as np
+import torch
+
+def regulate_real_note_itv(note_itv, note_bd, word_bd, word_durs, hop_size, audio_sample_rate):
+ # regulate note_itv in seconds according to the correspondence between note_bd and word_bd
+ assert note_itv.shape[0] == np.sum(note_bd) + 1
+ assert np.sum(word_bd) <= np.sum(note_bd)
+ assert word_durs.shape[0] == np.sum(word_bd) + 1, f"{word_durs.shape[0]} {np.sum(word_bd) + 1}"
+ word_bd = np.cumsum(word_bd) * word_bd # [0,1,0,0,1,0,0,0] -> [0,1,0,0,2,0,0,0]
+ word_itv = np.zeros((word_durs.shape[0], 2))
+ word_offsets = np.cumsum(word_durs)
+ note2words = np.zeros(note_itv.shape[0], dtype=int)
+ for idx in range(len(word_offsets) - 1):
+ word_itv[idx, 1] = word_itv[idx + 1, 0] = word_offsets[idx]
+ word_itv[-1, 1] = word_offsets[-1]
+ note_itv_secs = note_itv * hop_size / audio_sample_rate
+ for idx, itv in enumerate(note_itv):
+ start_idx, end_idx = itv
+ if word_bd[start_idx] > 0:
+ word_dur_idx = word_bd[start_idx]
+ note_itv_secs[idx, 0] = word_itv[word_dur_idx, 0]
+ note2words[idx] = word_dur_idx
+ if word_bd[end_idx] > 0:
+ word_dur_idx = word_bd[end_idx] - 1
+ note_itv_secs[idx, 1] = word_itv[word_dur_idx, 1]
+ note2words[idx] = word_dur_idx
+ note2words += 1 # mel2ph fashion: start from 1
+ return note_itv_secs, note2words
+
+def regulate_ill_slur(notes, note_itv, note2words):
+ res_note2words = []
+ res_note_itv = []
+ res_notes = []
+ note_idx = 0
+ note_idx_end = 0
+ while True:
+ if note_idx > len(notes) - 1:
+ break
+ while note_idx <= note_idx_end < len(notes) and note2words[note_idx] == note2words[note_idx_end]:
+ note_idx_end += 1
+ res_note2words.append(note2words[note_idx])
+ res_note_itv.append(note_itv[note_idx].tolist())
+ res_notes.append(notes[note_idx])
+ for idx in range(note_idx+1, note_idx_end):
+ if notes[idx] == notes[idx-1]:
+ res_note_itv[-1][1] = note_itv[idx][1]
+ else:
+ res_note_itv.append(note_itv[idx].tolist())
+ res_note2words.append(note2words[idx])
+ res_notes.append(notes[idx])
+ note_idx = note_idx_end
+ res_notes = np.array(res_notes, dtype=notes.dtype)
+ res_note_itv = np.array(res_note_itv, dtype=note_itv.dtype)
+ res_note2words = np.array(res_note2words, dtype=note2words.dtype)
+ return res_notes, res_note_itv, res_note2words
+
+def bd_to_idxs(bd):
+ # bd [T]
+ idxs = []
+ for idx in range(len(bd)):
+ if bd[idx] == 1:
+ idxs.append(idx)
+ return idxs
+
+def bd_to_durs(bd):
+ # bd [T]
+ last_idx = 0
+ durs = []
+ for idx in range(len(bd)):
+ if bd[idx] == 1:
+ durs.append(idx - last_idx)
+ last_idx = idx
+ durs.append(len(bd) - last_idx)
+ return durs
+
+def get_mel_len(wav_len, hop_size):
+ return (wav_len + hop_size - 1) // hop_size
+
+def mel2token_to_dur(mel2token, T_txt=None, max_dur=None):
+ is_torch = isinstance(mel2token, torch.Tensor)
+ has_batch_dim = True
+ if not is_torch:
+ mel2token = torch.LongTensor(mel2token)
+ if T_txt is None:
+ T_txt = mel2token.max()
+ if len(mel2token.shape) == 1:
+ mel2token = mel2token[None, ...]
+ has_batch_dim = False
+ B, _ = mel2token.shape
+ dur = mel2token.new_zeros(B, T_txt + 1).scatter_add(1, mel2token, torch.ones_like(mel2token))
+ dur = dur[:, 1:]
+ if max_dur is not None:
+ dur = dur.clamp(max=max_dur)
+ if not is_torch:
+ dur = dur.numpy()
+ if not has_batch_dim:
+ dur = dur[0]
+ return dur
+
+def align_word(word_durs, mel_len, hop_size, audio_sample_rate):
+ mel2word = np.zeros([mel_len], int)
+ start_time = 0
+ for i_word in range(len(word_durs)):
+ start_frame = int(start_time * audio_sample_rate / hop_size + 0.5)
+ end_frame = int((start_time + word_durs[i_word]) * audio_sample_rate / hop_size + 0.5)
+ mel2word[start_frame:end_frame] = i_word + 1
+ start_time = start_time + word_durs[i_word]
+
+ dur_word = mel2token_to_dur(mel2word)
+
+ return mel2word, dur_word.tolist()
\ No newline at end of file
diff --git a/preprocess/tools/note_transcription/utils/text/encoding.py b/preprocess/tools/note_transcription/utils/text/encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..f09f514613fd44a27450fe7c04cbdf5ebfbe78a8
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/text/encoding.py
@@ -0,0 +1,9 @@
+import chardet
+
+
+def get_encoding(file):
+ with open(file, 'rb') as f:
+ encoding = chardet.detect(f.read())['encoding']
+ if encoding == 'GB2312':
+ encoding = 'GB18030'
+ return encoding
diff --git a/preprocess/tools/note_transcription/utils/text/text_encoder.py b/preprocess/tools/note_transcription/utils/text/text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..09555af09720382a795712f0fdd9b711c5b19e02
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/text/text_encoder.py
@@ -0,0 +1,263 @@
+import json
+import re
+import six
+from six.moves import range # pylint: disable=redefined-builtin
+
+PAD = ""
+EOS = ""
+UNK = ""
+SEG = "|"
+PUNCS = '!,.?;:'
+RESERVED_TOKENS = [PAD, EOS, UNK]
+NUM_RESERVED_TOKENS = len(RESERVED_TOKENS)
+PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0
+EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1
+UNK_ID = RESERVED_TOKENS.index(UNK) # Normally 2
+
+if six.PY2:
+ RESERVED_TOKENS_BYTES = RESERVED_TOKENS
+else:
+ RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")]
+
+# Regular expression for unescaping token strings.
+# '\u' is converted to '_'
+# '\\' is converted to '\'
+# '\213;' is converted to unichr(213)
+_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
+_ESCAPE_CHARS = set(u"\\_u;0123456789")
+
+
+def strip_ids(ids, ids_to_strip):
+ """Strip ids_to_strip from the end ids."""
+ ids = list(ids)
+ while ids and ids[-1] in ids_to_strip:
+ ids.pop()
+ return ids
+
+
+class TextEncoder(object):
+ """Base class for converting from ints to/from human readable strings."""
+
+ def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS):
+ self._num_reserved_ids = num_reserved_ids
+
+ @property
+ def num_reserved_ids(self):
+ return self._num_reserved_ids
+
+ def encode(self, s):
+ """Transform a human-readable string into a sequence of int ids.
+
+ The ids should be in the range [num_reserved_ids, vocab_size). Ids [0,
+ num_reserved_ids) are reserved.
+
+ EOS is not appended.
+
+ Args:
+ s: human-readable string to be converted.
+
+ Returns:
+ ids: list of integers
+ """
+ return [int(w) + self._num_reserved_ids for w in s.split()]
+
+ def decode(self, ids, strip_extraneous=False):
+ """Transform a sequence of int ids into a human-readable string.
+
+ EOS is not expected in ids.
+
+ Args:
+ ids: list of integers to be converted.
+ strip_extraneous: bool, whether to strip off extraneous tokens
+ (EOS and PAD).
+
+ Returns:
+ s: human-readable string.
+ """
+ if strip_extraneous:
+ ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
+ return " ".join(self.decode_list(ids))
+
+ def decode_list(self, ids):
+ """Transform a sequence of int ids into a their string versions.
+
+ This method supports transforming individual input/output ids to their
+ string versions so that sequence to/from text conversions can be visualized
+ in a human readable format.
+
+ Args:
+ ids: list of integers to be converted.
+
+ Returns:
+ strs: list of human-readable string.
+ """
+ decoded_ids = []
+ for id_ in ids:
+ if 0 <= id_ < self._num_reserved_ids:
+ decoded_ids.append(RESERVED_TOKENS[int(id_)])
+ else:
+ decoded_ids.append(id_ - self._num_reserved_ids)
+ return [str(d) for d in decoded_ids]
+
+ @property
+ def vocab_size(self):
+ raise NotImplementedError()
+
+
+class TokenTextEncoder(TextEncoder):
+ """Encoder based on a user-supplied vocabulary (file or list)."""
+
+ def __init__(self,
+ vocab_filename,
+ reverse=False,
+ vocab_list=None,
+ replace_oov=None,
+ num_reserved_ids=NUM_RESERVED_TOKENS):
+ """Initialize from a file or list, one token per line.
+
+ Handling of reserved tokens works as follows:
+ - When initializing from a list, we add reserved tokens to the vocab.
+ - When initializing from a file, we do not add reserved tokens to the vocab.
+ - When saving vocab files, we save reserved tokens to the file.
+
+ Args:
+ vocab_filename: If not None, the full filename to read vocab from. If this
+ is not None, then vocab_list should be None.
+ reverse: Boolean indicating if tokens should be reversed during encoding
+ and decoding.
+ vocab_list: If not None, a list of elements of the vocabulary. If this is
+ not None, then vocab_filename should be None.
+ replace_oov: If not None, every out-of-vocabulary token seen when
+ encoding will be replaced by this string (which must be in vocab).
+ num_reserved_ids: Number of IDs to save for reserved tokens like .
+ """
+ super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
+ self._reverse = reverse
+ self._replace_oov = replace_oov
+ if vocab_filename:
+ self._init_vocab_from_file(vocab_filename)
+ else:
+ assert vocab_list is not None
+ self._init_vocab_from_list(vocab_list)
+ self.pad_index = self.token_to_id[PAD]
+ self.eos_index = self.token_to_id[EOS]
+ self.unk_index = self.token_to_id[UNK]
+ self.seg_index = self.token_to_id[SEG] if SEG in self.token_to_id else self.eos_index
+
+ def encode(self, s):
+ """Converts a space-separated string of tokens to a list of ids."""
+ sentence = s
+ tokens = sentence.strip().split()
+ if self._replace_oov is not None:
+ tokens = [t if t in self.token_to_id else self._replace_oov
+ for t in tokens]
+ ret = [self.token_to_id[tok] for tok in tokens]
+ return ret[::-1] if self._reverse else ret
+
+ def decode(self, ids, strip_eos=False, strip_padding=False):
+ if strip_padding and self.pad() in list(ids):
+ pad_pos = list(ids).index(self.pad())
+ ids = ids[:pad_pos]
+ if strip_eos and self.eos() in list(ids):
+ eos_pos = list(ids).index(self.eos())
+ ids = ids[:eos_pos]
+ return " ".join(self.decode_list(ids))
+
+ def decode_list(self, ids):
+ seq = reversed(ids) if self._reverse else ids
+ return [self._safe_id_to_token(i) for i in seq]
+
+ @property
+ def vocab_size(self):
+ return len(self.id_to_token)
+
+ def __len__(self):
+ return self.vocab_size
+
+ def _safe_id_to_token(self, idx):
+ return self.id_to_token.get(idx, "ID_%d" % idx)
+
+ def _init_vocab_from_file(self, filename):
+ """Load vocab from a file.
+
+ Args:
+ filename: The file to load vocabulary from.
+ """
+ with open(filename) as f:
+ tokens = [token.strip() for token in f.readlines()]
+
+ def token_gen():
+ for token in tokens:
+ yield token
+
+ self._init_vocab(token_gen(), add_reserved_tokens=False)
+
+ def _init_vocab_from_list(self, vocab_list):
+ """Initialize tokens from a list of tokens.
+
+ It is ok if reserved tokens appear in the vocab list. They will be
+ removed. The set of tokens in vocab_list should be unique.
+
+ Args:
+ vocab_list: A list of tokens.
+ """
+
+ def token_gen():
+ for token in vocab_list:
+ if token not in RESERVED_TOKENS:
+ yield token
+
+ self._init_vocab(token_gen())
+
+ def _init_vocab(self, token_generator, add_reserved_tokens=True):
+ """Initialize vocabulary with tokens from token_generator."""
+
+ self.id_to_token = {}
+ non_reserved_start_index = 0
+
+ if add_reserved_tokens:
+ self.id_to_token.update(enumerate(RESERVED_TOKENS))
+ non_reserved_start_index = len(RESERVED_TOKENS)
+
+ self.id_to_token.update(
+ enumerate(token_generator, start=non_reserved_start_index))
+
+ # _token_to_id is the reverse of _id_to_token
+ self.token_to_id = dict((v, k) for k, v in six.iteritems(self.id_to_token))
+
+ def pad(self):
+ return self.pad_index
+
+ def eos(self):
+ return self.eos_index
+
+ def unk(self):
+ return self.unk_index
+
+ def seg(self):
+ return self.seg_index
+
+ def store_to_file(self, filename):
+ """Write vocab file to disk.
+
+ Vocab files have one token per line. The file ends in a newline. Reserved
+ tokens are written to the vocab file as well.
+
+ Args:
+ filename: Full path of the file to store the vocab to.
+ """
+ with open(filename, "w") as f:
+ for i in range(len(self.id_to_token)):
+ f.write(self.id_to_token[i] + "\n")
+
+ def sil_phonemes(self):
+ return [p for p in self.id_to_token.values() if is_sil_phoneme(p)]
+
+
+def build_token_encoder(token_list_file):
+ token_list = json.load(open(token_list_file))
+ return TokenTextEncoder(None, vocab_list=token_list, replace_oov='')
+
+
+def is_sil_phoneme(p):
+ return p == '' or not p[0].isalpha()
diff --git a/preprocess/tools/note_transcription/utils/text/textgrid.py b/preprocess/tools/note_transcription/utils/text/textgrid.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fc2ec3c12ab41f9337fef81ca6df090f180848f
--- /dev/null
+++ b/preprocess/tools/note_transcription/utils/text/textgrid.py
@@ -0,0 +1,90 @@
+from collections import OrderedDict
+import re
+import json
+
+
+def remove_empty_lines(text):
+ """remove empty lines"""
+ assert (len(text) > 0)
+ assert (isinstance(text, list))
+ text = [t.strip() for t in text]
+ if "" in text:
+ text.remove("")
+ return text
+
+
+class TextGrid(object):
+ def __init__(self, text):
+ text = remove_empty_lines(text)
+ self.text = text
+ self.line_count = 0
+ self._get_type()
+ self._get_time_intval()
+ self._get_size()
+ self.tier_list = []
+ self._get_item_list()
+
+ def _extract_pattern(self, pattern, inc):
+ """
+ Parameters
+ ----------
+ pattern : regex to extract pattern
+ inc : increment of line count after extraction
+ Returns
+ -------
+ group : extracted info
+ """
+ try:
+ group = re.match(pattern, self.text[self.line_count]).group(1)
+ self.line_count += inc
+ except AttributeError:
+ raise ValueError("File format error at line %d:%s" % (self.line_count, self.text[self.line_count]))
+ return group
+
+ def _get_type(self):
+ self.file_type = self._extract_pattern(r"File type = \"(.*)\"", 2)
+
+ def _get_time_intval(self):
+ self.xmin = self._extract_pattern(r"xmin = (.*)", 1)
+ self.xmax = self._extract_pattern(r"xmax = (.*)", 2)
+
+ def _get_size(self):
+ self.size = int(self._extract_pattern(r"size = (.*)", 2))
+
+ def _get_item_list(self):
+ """Only supports IntervalTier currently"""
+ for itemIdx in range(1, self.size + 1):
+ tier = OrderedDict()
+ item_list = []
+ tier_idx = self._extract_pattern(r"item \[(.*)\]:", 1)
+ tier_class = self._extract_pattern(r"class = \"(.*)\"", 1)
+ if tier_class != "IntervalTier":
+ raise NotImplementedError("Only IntervalTier class is supported currently")
+ tier_name = self._extract_pattern(r"name = \"(.*)\"", 1)
+ tier_xmin = self._extract_pattern(r"xmin = (.*)", 1)
+ tier_xmax = self._extract_pattern(r"xmax = (.*)", 1)
+ tier_size = self._extract_pattern(r"intervals: size = (.*)", 1)
+ for i in range(int(tier_size)):
+ item = OrderedDict()
+ item["idx"] = self._extract_pattern(r"intervals \[(.*)\]", 1)
+ item["xmin"] = self._extract_pattern(r"xmin = (.*)", 1)
+ item["xmax"] = self._extract_pattern(r"xmax = (.*)", 1)
+ item["text"] = self._extract_pattern(r"text = \"(.*)\"", 1)
+ item_list.append(item)
+ tier["idx"] = tier_idx
+ tier["class"] = tier_class
+ tier["name"] = tier_name
+ tier["xmin"] = tier_xmin
+ tier["xmax"] = tier_xmax
+ tier["size"] = tier_size
+ tier["items"] = item_list
+ self.tier_list.append(tier)
+
+ def toJson(self):
+ _json = OrderedDict()
+ _json["file_type"] = self.file_type
+ _json["xmin"] = self.xmin
+ _json["xmax"] = self.xmax
+ _json["size"] = self.size
+ _json["tiers"] = self.tier_list
+ return json.dumps(_json, ensure_ascii=False, indent=2)
\ No newline at end of file
diff --git a/preprocess/tools/vocal_detection.py b/preprocess/tools/vocal_detection.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac4a751d9229889bbcdcd4aad7ee07c85403f120
--- /dev/null
+++ b/preprocess/tools/vocal_detection.py
@@ -0,0 +1,320 @@
+import os
+import time
+from dataclasses import dataclass
+from typing import List, Optional
+
+import librosa
+import numpy as np
+from soundfile import write
+
+
+@dataclass(frozen=True)
+class VocalDetectionConfig:
+ hop_ms: int = 20
+ smooth_ms: int = 200
+ start_ms: int = 120
+ end_ms: int = 200
+ prepad_ms: int = 80
+ postpad_ms: int = 120
+ min_len_ms: int = 1000
+ max_len_ms: int = 20000
+ short_seg_merge_gap_ms: int = 8000
+ small_gap_ms: int = 500
+ lookback_ms: int = 200
+ lookahead_ms: int = 100
+
+
+def _moving_average(x: np.ndarray, win: int) -> np.ndarray:
+ if win <= 1:
+ return x
+ kernel = np.ones(win, dtype=np.float32) / float(win)
+ return np.convolve(x, kernel, mode="same")
+
+
+def _merge_short_segments(
+ segments_ms: List[List[int]],
+ *,
+ min_len_ms: int,
+ max_len_ms: int,
+ short_seg_merge_gap_ms: int,
+ small_gap_ms: int,
+) -> List[List[int]]:
+ if not segments_ms:
+ return []
+
+ merged: List[List[int]] = []
+ cur_start, cur_end = segments_ms[0]
+
+ for next_start, next_end in segments_ms[1:]:
+ cur_len = cur_end - cur_start
+ gap_ms = next_start - cur_end
+ merged_len = next_end - cur_start
+
+ should_merge = (
+ (cur_len < min_len_ms and gap_ms < short_seg_merge_gap_ms)
+ or (gap_ms < small_gap_ms and merged_len < max_len_ms)
+ )
+
+ if should_merge:
+ cur_end = next_end
+ continue
+
+ if (cur_end - cur_start) >= min_len_ms:
+ merged.append([cur_start, cur_end])
+
+ cur_start, cur_end = next_start, next_end
+
+ if (cur_end - cur_start) >= min_len_ms:
+ merged.append([cur_start, cur_end])
+
+ if not merged:
+ return segments_ms
+
+ return merged
+
+
+def _voiced_to_segments(
+ voiced: np.ndarray,
+ *,
+ hop_ms: int,
+ smooth_ms: int,
+ start_ms: int,
+ end_ms: int,
+ prepad_ms: int,
+ postpad_ms: int,
+ max_len_ms: int,
+) -> List[List[int]]:
+ smooth_frames = max(1, int(round(smooth_ms / hop_ms)))
+ smooth_voiced = _moving_average(voiced.astype(np.float32), smooth_frames)
+ active = smooth_voiced >= 0.5
+
+ segments: List[List[int]] = []
+ start_idx = None
+ start_frames = max(1, int(round(start_ms / hop_ms)))
+ end_frames = max(1, int(round(end_ms / hop_ms)))
+ prepad_frames = max(0, int(round(prepad_ms / hop_ms)))
+ postpad_frames = max(0, int(round(postpad_ms / hop_ms)))
+ active_count = 0
+ inactive_count = 0
+
+ for i, flag in enumerate(active):
+ if flag:
+ active_count += 1
+ inactive_count = 0
+ else:
+ inactive_count += 1
+ active_count = 0
+
+ if start_idx is None:
+ if active_count >= start_frames:
+ start_idx = max(0, i - start_frames + 1 - prepad_frames)
+ else:
+ if inactive_count >= end_frames:
+ end_idx = min(len(active) - 1, i - end_frames + 1 + postpad_frames)
+ start_ms_val = start_idx * hop_ms
+ end_ms_val = end_idx * hop_ms + hop_ms
+ if end_ms_val > start_ms_val:
+ segments.append([int(start_ms_val), int(end_ms_val)])
+ start_idx = None
+
+ if start_idx is not None:
+ start_ms_val = start_idx * hop_ms
+ end_idx = min(len(active) - 1, len(active) - 1 + postpad_frames)
+ end_ms_val = end_idx * hop_ms + hop_ms
+ if end_ms_val > start_ms_val:
+ segments.append([int(start_ms_val), int(end_ms_val)])
+
+ def _split_segment(seg: List[int]) -> List[List[int]]:
+ start_ms_val, end_ms_val = seg
+ start_frame = int(start_ms_val // hop_ms)
+ end_frame = int((end_ms_val - 1) // hop_ms)
+ end_frame = max(start_frame, min(end_frame, len(active) - 1))
+
+ best_start = None
+ best_len = 0
+ cur_start = None
+ cur_len = 0
+ for idx in range(start_frame, end_frame + 1):
+ if not active[idx]:
+ if cur_start is None:
+ cur_start = idx
+ cur_len = 1
+ else:
+ cur_len += 1
+ else:
+ if cur_start is not None and cur_len > best_len:
+ best_start, best_len = cur_start, cur_len
+ cur_start = None
+ cur_len = 0
+ if cur_start is not None and cur_len > best_len:
+ best_start, best_len = cur_start, cur_len
+
+ if best_start is None:
+ split_frame = (start_frame + end_frame) // 2
+ else:
+ split_frame = best_start + best_len // 2
+
+ split_ms = split_frame * hop_ms
+ if split_ms <= start_ms_val:
+ split_ms = start_ms_val + hop_ms
+ if split_ms >= end_ms_val:
+ split_ms = end_ms_val - hop_ms
+
+ if split_ms <= start_ms_val or split_ms >= end_ms_val:
+ return [seg]
+
+ return [[start_ms_val, int(split_ms)], [int(split_ms), end_ms_val]]
+
+ queue = segments[:]
+ segments = []
+ while queue:
+ seg = queue.pop(0)
+ if (seg[1] - seg[0]) <= max_len_ms:
+ segments.append(seg)
+ continue
+ parts = _split_segment(seg)
+ if len(parts) == 1:
+ segments.append(seg)
+ else:
+ queue = parts + queue
+
+ return segments
+
+
+class VocalDetector:
+ """Detect vocal segments based on f0 voiced decisions.
+
+ This component consumes a precomputed ``*_f0.npy`` track and
+ produces vocal segments (and cuts wav files) for downstream
+ transcription or singing voice tasks.
+ """
+ def __init__(
+ self,
+ cut_wavs_output_dir: str = "cut_wavs",
+ config: VocalDetectionConfig | None = None,
+ *,
+ verbose: bool = True,
+ ):
+ """Initialize the vocal detector.
+
+ Args:
+ cut_wavs_output_dir: Directory to save cut wav segments.
+ config: Detection configuration; uses :class:`VocalDetectionConfig` by default.
+ verbose: Whether to print verbose logs.
+ """
+ self.cut_wavs_output_dir = cut_wavs_output_dir
+ self.config = config or VocalDetectionConfig()
+ self.verbose = verbose
+
+ if self.verbose:
+ print(
+ "[vocal detection] init success:",
+ f"cut_wavs_output_dir={self.cut_wavs_output_dir}",
+ f"hop_ms={self.config.hop_ms}",
+ )
+
+ def process(self, audio_path: str, f0: np.ndarray, *, verbose: Optional[bool] = None) -> List[dict]:
+ """Run vocal detection on a single wav.
+
+ Args:
+ audio_path: Path to the input wav file.
+ f0: The f0 contour to use for vocal detection.
+ verbose: Override instance-level verbose flag for this call.
+
+ Returns:
+ A list of segment metadata dicts with fields like
+ ``item_name``, ``wav_fn``, ``start_time_ms``, ``end_time_ms``.
+ """
+ verbose = self.verbose if verbose is None else verbose
+ if verbose:
+ print(f"[vocal detection] process: start: {audio_path}")
+ t0 = time.time()
+
+ os.makedirs(self.cut_wavs_output_dir, exist_ok=True)
+
+ base_name = os.path.basename(audio_path)
+ base_name_no_ext = os.path.splitext(base_name)[0]
+
+ voiced = f0 > 0
+
+ segments_ms = _voiced_to_segments(
+ voiced,
+ hop_ms=self.config.hop_ms,
+ smooth_ms=self.config.smooth_ms,
+ start_ms=self.config.start_ms,
+ end_ms=self.config.end_ms,
+ prepad_ms=self.config.prepad_ms,
+ postpad_ms=self.config.postpad_ms,
+ max_len_ms=self.config.max_len_ms,
+ )
+
+ if verbose:
+ print(f"[vocal detection] segments(before_merge)={len(segments_ms)}")
+
+ segments_ms = _merge_short_segments(
+ segments_ms,
+ min_len_ms=self.config.min_len_ms,
+ max_len_ms=self.config.max_len_ms,
+ short_seg_merge_gap_ms=self.config.short_seg_merge_gap_ms,
+ small_gap_ms=self.config.small_gap_ms,
+ )
+
+ if verbose:
+ print(f"[vocal detection] segments(after_merge)={len(segments_ms)}")
+
+ y, sr = librosa.load(audio_path, sr=None, mono=True)
+
+ # Apply global lookback/lookahead in milliseconds
+ lookback_ms = self.config.lookback_ms
+ lookahead_ms = self.config.lookahead_ms
+
+ adjusted_segments: List[List[int]] = []
+ prev_end = 0
+ for start_ms, end_ms in segments_ms:
+ start_ms = max(0, start_ms - lookback_ms)
+ end_ms = min(end_ms + lookahead_ms, int(y.shape[0] / sr * 1000))
+
+ # Enforce non-overlap with previous segment, move backward the previous one.
+ if start_ms < prev_end and len(adjusted_segments) > 0:
+ adjusted_segments[-1][1] = start_ms
+
+ adjusted_segments.append([start_ms, end_ms])
+ prev_end = end_ms
+
+ segment_infos = []
+ for idx, (start_ms, end_ms) in enumerate(adjusted_segments):
+ if end_ms - start_ms > self.config.max_len_ms:
+ start_ms = end_ms - self.config.max_len_ms
+
+ key = f"{base_name_no_ext}_{idx}"
+ start_sample = librosa.time_to_samples(start_ms / 1000, sr=sr)
+ end_sample = librosa.time_to_samples(end_ms / 1000, sr=sr)
+ segment = y[start_sample:end_sample]
+
+ write(f"{self.cut_wavs_output_dir}/{key}.wav", segment, sr)
+ segment_infos.append(
+ {
+ "item_name": key,
+ "wav_fn": f"{self.cut_wavs_output_dir}/{key}.wav",
+ "start_time_ms": int(start_sample * 1000 / sr),
+ "end_time_ms": int(end_sample * 1000 / sr),
+ "origin_wav_fn": audio_path,
+ "duration": int((end_sample - start_sample) * 1000 / sr),
+ }
+ )
+
+ if verbose:
+ dt = time.time() - t0
+ print(
+ "[vocal detection] process: done:",
+ f"n_segments={len(segment_infos)}",
+ f"time={dt:.3f}s",
+ )
+
+ return segment_infos
+
+
+if __name__ == "__main__":
+ m = VocalDetector(cut_wavs_output_dir="outputs/transcription/cut_wavs")
+ segment_infos = m.process("./outputs/transcription/test.wav")
+ print(segment_infos)
diff --git a/preprocess/tools/vocal_separation/__init__.py b/preprocess/tools/vocal_separation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/preprocess/tools/vocal_separation/model.py b/preprocess/tools/vocal_separation/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce6679c6daa3ba51fa60037f51e8930748dea05b
--- /dev/null
+++ b/preprocess/tools/vocal_separation/model.py
@@ -0,0 +1,225 @@
+# https://github.com/ZFTurbo/Music-Source-Separation-Training
+# https://huggingface.co/becruily/mel-band-roformer-karaoke/blob/main/mel_band_roformer_karaoke_becruily.ckpt
+# https://huggingface.co/anvuew/dereverb_mel_band_roformer/blob/main/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any, Dict, Optional, Tuple
+
+import librosa
+import sys
+import os
+import time
+import torch
+import numpy as np
+
+from .utils.audio_utils import normalize_audio, denormalize_audio
+from .utils.settings import get_model_from_config, parse_args_inference
+from .utils.model_utils import demix
+from .utils.model_utils import prefer_target_instrument, apply_tta, load_start_checkpoint
+
+
+def process(mix, model, args, config, device):
+
+ instruments = prefer_target_instrument(config)[:]
+
+ # If mono audio we must adjust it depending on model
+ if len(mix.shape) == 1:
+ mix = np.expand_dims(mix, axis=0)
+ if 'num_channels' in config.audio:
+ if config.audio['num_channels'] == 2:
+ # print(f'Convert mono track to stereo...')
+ mix = np.concatenate([mix, mix], axis=0)
+
+ if 'normalize' in config.inference:
+ if config.inference['normalize'] is True:
+ mix, norm_params = normalize_audio(mix)
+
+ waveforms_orig = demix(config, model, mix, device, model_type=args.model_type, pbar=not args.disable_detailed_pbar)
+
+ instr = 'vocals' if 'vocals' in instruments else instruments[0]
+ estimates = waveforms_orig[instr]
+ if 'normalize' in config.inference:
+ if config.inference['normalize'] is True:
+ estimates = denormalize_audio(estimates, norm_params)
+
+ return estimates
+
+
+def build_model(args):
+ model, config = get_model_from_config(args.model_type, args.config_path)
+
+ load_start_checkpoint(args, model, None, type_='inference')
+
+ return model, config
+
+
+def build_models(dict_args):
+ args = parse_args_inference(dict_args)
+
+ ########## load model ##########
+ torch.backends.cudnn.benchmark = True
+
+ args.config_path = args.sep_config_path
+ args.start_check_point = args.sep_start_check_point
+
+ sep_model, sep_config = build_model(args)
+
+ args.config_path = args.der_config_path
+ args.start_check_point = args.der_start_check_point
+
+ dereverb_model, dereverb_config = build_model(args)
+
+ sep_model = sep_model
+ dereverb_model = dereverb_model
+
+ return sep_model, sep_config, dereverb_model, dereverb_config, args
+
+def main(args, sep_model=None, sep_config=None, dereverb_model=None, dereverb_config=None, device=None):
+
+ ######## process data ##########
+ sample_rate = getattr(sep_config.audio, 'sample_rate', 44100)
+ path = args.input_path
+
+ mix, _ = librosa.load(path, sr=sample_rate, mono=False)
+ vocals = process(mix, sep_model, args, sep_config, device)
+ dereverbed_vocals = process(vocals.mean(0), dereverb_model, args, dereverb_config, device)
+ accompaniment = mix - dereverbed_vocals
+
+ return mix, vocals, dereverbed_vocals, accompaniment, sample_rate
+
+@dataclass
+class VocalSeparationOutputs:
+ """Vocal extraction output container."""
+
+ mix: np.ndarray
+ vocals: np.ndarray
+ vocals_dereverbed: np.ndarray
+ accompaniment: np.ndarray
+ sample_rate: int
+
+
+class VocalSeparator:
+ """Vocal separation and dereverb wrapper.
+
+ Wraps the karaoke separation and dereverb models from the
+ ZFTurbo Music Source Separation project and exposes a simple
+ :py:meth:`process` API that returns mix/vocals/dereverbed/accompaniment.
+ """
+ def __init__(
+ self,
+ sep_model_path: str,
+ sep_config_path: str,
+ der_model_path: str,
+ der_config_path: str,
+ *,
+ model_type: str = "mel_band_roformer",
+ disable_detailed_pbar: bool = True,
+ device: str = "cuda",
+ verbose: bool = True,
+ ):
+ """Initialize the vocal separator.
+
+ Args:
+ device: Torch device string, e.g. ``"cuda:0"``.
+ model_type: Separation model type key.
+ sep_config_path: Config path for separation model.
+ sep_start_check_point: Checkpoint path for separation model.
+ der_config_path: Config path for dereverb model.
+ der_start_check_point: Checkpoint path for dereverb model.
+ disable_detailed_pbar: Disable detailed progress bars in underlying utils.
+ verbose: Whether to print verbose logs.
+ """
+
+ # Match original script args schema
+ args_dict: Dict[str, Any] = {
+ "model_type": model_type,
+ "disable_detailed_pbar": disable_detailed_pbar,
+ "sep_config_path": sep_config_path,
+ "sep_start_check_point": sep_model_path,
+ "der_config_path": der_config_path,
+ "der_start_check_point": der_model_path,
+ }
+
+ if verbose:
+ print("[vocal extraction] init: start")
+
+ sep_model, sep_config, dereverb_model, dereverb_config, args = build_models(args_dict)
+
+ sep_model = sep_model.to(device)
+ dereverb_model = dereverb_model.to(device)
+
+ self.sep_model = sep_model
+ self.sep_config = sep_config
+ self.dereverb_model = dereverb_model
+ self.dereverb_config = dereverb_config
+ self.device = device
+ self.args = args
+ self.verbose = verbose
+
+ if verbose:
+ print(
+ "[vocal extraction] init success: sep=loaded, dereverb=loaded, device=",
+ device,
+ )
+
+ def process(self, input_path: str, *, verbose: Optional[bool] = None) -> VocalSeparationOutputs:
+ """Separate a single audio file into sources.
+
+ Args:
+ input_path: Path to the mixture wav.
+ verbose: Override instance-level verbose flag for this call.
+
+ Returns:
+ :class:`VocalSeparationOutputs` containing mix, vocals,
+ dereverbed vocals, accompaniment and sample rate.
+ """
+ verbose = self.verbose if verbose is None else verbose
+ if verbose:
+ print(f"[vocal extraction] process_file: start: {input_path}")
+ t0 = time.time()
+
+ self.args.input_path = input_path
+
+ mix, vocals, dereverbed, accompaniment, sample_rate = main(
+ self.args,
+ self.sep_model,
+ self.sep_config,
+ self.dereverb_model,
+ self.dereverb_config,
+ torch.device(self.device) if not isinstance(self.device, torch.device) else self.device,
+ )
+
+ if verbose:
+ dt = time.time() - t0
+ print(
+ "[vocal extraction] process_file: done:",
+ f"sr={sample_rate}",
+ f"mix={getattr(mix, 'shape', None)}",
+ f"vocals={getattr(vocals, 'shape', None)}",
+ f"dereverbed={getattr(dereverbed, 'shape', None)}",
+ f"acc={getattr(accompaniment, 'shape', None)}",
+ f"time={dt:.3f}s",
+ )
+
+ return VocalSeparationOutputs(
+ mix=mix,
+ vocals=vocals,
+ vocals_dereverbed=dereverbed,
+ accompaniment=accompaniment,
+ sample_rate=sample_rate,
+ )
+
+
+if __name__ == "__main__":
+
+ m = VocalSeparator(
+ sep_model_path="pretrained_models/mel-band-roformer-karaoke/mel_band_roformer_karaoke_becruily.ckpt",
+ sep_config_path="pretrained_models/mel-band-roformer-karaoke/config_karaoke_becruily.yaml",
+ der_model_path="pretrained_models/dereverb_mel_band_roformer/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt",
+ der_config_path="pretrained_models/dereverb_mel_band_roformer/dereverb_mel_band_roformer_anvuew.yaml",
+ device="cuda"
+ )
+
+ out = m.process("example/test/separation_test.mp3")
+ print(out.vocals_dereverbed.shape)
diff --git a/preprocess/tools/vocal_separation/modules/bs_roformer/__init__.py b/preprocess/tools/vocal_separation/modules/bs_roformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b670717e56526cae3a698103db86a9cb9ab05f9
--- /dev/null
+++ b/preprocess/tools/vocal_separation/modules/bs_roformer/__init__.py
@@ -0,0 +1,2 @@
+from .bs_roformer import BSRoformer
+from .mel_band_roformer import MelBandRoformer
diff --git a/preprocess/tools/vocal_separation/modules/bs_roformer/attend.py b/preprocess/tools/vocal_separation/modules/bs_roformer/attend.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6dc4b3079cff5b3c8c90cea8df2301afd18918b
--- /dev/null
+++ b/preprocess/tools/vocal_separation/modules/bs_roformer/attend.py
@@ -0,0 +1,126 @@
+from functools import wraps
+from packaging import version
+from collections import namedtuple
+
+import os
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+
+from einops import rearrange, reduce
+
+# constants
+
+FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
+
+# helpers
+
+def exists(val):
+ return val is not None
+
+def default(v, d):
+ return v if exists(v) else d
+
+def once(fn):
+ called = False
+ @wraps(fn)
+ def inner(x):
+ nonlocal called
+ if called:
+ return
+ called = True
+ return fn(x)
+ return inner
+
+print_once = once(print)
+
+# main class
+
+class Attend(nn.Module):
+ def __init__(
+ self,
+ dropout = 0.,
+ flash = False,
+ scale = None
+ ):
+ super().__init__()
+ self.scale = scale
+ self.dropout = dropout
+ self.attn_dropout = nn.Dropout(dropout)
+
+ self.flash = flash
+ assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
+
+ # determine efficient attention configs for cuda and cpu
+
+ self.cpu_config = FlashAttentionConfig(True, True, True)
+ self.cuda_config = None
+
+ if not torch.cuda.is_available() or not flash:
+ return
+
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
+ device_version = version.parse(f'{device_properties.major}.{device_properties.minor}')
+
+ if device_version >= version.parse('8.0'):
+ if os.name == 'nt':
+ print_once('Windows OS detected, using math or mem efficient attention if input tensor is on cuda')
+ self.cuda_config = FlashAttentionConfig(False, True, True)
+ else:
+ print_once('GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda')
+ self.cuda_config = FlashAttentionConfig(True, False, False)
+ else:
+ print_once('GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda')
+ self.cuda_config = FlashAttentionConfig(False, True, True)
+
+ def flash_attn(self, q, k, v):
+ _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
+
+ if exists(self.scale):
+ default_scale = q.shape[-1] ** -0.5
+ q = q * (self.scale / default_scale)
+
+ # Check if there is a compatible device for flash attention
+
+ config = self.cuda_config if is_cuda else self.cpu_config
+
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
+
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
+ out = F.scaled_dot_product_attention(
+ q, k, v,
+ dropout_p = self.dropout if self.training else 0.
+ )
+
+ return out
+
+ def forward(self, q, k, v):
+ """
+ einstein notation
+ b - batch
+ h - heads
+ n, i, j - sequence length (base sequence length, source, target)
+ d - feature dimension
+ """
+
+ q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
+
+ scale = default(self.scale, q.shape[-1] ** -0.5)
+
+ if self.flash:
+ return self.flash_attn(q, k, v)
+
+ # similarity
+
+ sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
+
+ # attention
+
+ attn = sim.softmax(dim=-1)
+ attn = self.attn_dropout(attn)
+
+ # aggregate values
+
+ out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
+
+ return out
diff --git a/preprocess/tools/vocal_separation/modules/bs_roformer/attend_sage.py b/preprocess/tools/vocal_separation/modules/bs_roformer/attend_sage.py
new file mode 100644
index 0000000000000000000000000000000000000000..f07b932a9f3e9b4b3ad0a2c595e43f45150dbd7b
--- /dev/null
+++ b/preprocess/tools/vocal_separation/modules/bs_roformer/attend_sage.py
@@ -0,0 +1,145 @@
+from functools import wraps
+from packaging import version
+from collections import namedtuple
+
+import os
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+
+from einops import rearrange, reduce
+
+def _print_once(msg):
+ printed = False
+ @wraps(print)
+ def inner():
+ nonlocal printed
+ if not printed:
+ print(msg)
+ printed = True
+ return inner
+
+try:
+ from sageattention import sageattn
+ _has_sage_attention = True
+ # _print_sage_found = _print_once("SageAttention found. Will be used when flash=True.")
+ # _print_sage_found()
+except ImportError:
+ _has_sage_attention = False
+ _print_sage_not_found = _print_once("SageAttention not found. Will fall back to PyTorch SDPA (if available) or manual einsum.")
+ _print_sage_not_found()
+
+# helpers
+def exists(val):
+ return val is not None
+
+def default(v, d):
+ return v if exists(v) else d
+
+# main class
+class Attend(nn.Module):
+ def __init__(
+ self,
+ dropout = 0.,
+ flash = False, # If True, attempts to use SageAttention or PyTorch SDPA
+ scale = None
+ ):
+ super().__init__()
+ self.scale = scale # Store the scale if needed for einsum path
+ self.dropout = dropout # Store dropout if needed for einsum/SDPA path
+
+ # Determine which attention mechanism to *try* first
+ self.use_sage = flash and _has_sage_attention
+ self.use_pytorch_sdpa = False
+ self._sdpa_checked = False # Flag to check PyTorch version only once
+
+ if flash and not self.use_sage:
+ # Only consider PyTorch SDPA if Sage isn't available/chosen
+ if not self._sdpa_checked:
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ self.use_pytorch_sdpa = True
+ _print_sdpa_used = _print_once("Using PyTorch SDPA backend (FlashAttention-2, Memory-Efficient, or Math).")
+ _print_sdpa_used()
+ else:
+ _print_fallback_einsum = _print_once("Flash attention requested but Pytorch < 2.0 and SageAttention not found. Falling back to einsum.")
+ _print_fallback_einsum()
+ self._sdpa_checked = True
+
+ # Dropout layer for manual einsum implementation ONLY
+ # SDPA and SageAttention handle dropout differently (or not at all in Sage's base API)
+ self.attn_dropout = nn.Dropout(dropout)
+
+ def forward(self, q, k, v):
+ """
+ einstein notation
+ b - batch
+ h - heads
+ n, i, j - sequence length (base sequence length, source, target)
+ d - feature dimension
+
+ Input tensors q, k, v expected in shape: (batch, heads, seq_len, dim_head) -> HND layout
+ """
+ q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
+
+ # --- Priority 1: SageAttention ---
+ if self.use_sage:
+ # Assumes q, k, v are FP16/BF16 (handled by autocast upstream)
+ # Assumes scale is handled internally by sageattn
+ # Assumes dropout is NOT handled by sageattn kernel
+ # is_causal=False based on how Attend is called in mel_band_roformer
+ out = sageattn(q, k, v, tensor_layout='HND', is_causal=False)
+ return out
+ try:
+ return out
+ # print("Attempting SageAttention") # Optional: for debugging
+ out = sageattn(q, k, v, tensor_layout='HND', is_causal=False)
+ return out
+ except Exception as e:
+ print(f"SageAttention failed with error: {e}. Falling back.")
+ self.use_sage = False # Don't try Sage again if it failed once
+ # Decide fallback: Check if PyTorch SDPA is an option
+ if not self._sdpa_checked:
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
+ self.use_pytorch_sdpa = True
+ _print_sdpa_fallback = _print_once("Falling back to PyTorch SDPA.")
+ _print_sdpa_fallback()
+ else:
+ _print_einsum_fallback = _print_once("Falling back to einsum.")
+ _print_einsum_fallback()
+ self._sdpa_checked = True
+
+
+ # --- Priority 2: PyTorch SDPA ---
+ if self.use_pytorch_sdpa:
+ # Use PyTorch's Scaled Dot Product Attention (SDPA)
+ # It handles scaling and dropout internally.
+ try:
+ # print("Attempting PyTorch SDPA") # Optional: for debugging
+ # Let PyTorch choose the best backend (Flash V2, Mem Efficient, Math)
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
+ out = F.scaled_dot_product_attention(
+ q, k, v,
+ attn_mask=None, # Assuming no explicit mask needed here
+ dropout_p = self.dropout if self.training else 0.,
+ is_causal=False # Assuming not needed based on usage context
+ )
+ return out
+ except Exception as e:
+ print(f"PyTorch SDPA failed with error: {e}. Falling back to einsum.")
+ self.use_pytorch_sdpa = False # Fallback to einsum on error
+
+
+ # Calculate scale
+ scale = default(self.scale, q.shape[-1] ** -0.5)
+
+ # similarity
+ sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
+
+ # attention
+ attn = sim.softmax(dim=-1)
+ attn = self.attn_dropout(attn) # Apply dropout ONLY in einsum path
+
+ # aggregate values
+ out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
+
+ return out
diff --git a/preprocess/tools/vocal_separation/modules/bs_roformer/bs_roformer.py b/preprocess/tools/vocal_separation/modules/bs_roformer/bs_roformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b78827a30013a157ca69ba73fda52544ec0255a
--- /dev/null
+++ b/preprocess/tools/vocal_separation/modules/bs_roformer/bs_roformer.py
@@ -0,0 +1,658 @@
+from functools import partial
+
+import torch
+from torch import nn, einsum, Tensor
+from torch.nn import Module, ModuleList
+import torch.nn.functional as F
+
+from .attend import Attend
+try:
+ from .attend_sage import Attend as AttendSage
+except:
+ pass
+from torch.utils.checkpoint import checkpoint
+
+from beartype.typing import Tuple, Optional, List, Callable
+from beartype import beartype
+
+from rotary_embedding_torch import RotaryEmbedding
+
+from einops import rearrange, pack, unpack
+from einops.layers.torch import Rearrange
+
+# helper functions
+
+def exists(val):
+ return val is not None
+
+
+def default(v, d):
+ return v if exists(v) else d
+
+
+def pack_one(t, pattern):
+ return pack([t], pattern)
+
+
+def unpack_one(t, ps, pattern):
+ return unpack(t, ps, pattern)[0]
+
+
+# norm
+
+def l2norm(t):
+ return F.normalize(t, dim = -1, p = 2)
+
+
+class RMSNorm(Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.scale = dim ** 0.5
+ self.gamma = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
+
+
+# attention
+
+class FeedForward(Module):
+ def __init__(
+ self,
+ dim,
+ mult=4,
+ dropout=0.
+ ):
+ super().__init__()
+ dim_inner = int(dim * mult)
+ self.net = nn.Sequential(
+ RMSNorm(dim),
+ nn.Linear(dim, dim_inner),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(dim_inner, dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class Attention(Module):
+ def __init__(
+ self,
+ dim,
+ heads=8,
+ dim_head=64,
+ dropout=0.,
+ rotary_embed=None,
+ flash=True,
+ sage_attention=False,
+ ):
+ super().__init__()
+ self.heads = heads
+ self.scale = dim_head ** -0.5
+ dim_inner = heads * dim_head
+
+ self.rotary_embed = rotary_embed
+
+ if sage_attention:
+ self.attend = AttendSage(flash=flash, dropout=dropout)
+ else:
+ self.attend = Attend(flash=flash, dropout=dropout)
+
+ self.norm = RMSNorm(dim)
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
+
+ self.to_gates = nn.Linear(dim, heads)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(dim_inner, dim, bias=False),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x):
+ x = self.norm(x)
+
+ q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
+
+ if exists(self.rotary_embed):
+ q = self.rotary_embed.rotate_queries_or_keys(q)
+ k = self.rotary_embed.rotate_queries_or_keys(k)
+
+ out = self.attend(q, k, v)
+
+ gates = self.to_gates(x)
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
+
+ out = rearrange(out, 'b h n d -> b n (h d)')
+ return self.to_out(out)
+
+
+class LinearAttention(Module):
+ """
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
+ """
+
+ @beartype
+ def __init__(
+ self,
+ *,
+ dim,
+ dim_head=32,
+ heads=8,
+ scale=8,
+ flash=False,
+ dropout=0.,
+ sage_attention=False,
+ ):
+ super().__init__()
+ dim_inner = dim_head * heads
+ self.norm = RMSNorm(dim)
+
+ self.to_qkv = nn.Sequential(
+ nn.Linear(dim, dim_inner * 3, bias=False),
+ Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
+ )
+
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
+
+ if sage_attention:
+ self.attend = AttendSage(
+ scale=scale,
+ dropout=dropout,
+ flash=flash
+ )
+ else:
+ self.attend = Attend(
+ scale=scale,
+ dropout=dropout,
+ flash=flash
+ )
+
+ self.to_out = nn.Sequential(
+ Rearrange('b h d n -> b n (h d)'),
+ nn.Linear(dim_inner, dim, bias=False)
+ )
+
+ def forward(
+ self,
+ x
+ ):
+ x = self.norm(x)
+
+ q, k, v = self.to_qkv(x)
+
+ q, k = map(l2norm, (q, k))
+ q = q * self.temperature.exp()
+
+ out = self.attend(q, k, v)
+
+ return self.to_out(out)
+
+
+class Transformer(Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ depth,
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.,
+ ff_dropout=0.,
+ ff_mult=4,
+ norm_output=True,
+ rotary_embed=None,
+ flash_attn=True,
+ linear_attn=False,
+ sage_attention=False,
+ ):
+ super().__init__()
+ self.layers = ModuleList([])
+
+ for _ in range(depth):
+ if linear_attn:
+ attn = LinearAttention(
+ dim=dim,
+ dim_head=dim_head,
+ heads=heads,
+ dropout=attn_dropout,
+ flash=flash_attn,
+ sage_attention=sage_attention
+ )
+ else:
+ attn = Attention(
+ dim=dim,
+ dim_head=dim_head,
+ heads=heads,
+ dropout=attn_dropout,
+ rotary_embed=rotary_embed,
+ flash=flash_attn,
+ sage_attention=sage_attention
+ )
+
+ self.layers.append(ModuleList([
+ attn,
+ FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
+ ]))
+
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
+
+ def forward(self, x):
+
+ for attn, ff in self.layers:
+ x = attn(x) + x
+ x = ff(x) + x
+
+ return self.norm(x)
+
+
+# bandsplit module
+
+class BandSplit(Module):
+ @beartype
+ def __init__(
+ self,
+ dim,
+ dim_inputs: Tuple[int, ...]
+ ):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_features = ModuleList([])
+
+ for dim_in in dim_inputs:
+ net = nn.Sequential(
+ RMSNorm(dim_in),
+ nn.Linear(dim_in, dim)
+ )
+
+ self.to_features.append(net)
+
+ def forward(self, x):
+ x = x.split(self.dim_inputs, dim=-1)
+
+ outs = []
+ for split_input, to_feature in zip(x, self.to_features):
+ split_output = to_feature(split_input)
+ outs.append(split_output)
+
+ return torch.stack(outs, dim=-2)
+
+
+def MLP(
+ dim_in,
+ dim_out,
+ dim_hidden=None,
+ depth=1,
+ activation=nn.Tanh
+):
+ dim_hidden = default(dim_hidden, dim_in)
+
+ net = []
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
+
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
+ is_last = ind == (len(dims) - 2)
+
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
+
+ if is_last:
+ continue
+
+ net.append(activation())
+
+ return nn.Sequential(*net)
+
+
+class MaskEstimator(Module):
+ @beartype
+ def __init__(
+ self,
+ dim,
+ dim_inputs: Tuple[int, ...],
+ depth,
+ mlp_expansion_factor=4
+ ):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_freqs = ModuleList([])
+ dim_hidden = dim * mlp_expansion_factor
+
+ for dim_in in dim_inputs:
+ net = []
+
+ mlp = nn.Sequential(
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
+ nn.GLU(dim=-1)
+ )
+
+ self.to_freqs.append(mlp)
+
+ def forward(self, x):
+ x = x.unbind(dim=-2)
+
+ outs = []
+
+ for band_features, mlp in zip(x, self.to_freqs):
+ freq_out = mlp(band_features)
+ outs.append(freq_out)
+
+ return torch.cat(outs, dim=-1)
+
+
+# main class
+
+DEFAULT_FREQS_PER_BANDS = (
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2,
+ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
+ 12, 12, 12, 12, 12, 12, 12, 12,
+ 24, 24, 24, 24, 24, 24, 24, 24,
+ 48, 48, 48, 48, 48, 48, 48, 48,
+ 128, 129,
+)
+
+
+class BSRoformer(Module):
+
+ @beartype
+ def __init__(
+ self,
+ dim,
+ *,
+ depth,
+ stereo=False,
+ num_stems=1,
+ time_transformer_depth=2,
+ freq_transformer_depth=2,
+ linear_transformer_depth=0,
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
+ # in the paper, they divide into ~60 bands, test with 1 for starters
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.,
+ ff_dropout=0.,
+ flash_attn=True,
+ dim_freqs_in=1025,
+ stft_n_fft=2048,
+ stft_hop_length=512,
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
+ stft_win_length=2048,
+ stft_normalized=False,
+ stft_window_fn: Optional[Callable] = None,
+ mask_estimator_depth=2,
+ multi_stft_resolution_loss_weight=1.,
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
+ multi_stft_hop_size=147,
+ multi_stft_normalized=False,
+ multi_stft_window_fn: Callable = torch.hann_window,
+ mlp_expansion_factor=4,
+ use_torch_checkpoint=False,
+ skip_connection=False,
+ sage_attention=False,
+ ):
+ super().__init__()
+
+ self.stereo = stereo
+ self.audio_channels = 2 if stereo else 1
+ self.num_stems = num_stems
+ self.use_torch_checkpoint = use_torch_checkpoint
+ self.skip_connection = skip_connection
+
+ self.layers = ModuleList([])
+
+ if sage_attention:
+ print("Use Sage Attention")
+
+ transformer_kwargs = dict(
+ dim=dim,
+ heads=heads,
+ dim_head=dim_head,
+ attn_dropout=attn_dropout,
+ ff_dropout=ff_dropout,
+ flash_attn=flash_attn,
+ norm_output=False,
+ sage_attention=sage_attention,
+ )
+
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
+
+ for _ in range(depth):
+ tran_modules = []
+ if linear_transformer_depth > 0:
+ tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
+ tran_modules.append(
+ Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
+ )
+ tran_modules.append(
+ Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
+ )
+ self.layers.append(nn.ModuleList(tran_modules))
+
+ self.final_norm = RMSNorm(dim)
+
+ self.stft_kwargs = dict(
+ n_fft=stft_n_fft,
+ hop_length=stft_hop_length,
+ win_length=stft_win_length,
+ normalized=stft_normalized
+ )
+
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
+
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True).shape[1]
+
+ assert len(freqs_per_bands) > 1
+ assert sum(
+ freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
+
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
+
+ self.band_split = BandSplit(
+ dim=dim,
+ dim_inputs=freqs_per_bands_with_complex
+ )
+
+ self.mask_estimators = nn.ModuleList([])
+
+ for _ in range(num_stems):
+ mask_estimator = MaskEstimator(
+ dim=dim,
+ dim_inputs=freqs_per_bands_with_complex,
+ depth=mask_estimator_depth,
+ mlp_expansion_factor=mlp_expansion_factor,
+ )
+
+ self.mask_estimators.append(mask_estimator)
+
+ # for the multi-resolution stft loss
+
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
+ self.multi_stft_n_fft = stft_n_fft
+ self.multi_stft_window_fn = multi_stft_window_fn
+
+ self.multi_stft_kwargs = dict(
+ hop_length=multi_stft_hop_size,
+ normalized=multi_stft_normalized
+ )
+
+ def forward(
+ self,
+ raw_audio,
+ target=None,
+ return_loss_breakdown=False
+ ):
+ """
+ einops
+
+ b - batch
+ f - freq
+ t - time
+ s - audio channel (1 for mono, 2 for stereo)
+ n - number of 'stems'
+ c - complex (2)
+ d - feature dimension
+ """
+
+ device = raw_audio.device
+
+ # defining whether model is loaded on MPS (MacOS GPU accelerator)
+ x_is_mps = True if device.type == "mps" else False
+
+ if raw_audio.ndim == 2:
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
+
+ channels = raw_audio.shape[1]
+ assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
+
+ # to stft
+
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
+
+ stft_window = self.stft_window_fn(device=device)
+
+ # RuntimeError: FFT operations are only supported on MacOS 14+
+ # Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used
+ try:
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
+ except:
+ stft_repr = torch.stft(raw_audio.cpu() if x_is_mps else raw_audio, **self.stft_kwargs,
+ window=stft_window.cpu() if x_is_mps else stft_window, return_complex=True).to(
+ device)
+ stft_repr = torch.view_as_real(stft_repr)
+
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
+
+ # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
+ stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
+
+ x = rearrange(stft_repr, 'b f t c -> b t (f c)')
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(self.band_split, x, use_reentrant=False)
+ else:
+ x = self.band_split(x)
+
+ # axial / hierarchical attention
+
+ store = [None] * len(self.layers)
+ for i, transformer_block in enumerate(self.layers):
+
+ if len(transformer_block) == 3:
+ linear_transformer, time_transformer, freq_transformer = transformer_block
+
+ x, ft_ps = pack([x], 'b * d')
+ if self.use_torch_checkpoint:
+ x = checkpoint(linear_transformer, x, use_reentrant=False)
+ else:
+ x = linear_transformer(x)
+ x, = unpack(x, ft_ps, 'b * d')
+ else:
+ time_transformer, freq_transformer = transformer_block
+
+ if self.skip_connection:
+ # Sum all previous
+ for j in range(i):
+ x = x + store[j]
+
+ x = rearrange(x, 'b t f d -> b f t d')
+ x, ps = pack([x], '* t d')
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(time_transformer, x, use_reentrant=False)
+ else:
+ x = time_transformer(x)
+
+ x, = unpack(x, ps, '* t d')
+ x = rearrange(x, 'b f t d -> b t f d')
+ x, ps = pack([x], '* f d')
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(freq_transformer, x, use_reentrant=False)
+ else:
+ x = freq_transformer(x)
+
+ x, = unpack(x, ps, '* f d')
+
+ if self.skip_connection:
+ store[i] = x
+
+ x = self.final_norm(x)
+
+ num_stems = len(self.mask_estimators)
+
+ if self.use_torch_checkpoint:
+ mask = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
+ else:
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
+ mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
+
+ # modulate frequency representation
+
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
+
+ # complex number multiplication
+
+ stft_repr = torch.view_as_complex(stft_repr)
+ mask = torch.view_as_complex(mask)
+
+ stft_repr = stft_repr * mask
+
+ # istft
+
+ stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
+
+ # same as torch.stft() fix for MacOS MPS above
+ try:
+ recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1])
+ except:
+ recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=raw_audio.shape[-1]).to(device)
+
+ recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
+
+ if num_stems == 1:
+ recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
+
+ # if a target is passed in, calculate loss for learning
+
+ if not exists(target):
+ return recon_audio
+
+ if self.num_stems > 1:
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
+
+ if target.ndim == 2:
+ target = rearrange(target, '... t -> ... 1 t')
+
+ target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
+
+ loss = F.l1_loss(recon_audio, target)
+
+ multi_stft_resolution_loss = 0.
+
+ for window_size in self.multi_stft_resolutions_window_sizes:
+ res_stft_kwargs = dict(
+ n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
+ win_length=window_size,
+ return_complex=True,
+ window=self.multi_stft_window_fn(window_size, device=device),
+ **self.multi_stft_kwargs,
+ )
+
+ recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
+ target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
+
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
+
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
+
+ total_loss = loss + weighted_multi_resolution_loss
+
+ if not return_loss_breakdown:
+ return total_loss
+
+ return total_loss, (loss, multi_stft_resolution_loss)
\ No newline at end of file
diff --git a/preprocess/tools/vocal_separation/modules/bs_roformer/mel_band_roformer.py b/preprocess/tools/vocal_separation/modules/bs_roformer/mel_band_roformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..44565244bffd7062949f79b1f27fd3f530c7e99e
--- /dev/null
+++ b/preprocess/tools/vocal_separation/modules/bs_roformer/mel_band_roformer.py
@@ -0,0 +1,703 @@
+from functools import partial
+
+import torch
+from torch import nn, einsum, Tensor
+from torch.nn import Module, ModuleList
+import torch.nn.functional as F
+
+from .attend import Attend
+try:
+ from .attend_sage import Attend as AttendSage
+except:
+ pass
+from torch.utils.checkpoint import checkpoint
+
+from beartype.typing import Tuple, Optional, List, Callable
+from beartype import beartype
+
+from rotary_embedding_torch import RotaryEmbedding
+
+from einops import rearrange, pack, unpack, reduce, repeat
+from einops.layers.torch import Rearrange
+
+from librosa import filters
+
+
+# helper functions
+
+def exists(val):
+ return val is not None
+
+
+def default(v, d):
+ return v if exists(v) else d
+
+
+def pack_one(t, pattern):
+ return pack([t], pattern)
+
+
+def unpack_one(t, ps, pattern):
+ return unpack(t, ps, pattern)[0]
+
+
+def pad_at_dim(t, pad, dim=-1, value=0.):
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
+ zeros = ((0, 0) * dims_from_right)
+ return F.pad(t, (*zeros, *pad), value=value)
+
+
+def l2norm(t):
+ return F.normalize(t, dim=-1, p=2)
+
+
+# norm
+
+class RMSNorm(Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.scale = dim ** 0.5
+ self.gamma = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
+
+
+# attention
+
+class FeedForward(Module):
+ def __init__(
+ self,
+ dim,
+ mult=4,
+ dropout=0.
+ ):
+ super().__init__()
+ dim_inner = int(dim * mult)
+ self.net = nn.Sequential(
+ RMSNorm(dim),
+ nn.Linear(dim, dim_inner),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(dim_inner, dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class Attention(Module):
+ def __init__(
+ self,
+ dim,
+ heads=8,
+ dim_head=64,
+ dropout=0.,
+ rotary_embed=None,
+ flash=True,
+ sage_attention=False,
+ ):
+ super().__init__()
+ self.heads = heads
+ self.scale = dim_head ** -0.5
+ dim_inner = heads * dim_head
+
+ self.rotary_embed = rotary_embed
+
+ if sage_attention:
+ self.attend = AttendSage(flash=flash, dropout=dropout)
+ else:
+ self.attend = Attend(flash=flash, dropout=dropout)
+ self.norm = RMSNorm(dim)
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
+
+ self.to_gates = nn.Linear(dim, heads)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(dim_inner, dim, bias=False),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x):
+ x = self.norm(x)
+
+ q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
+
+ if exists(self.rotary_embed):
+ q = self.rotary_embed.rotate_queries_or_keys(q)
+ k = self.rotary_embed.rotate_queries_or_keys(k)
+
+ out = self.attend(q, k, v)
+
+ gates = self.to_gates(x)
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
+
+ out = rearrange(out, 'b h n d -> b n (h d)')
+ return self.to_out(out)
+
+
+class LinearAttention(Module):
+ """
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
+ """
+
+ @beartype
+ def __init__(
+ self,
+ *,
+ dim,
+ dim_head=32,
+ heads=8,
+ scale=8,
+ flash=False,
+ dropout=0.,
+ sage_attention=False
+ ):
+ super().__init__()
+ dim_inner = dim_head * heads
+ self.norm = RMSNorm(dim)
+
+ self.to_qkv = nn.Sequential(
+ nn.Linear(dim, dim_inner * 3, bias=False),
+ Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
+ )
+
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
+
+ if sage_attention:
+ self.attend = AttendSage(
+ scale=scale,
+ dropout=dropout,
+ flash=flash
+ )
+ else:
+ self.attend = Attend(
+ scale=scale,
+ dropout=dropout,
+ flash=flash
+ )
+
+ self.to_out = nn.Sequential(
+ Rearrange('b h d n -> b n (h d)'),
+ nn.Linear(dim_inner, dim, bias=False)
+ )
+
+ def forward(
+ self,
+ x
+ ):
+ x = self.norm(x)
+
+ q, k, v = self.to_qkv(x)
+
+ q, k = map(l2norm, (q, k))
+ q = q * self.temperature.exp()
+
+ out = self.attend(q, k, v)
+
+ return self.to_out(out)
+
+
+class Transformer(Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ depth,
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.,
+ ff_dropout=0.,
+ ff_mult=4,
+ norm_output=True,
+ rotary_embed=None,
+ flash_attn=True,
+ linear_attn=False,
+ sage_attention=False,
+ ):
+ super().__init__()
+ self.layers = ModuleList([])
+
+ for _ in range(depth):
+ if linear_attn:
+ attn = LinearAttention(
+ dim=dim,
+ dim_head=dim_head,
+ heads=heads,
+ dropout=attn_dropout,
+ flash=flash_attn,
+ sage_attention=sage_attention
+ )
+ else:
+ attn = Attention(
+ dim=dim,
+ dim_head=dim_head,
+ heads=heads,
+ dropout=attn_dropout,
+ rotary_embed=rotary_embed,
+ flash=flash_attn,
+ sage_attention=sage_attention
+ )
+
+ self.layers.append(ModuleList([
+ attn,
+ FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
+ ]))
+
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
+
+ def forward(self, x):
+
+ for attn, ff in self.layers:
+ x = attn(x) + x
+ x = ff(x) + x
+
+ return self.norm(x)
+
+
+# bandsplit module
+
+class BandSplit(Module):
+ @beartype
+ def __init__(
+ self,
+ dim,
+ dim_inputs: Tuple[int, ...]
+ ):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_features = ModuleList([])
+
+ for dim_in in dim_inputs:
+ net = nn.Sequential(
+ RMSNorm(dim_in),
+ nn.Linear(dim_in, dim)
+ )
+
+ self.to_features.append(net)
+
+ def forward(self, x):
+ x = x.split(self.dim_inputs, dim=-1)
+
+ outs = []
+ for split_input, to_feature in zip(x, self.to_features):
+ split_output = to_feature(split_input)
+ outs.append(split_output)
+
+ return torch.stack(outs, dim=-2)
+
+
+def MLP(
+ dim_in,
+ dim_out,
+ dim_hidden=None,
+ depth=1,
+ activation=nn.Tanh
+):
+ dim_hidden = default(dim_hidden, dim_in)
+
+ net = []
+ dims = (dim_in, *((dim_hidden,) * depth), dim_out)
+
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
+ is_last = ind == (len(dims) - 2)
+
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
+
+ if is_last:
+ continue
+
+ net.append(activation())
+
+ return nn.Sequential(*net)
+
+
+class MaskEstimator(Module):
+ @beartype
+ def __init__(
+ self,
+ dim,
+ dim_inputs: Tuple[int, ...],
+ depth,
+ mlp_expansion_factor=4
+ ):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_freqs = ModuleList([])
+ dim_hidden = dim * mlp_expansion_factor
+
+ for dim_in in dim_inputs:
+ net = []
+
+ mlp = nn.Sequential(
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
+ nn.GLU(dim=-1)
+ )
+
+ self.to_freqs.append(mlp)
+
+ def forward(self, x):
+ x = x.unbind(dim=-2)
+
+ outs = []
+
+ for band_features, mlp in zip(x, self.to_freqs):
+ freq_out = mlp(band_features)
+ outs.append(freq_out)
+
+ return torch.cat(outs, dim=-1)
+
+
+# main class
+
+class MelBandRoformer(Module):
+
+ @beartype
+ def __init__(
+ self,
+ dim,
+ *,
+ depth,
+ stereo=False,
+ num_stems=1,
+ time_transformer_depth=2,
+ freq_transformer_depth=2,
+ linear_transformer_depth=0,
+ num_bands=60,
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.1,
+ ff_dropout=0.1,
+ flash_attn=True,
+ dim_freqs_in=1025,
+ sample_rate=44100, # needed for mel filter bank from librosa
+ stft_n_fft=2048,
+ stft_hop_length=512,
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
+ stft_win_length=2048,
+ stft_normalized=False,
+ stft_window_fn: Optional[Callable] = None,
+ mask_estimator_depth=1,
+ multi_stft_resolution_loss_weight=1.,
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
+ multi_stft_hop_size=147,
+ multi_stft_normalized=False,
+ multi_stft_window_fn: Callable = torch.hann_window,
+ match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
+ mlp_expansion_factor=4,
+ use_torch_checkpoint=False,
+ skip_connection=False,
+ sage_attention=False,
+ ):
+ super().__init__()
+
+ self.stereo = stereo
+ self.audio_channels = 2 if stereo else 1
+ self.num_stems = num_stems
+ self.use_torch_checkpoint = use_torch_checkpoint
+ self.skip_connection = skip_connection
+
+ self.layers = ModuleList([])
+
+ if sage_attention:
+ print("Use Sage Attention")
+
+ transformer_kwargs = dict(
+ dim=dim,
+ heads=heads,
+ dim_head=dim_head,
+ attn_dropout=attn_dropout,
+ ff_dropout=ff_dropout,
+ flash_attn=flash_attn,
+ sage_attention=sage_attention,
+ )
+
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
+
+ for _ in range(depth):
+ tran_modules = []
+ if linear_transformer_depth > 0:
+ tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
+ tran_modules.append(
+ Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
+ )
+ tran_modules.append(
+ Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
+ )
+ self.layers.append(nn.ModuleList(tran_modules))
+
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
+
+ self.stft_kwargs = dict(
+ n_fft=stft_n_fft,
+ hop_length=stft_hop_length,
+ win_length=stft_win_length,
+ normalized=stft_normalized
+ )
+
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1]
+
+ # create mel filter bank
+ # with librosa.filters.mel as in section 2 of paper
+
+ mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
+
+ mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
+
+ # for some reason, it doesn't include the first freq? just force a value for now
+
+ mel_filter_bank[0][0] = 1.
+
+ # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
+ # so let's force a positive value
+
+ mel_filter_bank[-1, -1] = 1.
+
+ # binary as in paper (then estimated masks are averaged for overlapping regions)
+
+ freqs_per_band = mel_filter_bank > 0
+ assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
+
+ repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
+ freq_indices = repeated_freq_indices[freqs_per_band]
+
+ if stereo:
+ freq_indices = repeat(freq_indices, 'f -> f s', s=2)
+ freq_indices = freq_indices * 2 + torch.arange(2)
+ freq_indices = rearrange(freq_indices, 'f s -> (f s)')
+
+ self.register_buffer('freq_indices', freq_indices, persistent=False)
+ self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
+
+ num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
+ num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
+
+ self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
+ self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
+
+ # band split and mask estimator
+
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
+
+ self.band_split = BandSplit(
+ dim=dim,
+ dim_inputs=freqs_per_bands_with_complex
+ )
+
+ self.mask_estimators = nn.ModuleList([])
+
+ for _ in range(num_stems):
+ mask_estimator = MaskEstimator(
+ dim=dim,
+ dim_inputs=freqs_per_bands_with_complex,
+ depth=mask_estimator_depth,
+ mlp_expansion_factor=mlp_expansion_factor,
+ )
+
+ self.mask_estimators.append(mask_estimator)
+
+ # for the multi-resolution stft loss
+
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
+ self.multi_stft_n_fft = stft_n_fft
+ self.multi_stft_window_fn = multi_stft_window_fn
+
+ self.multi_stft_kwargs = dict(
+ hop_length=multi_stft_hop_size,
+ normalized=multi_stft_normalized
+ )
+
+ self.match_input_audio_length = match_input_audio_length
+
+ def forward(
+ self,
+ raw_audio,
+ target=None,
+ return_loss_breakdown=False
+ ):
+ """
+ einops
+
+ b - batch
+ f - freq
+ t - time
+ s - audio channel (1 for mono, 2 for stereo)
+ n - number of 'stems'
+ c - complex (2)
+ d - feature dimension
+ """
+
+ device = raw_audio.device
+
+ if raw_audio.ndim == 2:
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
+
+ batch, channels, raw_audio_length = raw_audio.shape
+
+ istft_length = raw_audio_length if self.match_input_audio_length else None
+
+ assert (not self.stereo and channels == 1) or (
+ self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
+
+ # to stft
+
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
+
+ stft_window = self.stft_window_fn(device=device)
+
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
+ stft_repr = torch.view_as_real(stft_repr)
+
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
+
+ # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
+ stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
+
+ # index out all frequencies for all frequency ranges across bands ascending in one go
+
+ batch_arange = torch.arange(batch, device=device)[..., None]
+
+ # account for stereo
+
+ x = stft_repr[batch_arange, self.freq_indices]
+
+ # fold the complex (real and imag) into the frequencies dimension
+
+ x = rearrange(x, 'b f t c -> b t (f c)')
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(self.band_split, x, use_reentrant=False)
+ else:
+ x = self.band_split(x)
+
+ # axial / hierarchical attention
+
+ store = [None] * len(self.layers)
+ for i, transformer_block in enumerate(self.layers):
+
+ if len(transformer_block) == 3:
+ linear_transformer, time_transformer, freq_transformer = transformer_block
+
+ x, ft_ps = pack([x], 'b * d')
+ if self.use_torch_checkpoint:
+ x = checkpoint(linear_transformer, x, use_reentrant=False)
+ else:
+ x = linear_transformer(x)
+ x, = unpack(x, ft_ps, 'b * d')
+ else:
+ time_transformer, freq_transformer = transformer_block
+
+ if self.skip_connection:
+ # Sum all previous
+ for j in range(i):
+ x = x + store[j]
+
+ x = rearrange(x, 'b t f d -> b f t d')
+ x, ps = pack([x], '* t d')
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(time_transformer, x, use_reentrant=False)
+ else:
+ x = time_transformer(x)
+
+ x, = unpack(x, ps, '* t d')
+ x = rearrange(x, 'b f t d -> b t f d')
+ x, ps = pack([x], '* f d')
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(freq_transformer, x, use_reentrant=False)
+ else:
+ x = freq_transformer(x)
+
+ x, = unpack(x, ps, '* f d')
+
+ if self.skip_connection:
+ store[i] = x
+
+ num_stems = len(self.mask_estimators)
+ if self.use_torch_checkpoint:
+ masks = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
+ else:
+ masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
+ masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
+
+ # modulate frequency representation
+
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
+
+ # complex number multiplication
+
+ stft_repr = torch.view_as_complex(stft_repr)
+ masks = torch.view_as_complex(masks)
+
+ masks = masks.type(stft_repr.dtype)
+
+ # need to average the estimated mask for the overlapped frequencies
+
+ scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
+
+ stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
+ masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
+
+ denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
+
+ masks_averaged = masks_summed / denom.clamp(min=1e-8)
+
+ # modulate stft repr with estimated mask
+
+ stft_repr = stft_repr * masks_averaged
+
+ # istft
+
+ stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
+
+ recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
+ length=istft_length)
+
+ recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
+
+ if num_stems == 1:
+ recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
+
+ # if a target is passed in, calculate loss for learning
+
+ if not exists(target):
+ return recon_audio
+
+ if self.num_stems > 1:
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
+
+ if target.ndim == 2:
+ target = rearrange(target, '... t -> ... 1 t')
+
+ target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
+
+ loss = F.l1_loss(recon_audio, target)
+
+ multi_stft_resolution_loss = 0.
+
+ for window_size in self.multi_stft_resolutions_window_sizes:
+ res_stft_kwargs = dict(
+ n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
+ win_length=window_size,
+ return_complex=True,
+ window=self.multi_stft_window_fn(window_size, device=device),
+ **self.multi_stft_kwargs,
+ )
+
+ recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
+ target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
+
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
+
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
+
+ total_loss = loss + weighted_multi_resolution_loss
+
+ if not return_loss_breakdown:
+ return total_loss
+
+ return total_loss, (loss, multi_stft_resolution_loss)
\ No newline at end of file
diff --git a/preprocess/tools/vocal_separation/utils/audio_utils.py b/preprocess/tools/vocal_separation/utils/audio_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d81280b12e0d566e0ee51431d2d66752b56f65bf
--- /dev/null
+++ b/preprocess/tools/vocal_separation/utils/audio_utils.py
@@ -0,0 +1,129 @@
+
+import numpy as np
+import os
+import soundfile as sf
+import matplotlib.pyplot as plt
+from typing import Dict, Tuple, Optional
+
+import torch.distributed as dist
+
+
+def read_audio_transposed(path: str, instr: Optional[str] = None, skip_err: bool = False) -> Tuple[Optional[np.ndarray], Optional[int]]:
+ """
+ Read an audio file and return transposed waveform data with channels first.
+
+ Loads the audio file from `path`, converts mono signals to 2D format, and
+ transposes the array so that its shape is (channels, length). In case of
+ errors, either raises an exception or skips gracefully depending on
+ `skip_err`.
+
+ Args:
+ path (str): Path to the audio file to load.
+ instr (Optional[str], optional): Instrument name, used for informative
+ messages when `skip_err` is True. Defaults to None.
+ skip_err (bool, optional): If True, skip files with read errors and
+ return `(None, None)` instead of raising. Defaults to False.
+
+ Returns:
+ Tuple[Optional[np.ndarray], Optional[int]]: A tuple containing:
+ - NumPy array of shape (channels, length), or None if skipped.
+ - Sampling rate as an integer, or None if skipped.
+ """
+
+ should_print = not dist.is_initialized() or dist.get_rank() == 0
+
+ try:
+ mix, sr = sf.read(path)
+ except Exception as e:
+ if skip_err:
+ if should_print:
+ print(f"No stem {instr}: skip!")
+ return None, None
+ else:
+ raise RuntimeError(f"Error reading the file at {path}: {e}")
+ else:
+ if len(mix.shape) == 1: # For mono audio
+ mix = np.expand_dims(mix, axis=-1)
+ return mix.T, sr
+
+
+def normalize_audio(audio: np.ndarray) -> Tuple[np.ndarray, Dict[str, float]]:
+ """
+ Normalize an audio signal using mean and standard deviation.
+
+ Computes the mean and standard deviation from the mono mix of the input
+ signal, then applies normalization to each channel.
+
+ Args:
+ audio (np.ndarray): Input audio array of shape (channels, time) or (time,).
+
+ Returns:
+ Tuple[np.ndarray, Dict[str, float]]: A tuple containing:
+ - Normalized audio with the same shape as the input.
+ - A dictionary with keys "mean" and "std" from the original audio.
+ """
+
+ mono = audio.mean(0)
+ mean, std = mono.mean(), mono.std()
+ return (audio - mean) / std, {"mean": mean, "std": std}
+
+
+def denormalize_audio(audio: np.ndarray, norm_params: Dict[str, float]) -> np.ndarray:
+ """
+ Reverse normalization on an audio signal.
+
+ Applies the stored mean and standard deviation to restore the original
+ scale of a previously normalized signal.
+
+ Args:
+ audio (np.ndarray): Normalized audio array to be denormalized.
+ norm_params (Dict[str, float]): Dictionary containing the keys
+ "mean" and "std" used during normalization.
+
+ Returns:
+ np.ndarray: Denormalized audio with the same shape as the input.
+ """
+
+ return audio * norm_params["std"] + norm_params["mean"]
+
+
+def draw_spectrogram(waveform: np.ndarray, sample_rate: int, length: float, output_file: str) -> None:
+ """
+ Generate and save a spectrogram image from an audio waveform.
+
+ Converts the provided waveform into a mono signal, computes its Short-Time
+ Fourier Transform (STFT), converts the amplitude spectrogram to dB scale,
+ and plots it using a plasma colormap.
+
+ Args:
+ waveform (np.ndarray): Input audio waveform array of shape (time, channels)
+ or (time,).
+ sample_rate (int): Sampling rate of the waveform in Hz.
+ length (float): Duration (in seconds) of the waveform to include in the
+ spectrogram.
+ output_file (str): Path to save the resulting spectrogram image.
+
+ Returns:
+ None
+ """
+
+ import librosa.display
+
+ # Cut only required part of spectorgram
+ x = waveform[:int(length * sample_rate), :]
+ X = librosa.stft(x.mean(axis=-1)) # perform short-term fourier transform on mono signal
+ Xdb = librosa.amplitude_to_db(np.abs(X), ref=np.max) # convert an amplitude spectrogram to dB-scaled spectrogram.
+ fig, ax = plt.subplots()
+ # plt.figure(figsize=(30, 10)) # initialize the fig size
+ img = librosa.display.specshow(
+ Xdb,
+ cmap='plasma',
+ sr=sample_rate,
+ x_axis='time',
+ y_axis='linear',
+ ax=ax
+ )
+ ax.set(title='File: ' + os.path.basename(output_file))
+ fig.colorbar(img, ax=ax, format="%+2.f dB")
+ if output_file is not None:
+ plt.savefig(output_file)
\ No newline at end of file
diff --git a/preprocess/tools/vocal_separation/utils/metrics.py b/preprocess/tools/vocal_separation/utils/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ba25918c11e86d8fbb29cdd178dd331e87f0efe
--- /dev/null
+++ b/preprocess/tools/vocal_separation/utils/metrics.py
@@ -0,0 +1,421 @@
+import numpy as np
+import torch
+import librosa
+import torch.nn.functional as F
+from typing import Dict, List, Tuple
+
+def sdr(references: np.ndarray, estimates: np.ndarray) -> float:
+ """
+ Compute Signal-to-Distortion Ratio (SDR) for one or more audio tracks.
+
+ SDR is a measure of how well the predicted source (estimate) matches the reference source.
+ It is calculated as the ratio of the energy of the reference signal to the energy of the error (difference between reference and estimate).
+ Return SDR in decibels (dB)
+ Parameters:
+ ----------
+ references : np.ndarray
+ A 3D numpy array of shape (num_sources, num_channels, num_samples), where num_sources is the number of sources,
+ num_channels is the number of channels (e.g., 1 for mono, 2 for stereo), and num_samples is the length of the audio signal.
+
+ estimates : np.ndarray
+ A 3D numpy array of shape (num_sources, num_channels, num_samples) representing the estimated sources.
+
+ Returns:
+ -------
+ np.ndarray
+ A 1D numpy array containing the SDR values for each source.
+ """
+ eps = 1e-8 # to avoid numerical errors
+ num = np.sum(np.square(references), axis=(1, 2))
+ den = np.sum(np.square(references - estimates), axis=(1, 2))
+ num += eps
+ den += eps
+ return 10 * np.log10(num / den)
+
+
+def si_sdr(reference: np.ndarray, estimate: np.ndarray) -> float:
+ """
+ Compute Scale-Invariant Signal-to-Distortion Ratio (SI-SDR) for one or more audio tracks.
+
+ SI-SDR is a variant of the SDR metric that is invariant to the scaling of the estimate relative to the reference.
+ It is calculated by scaling the estimate to match the reference signal and then computing the SDR.
+
+ Parameters:
+ ----------
+ reference : np.ndarray
+ A 3D numpy array of shape (num_sources, num_channels, num_samples), where num_sources is the number of sources,
+ num_channels is the number of channels (e.g., 1 for mono, 2 for stereo), and num_samples is the length of the audio signal.
+
+ estimate : np.ndarray
+ A 3D numpy array of shape (num_sources, num_channels, num_samples) representing the estimated sources.
+
+ Returns:
+ -------
+ float
+ The SI-SDR value for the source. It is a scalar representing the Signal-to-Distortion Ratio in decibels (dB).
+ """
+ eps = 1e-8 # To avoid numerical errors
+ scale = np.sum(estimate * reference + eps, axis=(0, 1)) / np.sum(reference ** 2 + eps, axis=(0, 1))
+ scale = np.expand_dims(scale, axis=(0, 1)) # Reshape to [num_sources, 1]
+
+ reference = reference * scale
+ si_sdr = np.mean(10 * np.log10(
+ np.sum(reference ** 2, axis=(0, 1)) / (np.sum((reference - estimate) ** 2, axis=(0, 1)) + eps) + eps))
+
+ return si_sdr
+
+
+def L1Freq_metric(
+ reference: np.ndarray,
+ estimate: np.ndarray,
+ fft_size: int = 2048,
+ hop_size: int = 1024,
+ device: str = 'cpu'
+) -> float:
+ """
+ Compute the L1 Frequency Metric between the reference and estimated audio signals.
+
+ This metric compares the magnitude spectrograms of the reference and estimated audio signals
+ using the Short-Time Fourier Transform (STFT) and calculates the L1 loss between them. The result
+ is scaled to the range [0, 100] where a higher value indicates better performance.
+
+ Parameters:
+ ----------
+ reference : np.ndarray
+ A 2D numpy array of shape (num_channels, num_samples) representing the reference (ground truth) audio signal.
+
+ estimate : np.ndarray
+ A 2D numpy array of shape (num_channels, num_samples) representing the estimated (predicted) audio signal.
+
+ fft_size : int, optional
+ The size of the FFT (Short-Time Fourier Transform). Default is 2048.
+
+ hop_size : int, optional
+ The hop size between STFT frames. Default is 1024.
+
+ device : str, optional
+ The device to run the computation on ('cpu' or 'cuda'). Default is 'cpu'.
+
+ Returns:
+ -------
+ float
+ The L1 Frequency Metric in the range [0, 100], where higher values indicate better performance.
+ """
+
+ reference = torch.from_numpy(reference).to(device)
+ estimate = torch.from_numpy(estimate).to(device)
+
+ reference_stft = torch.stft(reference, fft_size, hop_size, return_complex=True)
+ estimated_stft = torch.stft(estimate, fft_size, hop_size, return_complex=True)
+
+ reference_mag = torch.abs(reference_stft)
+ estimate_mag = torch.abs(estimated_stft)
+
+ loss = 10 * F.l1_loss(estimate_mag, reference_mag)
+
+ ret = 100 / (1. + float(loss.cpu().numpy()))
+
+ return ret
+
+
+def LogWMSE_metric(
+ reference: np.ndarray,
+ estimate: np.ndarray,
+ mixture: np.ndarray,
+ device: str = 'cpu',
+) -> float:
+ """
+ Calculate the Log-WMSE (Logarithmic Weighted Mean Squared Error) between the reference, estimate, and mixture signals.
+
+ This metric evaluates the quality of the estimated signal compared to the reference signal in the
+ context of audio source separation. The result is given in logarithmic scale, which helps in evaluating
+ signals with large amplitude differences.
+
+ Parameters:
+ ----------
+ reference : np.ndarray
+ The ground truth audio signal of shape (channels, time), where channels is the number of audio channels
+ (e.g., 1 for mono, 2 for stereo) and time is the length of the audio in samples.
+
+ estimate : np.ndarray
+ The estimated audio signal of shape (channels, time).
+
+ mixture : np.ndarray
+ The mixed audio signal of shape (channels, time).
+
+ device : str, optional
+ The device to run the computation on, either 'cpu' or 'cuda'. Default is 'cpu'.
+
+ Returns:
+ -------
+ float
+ The Log-WMSE value, which quantifies the difference between the reference and estimated signal on a logarithmic scale.
+ """
+ from torch_log_wmse import LogWMSE
+ log_wmse = LogWMSE(
+ audio_length=reference.shape[-1] / 44100, # audio length in seconds
+ sample_rate=44100, # sample rate of 44100 Hz
+ return_as_loss=False, # return as loss (False means return as metric)
+ bypass_filter=False, # bypass frequency filtering (False means apply filter)
+ )
+
+ reference = torch.from_numpy(reference).unsqueeze(0).unsqueeze(0).to(device)
+ estimate = torch.from_numpy(estimate).unsqueeze(0).unsqueeze(0).to(device)
+ mixture = torch.from_numpy(mixture).unsqueeze(0).to(device)
+
+ res = log_wmse(mixture, reference, estimate)
+ return float(res.cpu().numpy())
+
+
+def AuraSTFT_metric(
+ reference: np.ndarray,
+ estimate: np.ndarray,
+ device: str = 'cpu',
+) -> float:
+ """
+ Calculate the AuraSTFT metric, which evaluates the spectral difference between the reference and estimated
+ audio signals using Short-Time Fourier Transform (STFT) loss.
+
+ The AuraSTFT metric computes the STFT loss in both logarithmic and linear magnitudes, and it is commonly used
+ to assess the quality of audio separation tasks. The result is returned as a value scaled to the range [0, 100].
+
+ Parameters:
+ ----------
+ reference : np.ndarray
+ The ground truth audio signal of shape (channels, time), where channels is the number of audio channels
+ (e.g., 1 for mono, 2 for stereo) and time is the length of the audio in samples.
+
+ estimate : np.ndarray
+ The estimated audio signal of shape (channels, time).
+
+ device : str, optional
+ The device to run the computation on, either 'cpu' or 'cuda'. Default is 'cpu'.
+
+ Returns:
+ -------
+ float
+ The AuraSTFT metric value, scaled to the range [0, 100], which quantifies the difference between
+ the reference and estimated signal in the spectral domain.
+ """
+
+ from auraloss.freq import STFTLoss
+
+ stft_loss = STFTLoss(
+ w_log_mag=1.0, # weight for log magnitude
+ w_lin_mag=0.0, # weight for linear magnitude
+ w_sc=1.0, # weight for spectral centroid
+ device=device,
+ )
+
+ reference = torch.from_numpy(reference).unsqueeze(0).to(device)
+ estimate = torch.from_numpy(estimate).unsqueeze(0).to(device)
+
+ res = 100 / (1. + 10 * stft_loss(reference, estimate))
+ return float(res.cpu().numpy())
+
+
+def AuraMRSTFT_metric(
+ reference: np.ndarray,
+ estimate: np.ndarray,
+ device: str = 'cpu',
+) -> float:
+ """
+ Calculate the AuraMRSTFT metric, which evaluates the spectral difference between the reference and estimated
+ audio signals using Multi-Resolution Short-Time Fourier Transform (STFT) loss.
+
+ The AuraMRSTFT metric uses multi-resolution STFT analysis, which allows better representation of both
+ low- and high-frequency components in the audio signals. The result is returned as a value scaled to the range [0, 100].
+
+ Parameters:
+ ----------
+ reference : np.ndarray
+ The ground truth audio signal of shape (channels, time), where channels is the number of audio channels
+ (e.g., 1 for mono, 2 for stereo) and time is the length of the audio in samples.
+
+ estimate : np.ndarray
+ The estimated audio signal of shape (channels, time).
+
+ device : str, optional
+ The device to run the computation on, either 'cpu' or 'cuda'. Default is 'cpu'.
+
+ Returns:
+ -------
+ float
+ The AuraMRSTFT metric value, scaled to the range [0, 100], which quantifies the difference between
+ the reference and estimated signal in the multi-resolution spectral domain.
+ """
+
+ from auraloss.freq import MultiResolutionSTFTLoss
+
+ mrstft_loss = MultiResolutionSTFTLoss(
+ fft_sizes=[1024, 2048, 4096],
+ hop_sizes=[256, 512, 1024],
+ win_lengths=[1024, 2048, 4096],
+ scale="mel", # mel scale for frequency resolution
+ n_bins=128, # number of bins for mel scale
+ sample_rate=44100,
+ perceptual_weighting=True, # apply perceptual weighting
+ device=device
+ )
+
+ reference = torch.from_numpy(reference).unsqueeze(0).float().to(device)
+ estimate = torch.from_numpy(estimate).unsqueeze(0).float().to(device)
+
+ res = 100 / (1. + 10 * mrstft_loss(reference, estimate))
+ return float(res.cpu().numpy())
+
+
+def bleed_full(
+ reference: np.ndarray,
+ estimate: np.ndarray,
+ sr: int = 44100,
+ n_fft: int = 4096,
+ hop_length: int = 1024,
+ n_mels: int = 512,
+ device: str = 'cpu',
+) -> Tuple[float, float]:
+ """
+ Calculate the 'bleed' and 'fullness' metrics between a reference and an estimated audio signal.
+
+ The 'bleed' metric measures how much the estimated signal bleeds into the reference signal,
+ while the 'fullness' metric measures how much the estimated signal retains its distinctiveness
+ in relation to the reference signal, both using mel spectrograms and decibel scaling.
+
+ Parameters:
+ ----------
+ reference : np.ndarray
+ The reference audio signal, shape (channels, time), where channels is the number of audio channels
+ (e.g., 1 for mono, 2 for stereo) and time is the length of the audio in samples.
+
+ estimate : np.ndarray
+ The estimated audio signal, shape (channels, time).
+
+ sr : int, optional
+ The sample rate of the audio signals. Default is 44100 Hz.
+
+ n_fft : int, optional
+ The FFT size used to compute the STFT. Default is 4096.
+
+ hop_length : int, optional
+ The hop length for STFT computation. Default is 1024.
+
+ n_mels : int, optional
+ The number of mel frequency bins. Default is 512.
+
+ device : str, optional
+ The device for computation, either 'cpu' or 'cuda'. Default is 'cpu'.
+
+ Returns:
+ -------
+ tuple
+ A tuple containing two values:
+ - `bleedless` (float): A score indicating how much 'bleeding' the estimated signal has (higher is better).
+ - `fullness` (float): A score indicating how 'full' the estimated signal is (higher is better).
+ """
+
+ from torchaudio.transforms import AmplitudeToDB
+
+ reference = torch.from_numpy(reference).float().to(device)
+ estimate = torch.from_numpy(estimate).float().to(device)
+
+ window = torch.hann_window(n_fft).to(device)
+
+ # Compute STFTs with the Hann window
+ D1 = torch.abs(torch.stft(reference, n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True,
+ pad_mode="constant"))
+ D2 = torch.abs(torch.stft(estimate, n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True,
+ pad_mode="constant"))
+
+ mel_basis = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels)
+ mel_filter_bank = torch.from_numpy(mel_basis).to(device)
+
+ S1_mel = torch.matmul(mel_filter_bank, D1)
+ S2_mel = torch.matmul(mel_filter_bank, D2)
+
+ S1_db = AmplitudeToDB(stype="magnitude", top_db=80)(S1_mel)
+ S2_db = AmplitudeToDB(stype="magnitude", top_db=80)(S2_mel)
+
+ diff = S2_db - S1_db
+
+ positive_diff = diff[diff > 0]
+ negative_diff = diff[diff < 0]
+
+ average_positive = torch.mean(positive_diff) if positive_diff.numel() > 0 else torch.tensor(0.0).to(device)
+ average_negative = torch.mean(negative_diff) if negative_diff.numel() > 0 else torch.tensor(0.0).to(device)
+
+ bleedless = 100 * 1 / (average_positive + 1)
+ fullness = 100 * 1 / (-average_negative + 1)
+
+ return bleedless.cpu().numpy(), fullness.cpu().numpy()
+
+
+def get_metrics(
+ metrics: List[str],
+ reference: np.ndarray,
+ estimate: np.ndarray,
+ mix: np.ndarray,
+ device: str = 'cpu',
+) -> Dict[str, float]:
+ """
+ Calculate a list of metrics to evaluate the performance of audio source separation models.
+
+ The function computes the specified metrics based on the reference, estimate, and mixture.
+
+ Parameters:
+ ----------
+ metrics : List[str]
+ A list of metric names to compute (e.g., ['sdr', 'si_sdr', 'l1_freq']).
+
+ reference : np.ndarray
+ The reference audio (true signal) with shape (channels, length).
+
+ estimate : np.ndarray
+ The estimated audio (predicted signal) with shape (channels, length).
+
+ mix : np.ndarray
+ The mixed audio signal with shape (channels, length).
+
+ device : str, optional, default='cpu'
+ The device ('cpu' or 'cuda') to perform the calculations on.
+
+ Returns:
+ -------
+ Dict[str, float]
+ A dictionary containing the computed metric values.
+ """
+ result = dict()
+
+ # Adjust the length to be the same across all inputs
+ min_length = min(reference.shape[1], estimate.shape[1])
+ reference = reference[..., :min_length]
+ estimate = estimate[..., :min_length]
+ mix = mix[..., :min_length]
+
+ if 'sdr' in metrics:
+ references = np.expand_dims(reference, axis=0)
+ estimates = np.expand_dims(estimate, axis=0)
+ result['sdr'] = float(sdr(references, estimates))
+
+ if 'si_sdr' in metrics:
+ result['si_sdr'] = float(si_sdr(reference, estimate))
+
+ if 'l1_freq' in metrics:
+ result['l1_freq'] = L1Freq_metric(reference, estimate, device=device)
+
+ if 'log_wmse' in metrics:
+ result['log_wmse'] = LogWMSE_metric(reference, estimate, mix, device)
+
+ if 'aura_stft' in metrics:
+ result['aura_stft'] = AuraSTFT_metric(reference, estimate, device)
+
+ if 'aura_mrstft' in metrics:
+ result['aura_mrstft'] = AuraMRSTFT_metric(reference, estimate, device)
+
+ if 'bleedless' in metrics or 'fullness' in metrics:
+ bleedless, fullness = bleed_full(reference, estimate, device=device)
+ if 'bleedless' in metrics:
+ result['bleedless'] = float(bleedless)
+ if 'fullness' in metrics:
+ result['fullness'] = float(fullness)
+
+ return result
diff --git a/preprocess/tools/vocal_separation/utils/model_utils.py b/preprocess/tools/vocal_separation/utils/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4149dec1385d43673d87e197c47a6acea9cd88f
--- /dev/null
+++ b/preprocess/tools/vocal_separation/utils/model_utils.py
@@ -0,0 +1,777 @@
+# coding: utf-8
+__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+from ml_collections import ConfigDict
+from torch.optim import Adam, AdamW, SGD, RAdam, RMSprop
+from tqdm.auto import tqdm
+from typing import Dict, List, Tuple, Any, Union, Optional
+import loralib as lora
+from .muon import SingleDeviceMuonWithAuxAdam
+import torch.distributed as dist
+
+def demix(
+ config: ConfigDict,
+ model: torch.nn.Module,
+ mix: torch.Tensor,
+ device: torch.device,
+ model_type: str,
+ pbar: bool = False
+) -> Union[Dict[str, np.ndarray], np.ndarray]:
+ """
+ Perform audio source separation with a given model.
+
+ Supports both Demucs-specific and generic processing modes, including
+ overlapping chunk-based inference with optional progress bar display.
+ Handles padding, fading, and batching to reduce artifacts during separation.
+
+ Args:
+ config (ConfigDict): Configuration object with audio and inference
+ parameters (chunk size, overlap, batch size, etc.).
+ model (torch.nn.Module): Source separation model for inference.
+ mix (torch.Tensor): Input audio tensor of shape (channels, time).
+ device (torch.device): Device on which to run inference (CPU or CUDA).
+ model_type (str): Type of model (e.g., 'htdemucs', 'mdx23c') that
+ determines processing mode.
+ pbar (bool, optional): If True, show a progress bar during chunk
+ processing. Defaults to False.
+
+ Returns:
+ Union[Dict[str, np.ndarray], np.ndarray]:
+ - Dictionary mapping instrument names to separated waveforms if
+ multiple instruments are predicted.
+ - NumPy array of separated audio if only a single instrument is
+ present (Demucs mode).
+ """
+
+ should_print = not dist.is_initialized() or dist.get_rank() == 0
+
+ mix = torch.tensor(mix, dtype=torch.float32)
+
+ if model_type == 'htdemucs':
+ mode = 'demucs'
+ else:
+ mode = 'generic'
+ # Define processing parameters based on the mode
+ if mode == 'demucs':
+ chunk_size = config.training.samplerate * config.training.segment
+ num_instruments = len(config.training.instruments)
+ num_overlap = config.inference.num_overlap
+ step = chunk_size // num_overlap
+ else:
+ if 'chunk_size' in config.inference:
+ chunk_size = config.inference.chunk_size
+ else:
+ chunk_size = config.audio.chunk_size
+ num_instruments = len(prefer_target_instrument(config))
+ num_overlap = config.inference.num_overlap
+
+ fade_size = chunk_size // 10
+ step = chunk_size // num_overlap
+ border = chunk_size - step
+ length_init = mix.shape[-1]
+ windowing_array = _getWindowingArray(chunk_size, fade_size)
+ # Add padding for generic mode to handle edge artifacts
+ if length_init > 2 * border and border > 0:
+ mix = nn.functional.pad(mix, (border, border), mode="reflect")
+
+ batch_size = config.inference.batch_size
+
+ use_amp = getattr(config.training, 'use_amp', True)
+
+ with torch.cuda.amp.autocast(enabled=use_amp):
+ with torch.inference_mode():
+ # Initialize result and counter tensors
+ req_shape = (num_instruments,) + mix.shape
+ result = torch.zeros(req_shape, dtype=torch.float32)
+ counter = torch.zeros(req_shape, dtype=torch.float32)
+
+ i = 0
+ batch_data = []
+ batch_locations = []
+ if pbar and should_print:
+ progress_bar = tqdm(
+ total=mix.shape[1], desc="Processing audio chunks", leave=False
+ )
+ else:
+ progress_bar = None
+
+ while i < mix.shape[1]:
+ # Extract chunk and apply padding if necessary
+ part = mix[:, i:i + chunk_size].to(device)
+ chunk_len = part.shape[-1]
+ if mode == "generic" and chunk_len > chunk_size // 2:
+ pad_mode = "reflect"
+ else:
+ pad_mode = "constant"
+ part = nn.functional.pad(part, (0, chunk_size - chunk_len), mode=pad_mode, value=0)
+
+ batch_data.append(part)
+ batch_locations.append((i, chunk_len))
+ i += step
+
+ # Process batch if it's full or the end is reached
+ if len(batch_data) >= batch_size or i >= mix.shape[1]:
+ arr = torch.stack(batch_data, dim=0)
+ x = model(arr)
+
+ if mode == "generic":
+ window = windowing_array.clone() # using clone() fixes the clicks at chunk edges when using batch_size=1
+ if i - step == 0: # First audio chunk, no fadein
+ window[:fade_size] = 1
+ elif i >= mix.shape[1]: # Last audio chunk, no fadeout
+ window[-fade_size:] = 1
+
+ for j, (start, seg_len) in enumerate(batch_locations):
+ if mode == "generic":
+ result[..., start:start + seg_len] += x[j, ..., :seg_len].cpu() * window[..., :seg_len]
+ counter[..., start:start + seg_len] += window[..., :seg_len]
+ else:
+ result[..., start:start + seg_len] += x[j, ..., :seg_len].cpu()
+ counter[..., start:start + seg_len] += 1.0
+
+ batch_data.clear()
+ batch_locations.clear()
+
+ if progress_bar:
+ progress_bar.update(step)
+
+ if progress_bar:
+ progress_bar.close()
+
+
+ """
+ # mix: B, 2, T
+ # req_shape = (num_instruments,) + mix.shape
+ req_shape = (num_instruments,) + mix.shape
+ result = torch.zeros(req_shape, dtype=torch.float32)
+ counter = torch.zeros(req_shape, dtype=torch.float32)
+
+ # prev_i = 0
+ i = 0
+ batch_data = []
+ batch_locations = []
+
+ while i < mix.shape[-1]:
+ part = mix[:, :, i:i + chunk_size].to(device)
+ chunk_len = part.shape[-1]
+ if mode == "generic" and chunk_len > chunk_size // 2:
+ pad_mode = "reflect"
+ else:
+ pad_mode = "constant"
+ part = nn.functional.pad(part, (0, chunk_size - chunk_len), mode=pad_mode, value=0)
+ # batch_locations.append((i, chunk_len))
+ # prev_i = i
+ batch_location = i, i + chunk_len
+ i += step
+
+ # print(part.shape)
+ x = model(part)
+ x = x.transpose(0, 1)
+ # print(x.shape)
+
+ if mode == "generic":
+ window = windowing_array.clone() # using clone() fixes the clicks at chunk edges when using batch_size=1
+ if i - step == 0: # First audio chunk, no fadein
+ window[:fade_size] = 1
+ elif i >= mix.shape[1]: # Last audio chunk, no fadeout
+ window[-fade_size:] = 1
+
+ # for j, (start, seg_len) in enumerate(batch_locations):
+ # l = chunk_len if chunk_len < chunk_size else chunk_size
+ # print(l, x.shape, result.shape, counter.shape, window.shape)
+ # print(result[..., batch_location[0]: batch_location[1]].shape, x[..., :chunk_len].cpu().shape, window[..., :chunk_len].shape)
+ if mode == "generic":
+ result[..., batch_location[0]: batch_location[1]] += x[..., :chunk_len].cpu() * window[..., :chunk_len]
+ counter[..., batch_location[0]: batch_location[1]] += window[..., :chunk_len]
+ else:
+ result[..., batch_location[0]: batch_location[1]] += x[..., :chunk_len].cpu()
+ counter[..., batch_location[0]: batch_location[1]] += 1.0
+
+ batch_data.clear()
+ batch_locations.clear()
+ """
+ # Compute final estimated sources
+ estimated_sources = result / counter
+ estimated_sources = estimated_sources.cpu().numpy()
+ np.nan_to_num(estimated_sources, copy=False, nan=0.0)
+
+ # Remove padding for generic mode
+ if mode == "generic":
+ if length_init > 2 * border and border > 0:
+ estimated_sources = estimated_sources[..., border:-border]
+
+ # Return the result as a dictionary or a single array
+ if mode == "demucs":
+ instruments = config.training.instruments
+ else:
+ instruments = prefer_target_instrument(config)
+
+ ret_data = {k: v for k, v in zip(instruments, estimated_sources)}
+
+ if mode == "demucs" and num_instruments <= 1:
+ return estimated_sources
+ else:
+ return ret_data
+
+
+def initialize_model_and_device(model: torch.nn.Module, device_ids: List[int]) -> Tuple[Union[torch.device, str], torch.nn.Module]:
+ """
+ Move a model to the correct computation device and wrap with DataParallel if needed.
+
+ Selects GPU(s) if CUDA is available; otherwise defaults to CPU. If multiple
+ GPU IDs are provided, wraps the model with `nn.DataParallel` for multi-GPU
+ execution.
+
+ Args:
+ model (torch.nn.Module): PyTorch model to be initialized.
+ device_ids (List[int]): List of GPU device IDs to use. If length > 1,
+ the model will be wrapped with DataParallel.
+
+ Returns:
+ Tuple[Union[torch.device, str], torch.nn.Module]: A tuple containing:
+ - The computation device (`torch.device` or "cpu").
+ - The model moved to that device (wrapped in DataParallel if applicable).
+ """
+
+ if torch.cuda.is_available():
+ if len(device_ids) <= 1:
+ device = torch.device(f'cuda:{device_ids[0]}')
+ model = model.to(device)
+ else:
+ device = torch.device(f'cuda:{device_ids[0]}')
+ model = nn.DataParallel(model, device_ids=device_ids).to(device)
+ else:
+ device = 'cpu'
+ model = model.to(device)
+ print("CUDA is not available. Running on CPU.")
+
+ return device, model
+
+
+def get_optimizer(config: ConfigDict, model: torch.nn.Module) -> torch.optim.Optimizer:
+ """
+ Create and configure an optimizer for training.
+
+ Selects the optimizer type based on `config.training.optimizer` and applies
+ the corresponding parameters, including support for advanced optimizers
+ such as Muon, Prodigy, and 8-bit AdamW. Handles parameter group separation
+ for specialized optimizers (e.g., Muon vs. Adam parameters).
+
+ Args:
+ config (ConfigDict): Training configuration containing optimizer type,
+ learning rate, and optional optimizer-specific parameters.
+ model (torch.nn.Module): Model whose parameters will be optimized.
+
+ Returns:
+ torch.optim.Optimizer: Initialized optimizer ready for training.
+
+ Raises:
+ ValueError: If required optimizer configuration is missing (e.g., for Muon).
+ SystemExit: If an unknown optimizer name is encountered.
+ """
+
+ should_print = not dist.is_initialized() or dist.get_rank() == 0
+ optim_params = dict()
+ if 'optimizer' in config:
+ optim_params = dict(config['optimizer'])
+ if config.training.optimizer != 'muon' and should_print:
+ print(f'Optimizer params from config:\n{optim_params}')
+
+ name_optimizer = getattr(config.training, 'optimizer',
+ 'No optimizer in config')
+
+ if name_optimizer == 'adam':
+ optimizer = Adam(model.parameters(), lr=config.training.lr, **optim_params)
+ elif name_optimizer == 'adamw':
+ optimizer = AdamW(model.parameters(), lr=config.training.lr, **optim_params)
+ elif name_optimizer == 'radam':
+ optimizer = RAdam(model.parameters(), lr=config.training.lr, **optim_params)
+ elif name_optimizer == 'rmsprop':
+ optimizer = RMSprop(model.parameters(), lr=config.training.lr, **optim_params)
+ elif name_optimizer == 'prodigy':
+ from prodigyopt import Prodigy
+ # you can choose weight decay value based on your problem, 0 by default
+ # We recommend using lr=1.0 (default) for all networks.
+ optimizer = Prodigy(model.parameters(), lr=config.training.lr, **optim_params)
+ elif name_optimizer == 'adamw8bit':
+ import bitsandbytes as bnb
+ optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.training.lr, **optim_params)
+ elif name_optimizer == 'muon':
+ if should_print:
+ print("Using Muon optimizer (Single-Device) with AdamW for auxiliary parameters.")
+
+ muon_params = [p for p in model.parameters() if p.ndim >= 2]
+ adam_params = [p for p in model.parameters() if p.ndim < 2]
+
+ if not hasattr(config, 'optimizer') or 'muon_group' not in config.optimizer or 'adam_group' not in config.optimizer:
+ raise ValueError("For the 'muon' optimizer, the config must have an 'optimizer' section "
+ "with 'muon_group' and 'adam_group' dictionaries.")
+
+ muon_group_config = dict(config.optimizer.muon_group)
+ adam_group_config = dict(config.optimizer.adam_group)
+
+ if should_print:
+ print(f"Muon group params: {muon_group_config}")
+ print(f"Adam group params: {adam_group_config}")
+
+ param_groups = [
+ dict(params=muon_params, use_muon=True, **muon_group_config),
+ dict(params=adam_params, use_muon=False, **adam_group_config),
+ ]
+ optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
+ elif name_optimizer == 'sgd':
+ if should_print:
+ print('Use SGD optimizer')
+ optimizer = SGD(model.parameters(), lr=config.training.lr, **optim_params)
+ else:
+ if should_print:
+ print(f'Unknown optimizer: {name_optimizer}')
+ exit()
+ return optimizer
+
+
+def normalize_batch(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply mean-variance normalization to a pair of tensors.
+
+ Computes the mean and standard deviation from `x` and normalizes both `x`
+ and `y` using those statistics. This ensures the two tensors are scaled
+ consistently.
+
+ Args:
+ x (torch.Tensor): Input tensor used to compute normalization statistics.
+ y (torch.Tensor): Input tensor normalized using the same statistics as `x`.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Normalized tensors `(x, y)`.
+ """
+
+ mean = x.mean()
+ std = x.std()
+ if std != 0:
+ x = (x - mean) / std
+ y = (y - mean) / std
+ return x, y
+
+
+def apply_tta(
+ config,
+ model: torch.nn.Module,
+ mix: torch.Tensor,
+ waveforms_orig: Dict[str, torch.Tensor],
+ device: torch.device,
+ model_type: str
+) -> Dict[str, torch.Tensor]:
+ """
+ Enhance source separation results using Test-Time Augmentation (TTA).
+
+ Applies augmentations such as channel reversal and polarity inversion to
+ the input mixture, reprocesses with the model, and combines the results
+ with the original predictions by averaging.
+
+ Args:
+ config: Configuration object with model and inference parameters.
+ model (torch.nn.Module): Trained source separation model.
+ mix (torch.Tensor): Input mixture tensor of shape (channels, time).
+ waveforms_orig (Dict[str, torch.Tensor]): Dictionary of separated
+ sources before augmentation.
+ device (torch.device): Computation device (CPU or CUDA).
+ model_type (str): Model type identifier used for demixing.
+
+ Returns:
+ Dict[str, torch.Tensor]: Dictionary of separated sources after applying TTA.
+ """
+
+ # Create augmentations: channel inversion and polarity inversion
+ track_proc_list = [mix[::-1].copy(), -1.0 * mix.copy()]
+
+ # Process each augmented mixture
+ for i, augmented_mix in enumerate(track_proc_list):
+ waveforms = demix(config, model, augmented_mix, device, model_type=model_type)
+ for el in waveforms:
+ if i == 0:
+ waveforms_orig[el] += waveforms[el][::-1].copy()
+ else:
+ waveforms_orig[el] -= waveforms[el]
+
+ # Average the results across augmentations
+ for el in waveforms_orig:
+ waveforms_orig[el] /= len(track_proc_list) + 1
+
+ return waveforms_orig
+
+
+def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor:
+ """
+ Generate a windowing array with a linear fade-in at the beginning and a fade-out at the end.
+
+ This function creates a window of size `window_size` where the first `fade_size` elements
+ linearly increase from 0 to 1 (fade-in) and the last `fade_size` elements linearly decrease
+ from 1 to 0 (fade-out). The middle part of the window is filled with ones.
+
+ Parameters:
+ ----------
+ window_size : int
+ The total size of the window.
+ fade_size : int
+ The size of the fade-in and fade-out regions.
+
+ Returns:
+ -------
+ torch.Tensor
+ A tensor of shape (window_size,) containing the generated windowing array.
+
+ Example:
+ -------
+ If `window_size=10` and `fade_size=3`, the output will be:
+ tensor([0.0000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.5000, 0.0000])
+ """
+
+ fadein = torch.linspace(0, 1, fade_size)
+ fadeout = torch.linspace(1, 0, fade_size)
+
+ window = torch.ones(window_size)
+ window[-fade_size:] = fadeout
+ window[:fade_size] = fadein
+ return window
+
+
+def prefer_target_instrument(config: ConfigDict) -> List[str]:
+ """
+ Return the list of target instruments based on the configuration.
+ If a specific target instrument is specified in the configuration,
+ it returns a list with that instrument. Otherwise, it returns the list of instruments.
+
+ Parameters:
+ ----------
+ config : ConfigDict
+ Configuration object containing the list of instruments or the target instrument.
+
+ Returns:
+ -------
+ List[str]
+ A list of target instruments.
+ """
+ if getattr(config.training, 'target_instrument', None):
+ return [config.training.target_instrument]
+ else:
+ return config.training.instruments
+
+
+def load_not_compatible_weights(model: torch.nn.Module, old_model: dict, verbose: bool = False) -> None:
+ """
+ Load a possibly incompatible state dict into `model` with best-effort matching.
+
+ Accepts either a raw state_dict or a checkpoint dict with weights under "state" or "state_dict".
+ For each param/buffer in `model`: if the name exists and shapes match → copy;
+ if ndim matches but shapes differ → zero-pad/crop the source to fit the target;
+ if the name is missing or ndim differs → skip. Optional logging on rank 0 when `verbose=True`.
+
+ Args:
+ model: Target PyTorch module.
+ old_model: Source weights (state_dict or checkpoint dict).
+ verbose: Print brief load decisions.
+
+ Returns:
+ None
+ """
+
+ should_print = verbose and (not dist.is_initialized() or dist.get_rank() == 0)
+
+ new_model = model.state_dict()
+
+ if 'state' in old_model:
+ # Fix for htdemucs weights loading
+ old_model = old_model['state']
+ if 'state_dict' in old_model:
+ # Fix for apollo weights loading
+ old_model = old_model['state_dict']
+ if 'model_state_dict' in old_model:
+ # Fix for full_check_point
+ old_model = old_model['model_state_dict']
+
+ for el in new_model:
+ if el in old_model:
+ if should_print:
+ print(f'Match found for {el}!')
+ if new_model[el].shape == old_model[el].shape:
+ if should_print:
+ print('Action: Just copy weights!')
+ new_model[el] = old_model[el]
+ else:
+ if len(new_model[el].shape) != len(old_model[el].shape) and should_print:
+ print('Action: Different dimension! Too lazy to write the code... Skip it')
+ else:
+ if should_print:
+ print(f'Shape is different: {tuple(new_model[el].shape)} != {tuple(old_model[el].shape)}')
+ ln = len(new_model[el].shape)
+ max_shape = []
+ slices_old = []
+ slices_new = []
+ for i in range(ln):
+ max_shape.append(max(new_model[el].shape[i], old_model[el].shape[i]))
+ slices_old.append(slice(0, old_model[el].shape[i]))
+ slices_new.append(slice(0, new_model[el].shape[i]))
+ # print(max_shape)
+ # print(slices_old, slices_new)
+ slices_old = tuple(slices_old)
+ slices_new = tuple(slices_new)
+ max_matrix = np.zeros(max_shape, dtype=np.float32)
+ for i in range(ln):
+ max_matrix[slices_old] = old_model[el].cpu().numpy()
+ max_matrix = torch.from_numpy(max_matrix)
+ new_model[el] = max_matrix[slices_new]
+ else:
+ if should_print:
+ print(f'Match not found for {el}!')
+ model.load_state_dict(
+ new_model
+ )
+
+
+def load_lora_weights(model: torch.nn.Module, lora_path: str, device: str = 'cpu') -> None:
+ """
+ Load LoRA weights into a model.
+ This function updates the given model with LoRA-specific weights from the specified checkpoint file.
+ It does not require the checkpoint to match the model's full state dictionary, as only LoRA layers are updated.
+
+ Parameters:
+ ----------
+ model : Module
+ The PyTorch model into which the LoRA weights will be loaded.
+ lora_path : str
+ Path to the LoRA checkpoint file.
+ device : str, optional
+ The device to load the weights onto, by default 'cpu'. Common values are 'cpu' or 'cuda'.
+
+ Returns:
+ -------
+ None
+ The model is updated in place.
+ """
+ lora_state_dict = torch.load(lora_path, map_location=device)
+ model.load_state_dict(lora_state_dict, strict=False)
+
+
+def load_start_checkpoint(args: argparse.Namespace,
+ model: torch.nn.Module,
+ old_model: None,
+ type_: str = 'train') -> None:
+ """
+ Load an initial checkpoint into `model`.
+
+ For `type_ == "train"`, performs a tolerant load using `old_model` (a state dict or a
+ checkpoint dict) via `load_not_compatible_weights`, allowing partial shape mismatches.
+ For other modes, loads a strict state dict from `args.start_check_point`, with special
+ handling for HTDemucs/Apollo checkpoints (keys under "state"/"state_dict"). If
+ `args.lora_checkpoint` is set, LoRA weights are applied after the base load.
+
+ Args:
+ args: Namespace with at least `start_check_point`, `model_type`, and optionally `lora_checkpoint`.
+ model: Target PyTorch module to receive weights.
+ old_model: Source weights for tolerant loading in train mode (state dict or checkpoint dict).
+ type_: Loading strategy; "train" uses tolerant loading, otherwise strict loading from path.
+
+ Returns:
+ None
+ """
+ should_print = not dist.is_initialized() or dist.get_rank() == 0
+
+ if should_print:
+ print(f'Start from checkpoint: {args.start_check_point}')
+ if type_ in ['train']:
+ if 1:
+ load_not_compatible_weights(model, old_model, verbose=False)
+ else:
+ model.load_state_dict(torch.load(args.start_check_point))
+ else:
+ device='cpu'
+ if args.model_type in ['htdemucs', 'apollo']:
+ state_dict = torch.load(args.start_check_point, map_location=device, weights_only=False)
+ # Fix for htdemucs pretrained models
+ if 'state' in state_dict:
+ state_dict = state_dict['state']
+ # Fix for apollo pretrained models
+ if 'state_dict' in state_dict:
+ state_dict = state_dict['state_dict']
+ else:
+ state_dict = torch.load(args.start_check_point, map_location=device, weights_only=True)
+ model.load_state_dict(state_dict)
+
+ if args.lora_checkpoint:
+ if should_print:
+ print(f"Loading LoRA weights from: {args.lora_checkpoint}")
+ load_lora_weights(model, args.lora_checkpoint)
+
+
+def bind_lora_to_model(config: Dict[str, Any], model: nn.Module) -> nn.Module:
+ """
+ Replaces specific layers in the model with LoRA-extended versions.
+
+ Parameters:
+ ----------
+ config : Dict[str, Any]
+ Configuration containing parameters for LoRA. It should include a 'lora' key with parameters for `MergedLinear`.
+ model : nn.Module
+ The original model in which the layers will be replaced.
+
+ Returns:
+ -------
+ nn.Module
+ The modified model with the replaced layers.
+ """
+
+ if 'lora' not in config:
+ raise ValueError("Configuration must contain the 'lora' key with parameters for LoRA.")
+
+ replaced_layers = 0 # Counter for replaced layers
+ should_print = not dist.is_initialized() or dist.get_rank() == 0
+
+ for name, module in model.named_modules():
+ hierarchy = name.split('.')
+ layer_name = hierarchy[-1]
+
+ # Check if this is the target layer to replace (and layer_name == 'to_qkv')
+ if isinstance(module, nn.Linear):
+ try:
+ # Get the parent module
+ parent_module = model
+ for submodule_name in hierarchy[:-1]:
+ parent_module = getattr(parent_module, submodule_name)
+
+ # Replace the module with LoRA-enabled layer
+ setattr(
+ parent_module,
+ layer_name,
+ lora.MergedLinear(
+ in_features=module.in_features,
+ out_features=module.out_features,
+ bias=module.bias is not None,
+ **config['lora']
+ )
+ )
+ replaced_layers += 1 # Increment the counter
+
+ except Exception as e:
+ if should_print:
+ print(f"Error replacing layer {name}: {e}")
+
+ if replaced_layers == 0 and should_print:
+ print("Warning: No layers were replaced. Check the model structure and configuration.")
+ elif should_print:
+ print(f"Number of layers replaced with LoRA: {replaced_layers}")
+
+ return model
+
+
+def save_weights(
+ store_path: str,
+ model: nn.Module,
+ device_ids: List[int],
+ optimizer: torch.optim.Optimizer,
+ epoch: int,
+ all_time_all_metrics,
+ best_metric: float,
+ scheduler: Optional[torch.optim.lr_scheduler.ReduceLROnPlateau] = None,
+ train_lora: bool = False
+) -> None:
+ """
+ Save a training checkpoint containing model weights, optimizer/scheduler states, and metadata.
+
+ Behavior:
+ - In Distributed Data Parallel (DDP), only rank 0 writes the file to avoid conflicts.
+ - If `train_lora` is True, saves only LoRA adapter weights (`lora_state_dict`); otherwise saves the full model.
+ - Uses `model.module.state_dict()` when the model is wrapped by DDP/DataParallel.
+ - Stores `epoch` and `best_metric` alongside optimizer/scheduler states.
+
+ Args:
+ store_path: Destination file path for the checkpoint (will be overwritten).
+ model: The model whose weights are being saved (may be wrapped by DDP/DataParallel).
+ device_ids: List of GPU device IDs used during training (used to detect DP wrapping in non-DDP runs).
+ optimizer: Optimizer whose state will be saved.
+ epoch: Current training epoch to record in the checkpoint.
+ all_time_all_metrics:
+ best_metric: Best validation metric achieved so far.
+ scheduler: Optional learning rate scheduler; its state is saved if provided.
+ train_lora: If True, save only LoRA adapter weights instead of the full model.
+
+ Returns:
+ None
+ """
+
+ checkpoint: Dict[str, Any] = {
+ "epoch": epoch,
+ "optimizer_name": optimizer.__class__.__name__,
+ "optimizer_state_dict": optimizer.state_dict(),
+ "scheduler_state_dict": scheduler.state_dict() if scheduler else None,
+ "best_metric": best_metric,
+ "all_metrics": all_time_all_metrics
+ }
+
+ # Save model weights
+ if train_lora:
+ checkpoint["model_state_dict"] = lora.lora_state_dict(model)
+ else:
+ if dist.is_initialized():
+ # In DDP, use .module
+ checkpoint["model_state_dict"] = model.module.state_dict()
+ else:
+ checkpoint["model_state_dict"] = (
+ model.state_dict() if len(device_ids) <= 1 else model.module.state_dict()
+ )
+
+ # Save only on rank 0 (or if not using DDP)
+ if not dist.is_initialized() or dist.get_rank() == 0:
+ torch.save(checkpoint, store_path)
+
+
+def save_last_weights(
+ args: argparse.Namespace,
+ model: nn.Module,
+ device_ids: List[int],
+ optimizer: torch.optim.Optimizer,
+ epoch: int,
+ all_time_all_metrics,
+ best_metric: float,
+ scheduler: Optional[torch.optim.lr_scheduler.ReduceLROnPlateau] = None,
+) -> None:
+ """
+ Save the latest training checkpoint for continuation or recovery.
+
+ The checkpoint is always written to:
+ {args.results_path}/last_{args.model_type}.ckpt
+
+ This wraps `save_weights` and ensures the latest model/optimizer/scheduler
+ states are recorded, along with the current epoch and best metric. In DDP,
+ only rank 0 performs the save. Supports both standard and LoRA training.
+
+ Args:
+ all_time_all_metrics:
+ args: Training arguments. Must define `results_path`, `model_type`,
+ and `train_lora`.
+ model: Model instance (may be wrapped by DDP/DataParallel).
+ device_ids: List of GPU IDs used for training.
+ optimizer: Optimizer whose state will be saved.
+ epoch: Current training epoch.
+ best_metric: Current best validation metric.
+ scheduler: Optional learning rate scheduler to save state for.
+
+ Returns:
+ None
+ """
+ store_path = f"{args.results_path}/last_{args.model_type}.ckpt"
+ save_weights(
+ store_path,
+ model,
+ device_ids,
+ optimizer,
+ epoch,
+ all_time_all_metrics,
+ best_metric,
+ scheduler,
+ args.train_lora,
+ )
diff --git a/preprocess/tools/vocal_separation/utils/muon.py b/preprocess/tools/vocal_separation/utils/muon.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f117329d5f1d646749a2ec6b80d434610104c4a
--- /dev/null
+++ b/preprocess/tools/vocal_separation/utils/muon.py
@@ -0,0 +1,286 @@
+import torch
+import torch.distributed as dist
+
+
+def zeropower_via_newtonschulz5(G, steps: int):
+ """
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
+ performance at all relative to UV^T, where USV^T = G is the SVD.
+ """
+ assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
+ a, b, c = (3.4445, -4.7750, 2.0315)
+ X = G.bfloat16()
+ if G.size(-2) > G.size(-1):
+ X = X.mT
+
+ # Ensure spectral norm is at most 1
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
+ # Perform the NS iterations
+ for _ in range(steps):
+ A = X @ X.mT
+ B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
+ X = a * X + B @ X
+
+ if G.size(-2) > G.size(-1):
+ X = X.mT
+ return X
+
+
+def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
+ momentum.lerp_(grad, 1 - beta)
+ update = grad.lerp_(momentum, beta) if nesterov else momentum
+ if update.ndim == 4: # for the case of conv filters
+ update = update.view(len(update), -1)
+ update = zeropower_via_newtonschulz5(update, steps=ns_steps)
+ update *= max(1, grad.size(-2) / grad.size(-1))**0.5
+ return update
+
+
+class Muon(torch.optim.Optimizer):
+ """
+ Muon - MomentUm Orthogonalized by Newton-schulz
+
+ https://kellerjordan.github.io/posts/muon/
+
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
+ matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the
+ advantage that it can be stably run in bfloat16 on the GPU.
+
+ Muon should only be used for hidden weight layers. The input embedding, final output layer,
+ and any internal gains or biases should be optimized using a standard method such as AdamW.
+ Hidden convolutional weights can be trained using Muon by viewing them as 2D and then
+ collapsing their last 3 dimensions.
+
+ Arguments:
+ lr: The learning rate, in units of spectral norm per update.
+ weight_decay: The AdamW-style weight decay.
+ momentum: The momentum. A value of 0.95 here is usually fine.
+ """
+ def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
+ defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
+ assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
+ params = sorted(params, key=lambda x: x.size(), reverse=True)
+ super().__init__(params, defaults)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ params = group["params"]
+ params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
+ for base_i in range(len(params))[::dist.get_world_size()]:
+ if base_i + dist.get_rank() < len(params):
+ p = params[base_i + dist.get_rank()]
+ if p.grad is None:
+ # continue
+ p.grad = torch.zeros_like(p) # Force synchronization
+ state = self.state[p]
+ if len(state) == 0:
+ state["momentum_buffer"] = torch.zeros_like(p)
+ update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
+ p.mul_(1 - group["lr"] * group["weight_decay"])
+ p.add_(update.reshape(p.shape), alpha=-group["lr"])
+ dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
+
+ return loss
+
+
+class SingleDeviceMuon(torch.optim.Optimizer):
+ """
+ Muon variant for usage in non-distributed settings.
+ """
+ def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
+ defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
+ super().__init__(params, defaults)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is None:
+ # continue
+ p.grad = torch.zeros_like(p) # Force synchronization
+ state = self.state[p]
+ if len(state) == 0:
+ state["momentum_buffer"] = torch.zeros_like(p)
+ update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
+ p.mul_(1 - group["lr"] * group["weight_decay"])
+ p.add_(update.reshape(p.shape), alpha=-group["lr"])
+
+ return loss
+
+
+def adam_update(grad, buf1, buf2, step, betas, eps):
+ buf1.lerp_(grad, 1 - betas[0])
+ buf2.lerp_(grad.square(), 1 - betas[1])
+ buf1c = buf1 / (1 - betas[0]**step)
+ buf2c = buf2 / (1 - betas[1]**step)
+ return buf1c / (buf2c.sqrt() + eps)
+
+
+class MuonWithAuxAdam(torch.optim.Optimizer):
+ """
+ Distributed Muon variant that can be used for all parameters in the network, since it runs an
+ internal AdamW for the parameters that are not compatible with Muon. The user must manually
+ specify which parameters shall be optimized with Muon and which with Adam by passing in a
+ list of param_groups with the `use_muon` flag set.
+
+ The point of this class is to allow the user to have a single optimizer in their code, rather
+ than having both a Muon and an Adam which each need to be stepped.
+
+ You can see an example usage below:
+
+ https://github.com/KellerJordan/modded-nanogpt/blob/master/records/052525_MuonWithAuxAdamExample/b01550f9-03d8-4a9c-86fe-4ab434f1c5e0.txt#L470
+ ```
+ hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n]
+ embed_params = [p for n, p in model.named_parameters() if "embed" in n]
+ scalar_params = [p for p in model.parameters() if p.ndim < 2]
+ head_params = [model.lm_head.weight]
+
+ from muon import MuonWithAuxAdam
+ adam_groups = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)]
+ adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups]
+ muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True)
+ param_groups = [*adam_groups, muon_group]
+ optimizer = MuonWithAuxAdam(param_groups)
+ ```
+ """
+ def __init__(self, param_groups):
+ for group in param_groups:
+ assert "use_muon" in group
+ if group["use_muon"]:
+ group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True)
+ # defaults
+ group["lr"] = group.get("lr", 0.02)
+ group["momentum"] = group.get("momentum", 0.95)
+ group["weight_decay"] = group.get("weight_decay", 0)
+ assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
+ else:
+ # defaults
+ group["lr"] = group.get("lr", 3e-4)
+ group["betas"] = group.get("betas", (0.9, 0.95))
+ group["eps"] = group.get("eps", 1e-10)
+ group["weight_decay"] = group.get("weight_decay", 0)
+ assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
+ super().__init__(param_groups, dict())
+
+ @torch.no_grad()
+ def step(self, closure=None):
+
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ if group["use_muon"]:
+ params = group["params"]
+ params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
+ for base_i in range(len(params))[::dist.get_world_size()]:
+ if base_i + dist.get_rank() < len(params):
+ p = params[base_i + dist.get_rank()]
+ if p.grad is None:
+ # continue
+ p.grad = torch.zeros_like(p) # Force synchronization
+ state = self.state[p]
+ if len(state) == 0:
+ state["momentum_buffer"] = torch.zeros_like(p)
+ update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
+ p.mul_(1 - group["lr"] * group["weight_decay"])
+ p.add_(update.reshape(p.shape), alpha=-group["lr"])
+ dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
+ else:
+ for p in group["params"]:
+ if p.grad is None:
+ # continue
+ p.grad = torch.zeros_like(p) # Force synchronization
+ state = self.state[p]
+ if len(state) == 0:
+ state["exp_avg"] = torch.zeros_like(p)
+ state["exp_avg_sq"] = torch.zeros_like(p)
+ state["step"] = 0
+ state["step"] += 1
+ update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
+ state["step"], group["betas"], group["eps"])
+ p.mul_(1 - group["lr"] * group["weight_decay"])
+ p.add_(update, alpha=-group["lr"])
+
+ return loss
+
+
+class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
+ """
+ Non-distributed variant of MuonWithAuxAdam.
+ """
+ def __init__(self, param_groups):
+ for group in param_groups:
+ assert "use_muon" in group
+ if group["use_muon"]:
+ # defaults
+ group["lr"] = group.get("lr", 0.02)
+ group["momentum"] = group.get("momentum", 0.95)
+ group["weight_decay"] = group.get("weight_decay", 0)
+ assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
+ else:
+ # defaults
+ group["lr"] = group.get("lr", 3e-4)
+ group["betas"] = group.get("betas", (0.9, 0.95))
+ group["eps"] = group.get("eps", 1e-10)
+ group["weight_decay"] = group.get("weight_decay", 0)
+ assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
+ super().__init__(param_groups, dict())
+
+ @torch.no_grad()
+ def step(self, closure=None):
+
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ if group["use_muon"]:
+ for p in group["params"]:
+ if p.grad is None:
+ # continue
+ p.grad = torch.zeros_like(p) # Force synchronization
+ state = self.state[p]
+ if len(state) == 0:
+ state["momentum_buffer"] = torch.zeros_like(p)
+ update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
+ p.mul_(1 - group["lr"] * group["weight_decay"])
+ p.add_(update.reshape(p.shape), alpha=-group["lr"])
+ else:
+ for p in group["params"]:
+ if p.grad is None:
+ # continue
+ p.grad = torch.zeros_like(p) # Force synchronization
+ state = self.state[p]
+ if len(state) == 0:
+ state["exp_avg"] = torch.zeros_like(p)
+ state["exp_avg_sq"] = torch.zeros_like(p)
+ state["step"] = 0
+ state["step"] += 1
+ update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
+ state["step"], group["betas"], group["eps"])
+ p.mul_(1 - group["lr"] * group["weight_decay"])
+ p.add_(update, alpha=-group["lr"])
+
+ return loss
diff --git a/preprocess/tools/vocal_separation/utils/settings.py b/preprocess/tools/vocal_separation/utils/settings.py
new file mode 100644
index 0000000000000000000000000000000000000000..127c90126b4aee7b6b05a70ccec3793c5dea3e98
--- /dev/null
+++ b/preprocess/tools/vocal_separation/utils/settings.py
@@ -0,0 +1,501 @@
+import os
+import random
+import time
+import yaml
+import wandb
+import numpy as np
+import torch
+import argparse
+from typing import Dict, List, Tuple, Union
+from omegaconf import OmegaConf
+from ml_collections import ConfigDict
+import torch.distributed as dist
+from torch import nn
+
+
+def parse_args_train(dict_args: Union[Dict, None]) -> argparse.Namespace:
+ """
+ Parse command-line arguments for training configuration.
+
+ This function constructs an argument parser for model, dataset, training, and logging
+ options, merges overrides from a provided dictionary (if any), and returns the parsed
+ arguments. If `dict_args` is None, the arguments are parsed from `sys.argv`.
+
+ Args:
+ dict_args (Dict | None): Optional dictionary of argument overrides. Keys should
+ match the defined CLI options.
+
+ Returns:
+ argparse.Namespace: Parsed arguments namespace containing all configuration
+ values required for training.
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_type", type=str, default='mdx23c',
+ help="One of mdx23c, htdemucs, segm_models, mel_band_roformer, bs_roformer, swin_upernet, bandit")
+ parser.add_argument("--config_path", type=str, help="path to config file")
+ parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint to start training")
+ parser.add_argument("--load_optimizer", action='store_true', help="Load optimizer state from checkpoint (if available)")
+ parser.add_argument("--load_scheduler", action='store_true', help="Load scheduler state from checkpoint (if available)")
+ parser.add_argument("--load_epoch", action='store_true', help="Load epoch number from checkpoint (if available)")
+ parser.add_argument("--load_best_metric", action='store_true', help="Load best metric from checkpoint (if available)")
+ parser.add_argument("--load_all_metrics", action='store_true', help="Load all metrics from checkpoint (if available)")
+ parser.add_argument("--results_path", type=str,
+ help="path to folder where results will be stored (weights, metadata)")
+ parser.add_argument("--data_path", nargs="+", type=str, help="Dataset data paths. You can provide several folders.")
+ parser.add_argument("--dataset_type", type=int, default=1,
+ help="Dataset type. Must be one of: 1, 2, 3 or 4. Details here: https://github.com/ZFTurbo/Music-Source-Separation-Training/blob/main/docs/dataset_types.md")
+ parser.add_argument("--valid_path", nargs="+", type=str,
+ help="validation data paths. You can provide several folders.")
+ parser.add_argument("--num_workers", type=int, default=0, help="dataloader num_workers")
+ parser.add_argument("--pin_memory", action='store_true', help="dataloader pin_memory")
+ parser.add_argument("--seed", type=int, default=0, help="random seed")
+ parser.add_argument("--device_ids", nargs='+', type=int, default=[0], help='list of gpu ids')
+ parser.add_argument("--loss", type=str, nargs='+', choices=['masked_loss', 'mse_loss', 'l1_loss',
+ 'multistft_loss', 'spec_masked_loss', 'spec_rmse_loss', 'log_wmse_loss'],
+ default=['masked_loss'], help="List of loss functions to use")
+ parser.add_argument("--masked_loss_coef", type=float, default=1., help="Coef for loss")
+ parser.add_argument("--mse_loss_coef", type=float, default=1., help="Coef for loss")
+ parser.add_argument("--l1_loss_coef", type=float, default=1., help="Coef for loss")
+ parser.add_argument("--log_wmse_loss_coef", type=float, default=1., help="Coef for loss")
+ parser.add_argument("--multistft_loss_coef", type=float, default=0.001, help="Coef for loss")
+ parser.add_argument("--spec_masked_loss_coef", type=float, default=1, help="Coef for loss")
+ parser.add_argument("--spec_rmse_loss_coef", type=float, default=1, help="Coef for loss")
+ parser.add_argument("--wandb_key", type=str, default='', help='wandb API Key')
+ parser.add_argument("--wandb_offline", action='store_true', help='local wandb')
+ parser.add_argument("--pre_valid", action='store_true', help='Run validation before training')
+ parser.add_argument("--metrics", nargs='+', type=str, default=["sdr"],
+ choices=['sdr', 'l1_freq', 'si_sdr', 'log_wmse', 'aura_stft', 'aura_mrstft', 'bleedless',
+ 'fullness'], help='List of metrics to use.')
+ parser.add_argument("--metric_for_scheduler", default="sdr",
+ choices=['sdr', 'l1_freq', 'si_sdr', 'log_wmse', 'aura_stft', 'aura_mrstft', 'bleedless',
+ 'fullness'], help='Metric which will be used for scheduler.')
+ parser.add_argument("--train_lora", action='store_true', help="Train with LoRA")
+ parser.add_argument("--lora_checkpoint", type=str, default='', help="Initial checkpoint to LoRA weights")
+ parser.add_argument("--each_metrics_in_name", action='store_true',
+ help="All stems in naming checkpoints")
+ parser.add_argument("--use_standard_loss", action='store_true',
+ help="Roformers will use provided loss instead of internal")
+ parser.add_argument("--save_weights_every_epoch", action='store_true',
+ help="Weights will be saved every epoch with all metric values")
+ parser.add_argument("--persistent_workers", action='store_true',
+ help="dataloader persistent_workers")
+ parser.add_argument("--prefetch_factor", type=int, default=None,
+ help="dataloader prefetch_factor")
+ parser.add_argument("--set_per_process_memory_fraction", action='store_true',
+ help="using only VRAM, no RAM")
+
+ if dict_args is not None:
+ args = parser.parse_args([])
+ args_dict = vars(args)
+ args_dict.update(dict_args)
+ args = argparse.Namespace(**args_dict)
+ else:
+ args = parser.parse_args()
+
+ if args.metric_for_scheduler not in args.metrics:
+ args.metrics += [args.metric_for_scheduler]
+
+ get_internal_loss = (args.model_type in ('mel_band_conformer',) or 'roformer' in args.model_type
+ ) and not args.use_standard_loss
+ if get_internal_loss:
+ args.loss = [f'{args.model_type}_loss']
+ return args
+
+
+def parse_args_valid(dict_args: Union[Dict, None]) -> argparse.Namespace:
+ """
+ Parse command-line arguments for validation configuration.
+
+ Builds the CLI for model selection, configuration paths, validation data
+ locations, output/spectrogram saving options, device/runtime settings, and
+ evaluation metrics. If `dict_args` is provided, its key–value pairs override
+ or set the parsed arguments; otherwise arguments are read from `sys.argv`.
+
+ Args:
+ dict_args (Union[Dict, None]): Optional mapping of argument names to values
+ used to override or supply CLI options programmatically.
+
+ Returns:
+ argparse.Namespace: Parsed arguments namespace containing all validation
+ configuration values.
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_type", type=str, default='mdx23c',
+ help="One of mdx23c, htdemucs, segm_models, mel_band_roformer,"
+ " bs_roformer, swin_upernet, bandit")
+ parser.add_argument("--config_path", type=str, help="Path to config file")
+ parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint"
+ " to valid weights")
+ parser.add_argument("--valid_path", nargs="+", type=str, help="Validate path")
+ parser.add_argument("--store_dir", type=str, default="", help="Path to store results as wav file")
+ parser.add_argument("--draw_spectro", type=float, default=0,
+ help="If --store_dir is set then code will generate spectrograms for resulted stems as well."
+ " Value defines for how many seconds os track spectrogram will be generated.")
+ parser.add_argument("--device_ids", nargs='+', type=int, default=[0], help='List of gpu ids')
+ parser.add_argument("--num_workers", type=int, default=0, help="Dataloader num_workers")
+ parser.add_argument("--pin_memory", action='store_true', help="Dataloader pin_memory")
+ parser.add_argument("--extension", type=str, default='wav', help="Choose extension for validation")
+ parser.add_argument("--use_tta", action='store_true',
+ help="Flag adds test time augmentation during inference (polarity and channel inverse)."
+ "While this triples the runtime, it reduces noise and slightly improves prediction quality.")
+ parser.add_argument("--metrics", nargs='+', type=str, default=["sdr"],
+ choices=['sdr', 'l1_freq', 'si_sdr', 'neg_log_wmse', 'aura_stft', 'aura_mrstft', 'bleedless',
+ 'fullness'], help='List of metrics to use.')
+ parser.add_argument("--lora_checkpoint", type=str, default='', help="Initial checkpoint to LoRA weights")
+
+ if dict_args is not None:
+ args = parser.parse_args([])
+ args_dict = vars(args)
+ args_dict.update(dict_args)
+ args = argparse.Namespace(**args_dict)
+ else:
+ args = parser.parse_args()
+
+ return args
+
+
+def parse_args_inference(dict_args: Union[Dict, None]) -> argparse.Namespace:
+ """
+ Parse command-line arguments for inference configuration.
+
+ Builds the CLI for model selection, configuration path, input/output handling,
+ device/runtime options, test-time augmentation, and optional LoRA checkpoints.
+ If `dict_args` is provided, its key–value pairs override or supply CLI options
+ programmatically; otherwise, arguments are read from `sys.argv`.
+
+ Args:
+ dict_args (Union[Dict, None]): Optional mapping of argument names to values
+ used to override or supply CLI options programmatically.
+
+ Returns:
+ argparse.Namespace: Parsed arguments namespace containing all inference
+ configuration values.
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_type", type=str, default='mdx23c',
+ help="One of bandit, bandit_v2, bs_roformer, htdemucs, mdx23c, mel_band_roformer,"
+ " scnet, scnet_unofficial, segm_models, swin_upernet, torchseg")
+ parser.add_argument("--config_path", type=str, help="path to config file")
+ parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint to valid weights")
+ parser.add_argument("--input_folder", type=str, help="folder with mixtures to process")
+ parser.add_argument("--store_dir", type=str, default="", help="path to store results as wav file")
+ parser.add_argument("--draw_spectro", type=float, default=0,
+ help="Code will generate spectrograms for resulted stems."
+ " Value defines for how many seconds os track spectrogram will be generated.")
+ parser.add_argument("--device_ids", nargs='+', type=int, default=0, help='list of gpu ids')
+ parser.add_argument("--extract_instrumental", action='store_true',
+ help="invert vocals to get instrumental if provided")
+ parser.add_argument("--disable_detailed_pbar", action='store_true', help="disable detailed progress bar")
+ parser.add_argument("--force_cpu", action='store_true', help="Force the use of CPU even if CUDA is available")
+ parser.add_argument("--flac_file", action='store_true', help="Output flac file instead of wav")
+ parser.add_argument("--pcm_type", type=str, choices=['PCM_16', 'PCM_24'], default='PCM_24',
+ help="PCM type for FLAC files (PCM_16 or PCM_24)")
+ parser.add_argument("--use_tta", action='store_true',
+ help="Flag adds test time augmentation during inference (polarity and channel inverse)."
+ "While this triples the runtime, it reduces noise and slightly improves prediction quality.")
+ parser.add_argument("--lora_checkpoint", type=str, default='', help="Initial checkpoint to LoRA weights")
+
+ if dict_args is not None:
+ args = parser.parse_args([])
+ args_dict = vars(args)
+ args_dict.update(dict_args)
+ args = argparse.Namespace(**args_dict)
+ else:
+ args = parser.parse_args()
+
+ return args
+
+
+def load_config(model_type: str, config_path: str) -> Union[ConfigDict, OmegaConf]:
+ """
+ Load a model configuration from a file.
+
+ Based on `model_type`, returns either an OmegaConf (e.g., for 'htdemucs')
+ or a YAML-parsed ConfigDict for other models.
+
+ Args:
+ model_type (str): Model identifier that determines the loader behavior
+ (e.g., 'htdemucs', 'mdx23c', etc.).
+ config_path (str): Path to the configuration file (YAML/OmegaConf).
+
+ Returns:
+ Union[ConfigDict, OmegaConf]: Loaded configuration object.
+
+ Raises:
+ FileNotFoundError: If `config_path` does not point to an existing file.
+ ValueError: If the configuration cannot be parsed or is otherwise invalid.
+ """
+ try:
+ with open(config_path, 'r') as f:
+ if model_type == 'htdemucs':
+ config = OmegaConf.load(config_path)
+ else:
+ config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
+ return config
+ except FileNotFoundError:
+ raise FileNotFoundError(f"Configuration file not found at {config_path}")
+ except Exception as e:
+ raise ValueError(f"Error loading configuration: {e}")
+
+
+def get_model_from_config(model_type: str, config_path: str) -> Tuple[nn.Module, Union[ConfigDict, OmegaConf]]:
+ """
+ Load and instantiate a model using a configuration file.
+
+ Given a `model_type` and a path to a configuration, this function loads the
+ configuration (YAML or OmegaConf) and constructs the corresponding model.
+
+ Args:
+ model_type (str): Identifier of the model family (e.g., 'mdx23c', 'htdemucs',
+ 'scnet', 'mel_band_conformer', etc.).
+ config_path (str): Filesystem path to the configuration file used to
+ initialize the model.
+
+ Returns:
+ Tuple[nn.Module, Union[ConfigDict, OmegaConf]]: A tuple containing the
+ initialized PyTorch model and the loaded configuration object.
+
+ Raises:
+ ValueError: If `model_type` is unknown or model initialization fails.
+ FileNotFoundError: If `config_path` does not exist (may be raised by the
+ underlying config loader).
+ """
+
+ config = load_config(model_type, config_path)
+
+ if model_type == 'mel_band_roformer':
+ from ..modules.bs_roformer import MelBandRoformer
+ model = MelBandRoformer(**dict(config.model))
+ else:
+ raise ValueError(f"Unknown model type: {model_type}")
+
+ return model, config
+
+
+def logging(logs: List[str], text: str, verbose_logging: bool = False) -> None:
+ """
+ Print a log message and optionally append it to an in-memory list.
+
+ In Distributed Data Parallel (DDP) contexts, the message is printed only on
+ rank 0; when DDP is uninitialized, it prints unconditionally. If
+ `verbose_logging` is True, the message is also appended to `logs`.
+
+ Args:
+ logs (List[str]): Mutable list to which the message is appended when
+ `verbose_logging` is True.
+ text (str): The log message to print (rank 0 only under DDP) and
+ optionally store.
+ verbose_logging (bool, optional): If True, append `text` to `logs`.
+ Defaults to False.
+
+ Returns:
+ None: The function prints and may mutate `logs` in place.
+ """
+ if not dist.is_initialized() or dist.get_rank()==0:
+ print(text)
+ if verbose_logging:
+ logs.append(text)
+
+
+def write_results_in_file(store_dir: str, logs: List[str]) -> None:
+ """
+ Write accumulated log messages to a results file.
+
+ Creates (or overwrites) a `results.txt` file inside `store_dir` and writes
+ each entry from `logs` as a separate line. In Distributed Data Parallel (DDP)
+ scenarios, writing is intended to occur only on rank 0.
+
+ Args:
+ store_dir (str): Directory path where `results.txt` will be saved.
+ logs (List[str]): Ordered collection of log lines to write.
+
+ Returns:
+ None
+ """
+ if not dist.is_initialized() or dist.get_rank() == 0:
+ with open(f'{store_dir}/results.txt', 'w') as out:
+ for item in logs:
+ out.write(item + "\n")
+
+
+def manual_seed(seed: int) -> None:
+ """
+ Initialize random seeds for reproducibility.
+
+ Sets the seed across Python's `random`, NumPy, and PyTorch (CPU and CUDA)
+ libraries, and updates the `PYTHONHASHSEED` environment variable. This helps
+ ensure deterministic behavior where possible, though some GPU operations
+ may still introduce nondeterminism.
+
+ Args:
+ seed (int): The seed value to use for all random number generators.
+
+ Returns:
+ None
+ """
+
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed) # if multi-GPU
+ torch.backends.cudnn.deterministic = False
+ os.environ["PYTHONHASHSEED"] = str(seed)
+
+
+def initialize_environment(seed: int, results_path: str) -> None:
+ """
+ Initialize runtime environment settings.
+
+ Sets random seeds for reproducibility, adjusts PyTorch cuDNN behavior,
+ configures multiprocessing with the 'spawn' start method, and ensures
+ the results directory exists.
+
+ Args:
+ seed (int): Random seed value for deterministic initialization.
+ results_path (str): Filesystem path to create for saving results.
+
+ Returns:
+ None
+ """
+
+ manual_seed(seed)
+ torch.backends.cudnn.deterministic = False
+ try:
+ torch.multiprocessing.set_start_method('spawn')
+ except Exception as e:
+ pass
+ os.makedirs(results_path, exist_ok=True)
+
+
+def initialize_environment_ddp(rank: int, world_size: int, seed: int = 0, resuls_path: str = None) -> None:
+ """
+ Initialize environment for Distributed Data Parallel (DDP) training/validation.
+
+ Sets up the DDP process group, seeds random number generators, configures
+ multiprocessing to use the 'spawn' method, and creates a results directory
+ if provided.
+
+ Args:
+ rank (int): Rank of the current process within the DDP group.
+ world_size (int): Total number of processes participating in DDP.
+ seed (int, optional): Random seed for reproducibility. Defaults to 0.
+ resuls_path (str, optional): Directory path to create for storing results.
+ If None, no directory is created. Defaults to None.
+
+ Returns:
+ None
+ """
+
+ setup_ddp(rank, world_size)
+ manual_seed(seed)
+
+ try:
+ torch.multiprocessing.set_start_method('spawn', force=True) # force=True prevent errors
+ except RuntimeError as e:
+ if "context has already been set" not in str(e):
+ raise e
+ if not(resuls_path is None):
+ os.makedirs(resuls_path, exist_ok=True)
+
+
+def gen_wandb_name(args, config) -> str:
+ """
+ Generate a descriptive name for a Weights & Biases (wandb) run.
+
+ Combines the model type, a dash-joined list of training instruments,
+ and the current date into a single string identifier.
+
+ Args:
+ args: Parsed arguments namespace containing at least `model_type`.
+ config: Configuration object/dict with a `training.instruments` field.
+
+ Returns:
+ str: Formatted run name in the form
+ "_[--...]_".
+ """
+
+ instrum = '-'.join(config['training']['instruments'])
+ time_str = time.strftime("%Y-%m-%d")
+ name = '{}_[{}]_{}'.format(args.model_type, instrum, time_str)
+ return name
+
+
+def wandb_init(args: argparse.Namespace, config: Dict, batch_size: int) -> None:
+ """
+ Initialize Weights & Biases (wandb) for experiment tracking.
+
+ Depending on the provided arguments, sets up wandb in one of three modes:
+ - Offline mode when `args.wandb_offline` is True.
+ - Disabled mode when no valid `wandb_key` is provided.
+ - Online mode with authentication using `args.wandb_key`.
+
+ Args:
+ args (argparse.Namespace): Parsed arguments containing wandb options
+ (`wandb_offline`, `wandb_key`, `device_ids`).
+ config (Dict): Experiment configuration dictionary to log.
+ batch_size (int): Training batch size to include in the run configuration.
+
+ Returns:
+ None
+ """
+
+ if not dist.is_initialized() or dist.get_rank() == 0:
+ if args.wandb_offline:
+ wandb.init(mode='offline',
+ project='msst',
+ name=gen_wandb_name(args, config),
+ config={'config': config, 'args': args, 'device_ids': args.device_ids, 'batch_size': batch_size}
+ )
+ elif args.wandb_key is None or args.wandb_key.strip() == '':
+ wandb.init(mode='disabled')
+ else:
+ wandb.login(key=args.wandb_key)
+ wandb.init(
+ project='msst',
+ name=gen_wandb_name(args, config),
+ config={'config': config, 'args': args, 'device_ids': args.device_ids, 'batch_size': batch_size}
+ )
+
+
+def setup_ddp(rank: int, world_size: int) -> None:
+ """
+ Initialize a Distributed Data Parallel (DDP) process group.
+
+ Configures environment variables for the DDP master node, attempts to
+ initialize the process group with the NCCL backend (preferred for GPUs),
+ and falls back to the Gloo backend if NCCL is unavailable. Also sets the
+ current CUDA device to match the process rank.
+
+ Args:
+ rank (int): Rank of the current process in the DDP group.
+ world_size (int): Total number of processes participating in DDP.
+
+ Returns:
+ None
+ """
+
+ os.environ['MASTER_ADDR'] = 'localhost'
+ os.environ['MASTER_PORT'] = '12355' # We can change and use another
+ os.environ["USE_LIBUV"] = "0"
+ try:
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
+ except:
+ dist.init_process_group("gloo", rank=rank, world_size=world_size)
+ if dist.get_rank()==0:
+ print(f'NCCL are not available. Using "gloo" backend.')
+
+ torch.cuda.set_device(rank)
+
+
+def cleanup_ddp() -> None:
+ """
+ Finalize and clean up a Distributed Data Parallel (DDP) process group.
+
+ Calls `torch.distributed.destroy_process_group()` to release resources
+ associated with the current DDP environment.
+
+ Returns:
+ None
+ """
+ dist.destroy_process_group()
diff --git a/preprocess/utils.py b/preprocess/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..745ace934613e8f58f8a4c0e6bd91b10908e87ba
--- /dev/null
+++ b/preprocess/utils.py
@@ -0,0 +1,224 @@
+import os
+from pathlib import Path
+from typing import List, Optional, Dict
+
+import numpy as np
+import soundfile as sf
+from dataclasses import dataclass
+
+from preprocess.tools.g2p import g2p_transform
+
+
+@dataclass
+class SegmentMetadata:
+ item_name: str
+ wav_fn: str
+ language: str
+ start_time_ms: int
+ end_time_ms: int
+ note_text: List[str]
+ note_dur: List[float]
+ note_pitch: List[int]
+ note_type: List[int]
+ origin_wav_fn: Optional[str] = None
+
+
+def _merge_group(
+ audio: np.ndarray,
+ sample_rate: int,
+ segments: List[SegmentMetadata],
+ output_dir: Path,
+ end_extension_ms: int = 0,
+) -> SegmentMetadata:
+ """
+ Merge a group of consecutive segments into a single segment.
+
+ This function:
+ - Concatenates note-level information
+ - Inserts for silence gaps
+ - Merges consecutive
+ - Cuts and writes merged audio
+ - Determines dominant language
+
+ Args:
+ audio: Full vocal audio waveform (T,)
+ sample_rate: Audio sample rate
+ segments: Consecutive segments to be merged (SegmentMetadata or dict)
+ output_dir: Directory to save merged wav
+ end_extension_ms: Extra silence appended to the end (ms)
+
+ Returns:
+ A merged SegmentMetadata instance
+ """
+ if not segments:
+ raise ValueError("segments must not be empty")
+
+ # Helper function to get attributes from either SegmentMetadata or dict
+ def get_attr(seg, attr_name, default=None):
+ if isinstance(seg, dict):
+ return seg.get(attr_name, default)
+ return getattr(seg, attr_name, default)
+
+ # ---------- concat notes ----------
+ words: List[str] = []
+ durs: List[float] = []
+ pitches: List[int] = []
+ types: List[int] = []
+
+ for i, seg in enumerate(segments):
+ if i > 0:
+ prev_seg = segments[i - 1]
+ gap_ms = (
+ get_attr(seg, "start_time_ms", 0)
+ - get_attr(prev_seg, "end_time_ms", 0)
+ )
+ if gap_ms > 0:
+ words.append("")
+ durs.append(gap_ms / 1000.0)
+ pitches.append(0)
+ types.append(1)
+
+ words.extend(get_attr(seg, "note_text", []))
+ durs.extend(get_attr(seg, "note_dur", []))
+ pitches.extend(get_attr(seg, "note_pitch", []))
+ types.extend(get_attr(seg, "note_type", []))
+
+ if end_extension_ms > 0:
+ words.append("")
+ durs.append(end_extension_ms / 1000.0)
+ pitches.append(0)
+ types.append(1)
+
+ # ---------- merge consecutive ----------
+ merged_words, merged_durs, merged_pitches, merged_types = [], [], [], []
+ for w, d, p, t in zip(words, durs, pitches, types):
+ if merged_words and w == "" and merged_words[-1] == "":
+ merged_durs[-1] += d
+ else:
+ merged_words.append(w)
+ merged_durs.append(d)
+ merged_pitches.append(p)
+ merged_types.append(t)
+
+ # ---------- dominant language ----------
+ languages = [get_attr(s, "language", "Mandarin") for s in segments if get_attr(s, "language")]
+ language = (
+ max(languages, key=languages.count)
+ if languages
+ else "Mandarin"
+ )
+
+ # ---------- time & audio ----------
+ start_ms = get_attr(segments[0], "start_time_ms", 0)
+ end_ms = get_attr(segments[-1], "end_time_ms", 0) + end_extension_ms
+ start_sample = start_ms * sample_rate // 1000
+ end_sample = end_ms * sample_rate // 1000
+
+ # ---------- naming ----------
+ first_item_name = get_attr(segments[0], "item_name", "segment")
+ song_prefix = "_".join(first_item_name.split("_")[:-1])
+ item_name = f"{song_prefix}_{start_ms}_{end_ms}"
+
+ wav_path = output_dir / f"{item_name}.wav"
+ sf.write(
+ wav_path,
+ audio[start_sample:end_sample],
+ sample_rate,
+ )
+
+ return SegmentMetadata(
+ item_name=item_name,
+ wav_fn=str(wav_path),
+ language=language,
+ start_time_ms=start_ms,
+ end_time_ms=end_ms,
+ note_text=merged_words,
+ note_dur=merged_durs,
+ note_pitch=merged_pitches,
+ note_type=merged_types,
+ origin_wav_fn=get_attr(segments[0], "origin_wav_fn", ""),
+ )
+
+
+def convert_metadata(item) -> Dict:
+ """
+ Convert internal SegmentMetadata into final json-serializable format.
+ """
+ f0_path = item.wav_fn.replace(".wav", "_f0.npy")
+ f0 = np.load(f0_path)
+
+ return {
+ "index": item.item_name,
+ "language": item.language,
+ "time": [item.start_time_ms, item.end_time_ms],
+ "duration": " ".join(f"{d:.2f}" for d in item.note_dur),
+ "text": " ".join(item.note_text),
+ "phoneme": " ".join(
+ g2p_transform(item.note_text, item.language)
+ ),
+ "note_pitch": " ".join(map(str, item.note_pitch)),
+ "note_type": " ".join(map(str, item.note_type)),
+ "f0": " ".join(f"{x:.1f}" for x in f0),
+ }
+
+
+def merge_short_segments(
+ audio: np.ndarray,
+ sample_rate: int,
+ segments: List[SegmentMetadata],
+ output_dir: str,
+ max_gap_ms: int = 10000,
+ max_duration_ms: int = 60000,
+ end_extension_ms: int = 0,
+) -> List[SegmentMetadata]:
+ """
+ Merge short segments into longer audio chunks.
+
+ Args:
+ audio: Full vocal audio waveform
+ sample_rate: Audio sample rate
+ segments: List of SegmentMetadata or dict objects
+ output_dir: Directory to save merged segments
+ max_gap_ms: Maximum gap between segments to merge (ms)
+ max_duration_ms: Maximum duration of merged segment (ms)
+ end_extension_ms: Extra silence to append at end (ms)
+
+ Returns:
+ List of merged SegmentMetadata objects
+ """
+ os.makedirs(output_dir, exist_ok=True)
+
+ merged_segments = []
+ current_group = []
+ current_len = 0
+ prev_end = -1
+
+ for seg in segments:
+ if isinstance(seg, dict):
+ start_time = seg.get("start_time_ms", 0)
+ end_time = seg.get("end_time_ms", 0)
+ else:
+ start_time = seg.start_time_ms
+ end_time = seg.end_time_ms
+
+ if (
+ current_group
+ and (start_time - prev_end > max_gap_ms
+ or current_len + end_time - start_time > max_duration_ms)
+ ):
+ merged_segments.append(_merge_group(
+ audio, sample_rate, current_group, output_dir, end_extension_ms
+ ))
+ current_group = []
+ current_len = 0
+
+ current_group.append(seg)
+ current_len += end_time - start_time
+ prev_end = end_time
+
+ if current_group:
+ merged_segments.append(_merge_group(
+ audio, sample_rate, current_group, output_dir, end_extension_ms
+ ))
+
+ return merged_segments
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3e4d6d086f95af0ed82a5feb32966cb3b0383afb
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,13 @@
+accelerate==1.11.0
+gradio==6.3.0
+huggingface_hub>=0.20.0
+librosa==0.11.0
+numpy==2.2.6
+omegaconf==2.3.0
+scipy==1.15.3
+soundfile==0.13.1
+torch==2.10.0
+torchaudio==2.10.0
+torchcodec==0.10.0
+tqdm==4.67.1
+transformers==4.41.2
diff --git a/soulxsinger/__init__.py b/soulxsinger/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/soulxsinger/config/soulxsinger.yaml b/soulxsinger/config/soulxsinger.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b2964b328783451e6239e78e1e2558fcfe404ffb
--- /dev/null
+++ b/soulxsinger/config/soulxsinger.yaml
@@ -0,0 +1,37 @@
+infer:
+ n_steps: 32
+ cfg: 3
+
+audio:
+ hop_size: 480
+ sample_rate: 24000
+ max_length: 36000
+ n_fft: 1920
+ num_mels: 128
+ win_size: 1920
+ fmin: 0
+ fmax: 12000
+ mel_var: 8.14
+ mel_mean: -4.92
+
+model:
+ encoder:
+ vocab_size: 3000
+ text_dim: 512
+ pitch_dim: 512
+ type_dim: 512
+ f0_bin: 361
+ f0_dim: 512
+ num_layers: 4
+
+ flow_matching:
+ mel_dim: 128
+ hidden_size: 1024
+ num_layers: 22
+ num_heads: 16
+ cfg_drop_prob: 0.2
+ use_embedding: False
+ cond_codebook_size: 512
+ cond_scale_factor: 1
+ sigma: 1e-5
+ time_scheduler: cos
\ No newline at end of file
diff --git a/soulxsinger/models/modules/convnext.py b/soulxsinger/models/modules/convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3ff59181ab4cb53b3fad6577de6f4d4f4b25007
--- /dev/null
+++ b/soulxsinger/models/modules/convnext.py
@@ -0,0 +1,46 @@
+import torch.nn as nn
+import torch
+
+
+class GRN(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
+
+ def forward(self, x):
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
+ return self.gamma * (x * Nx) + self.beta + x
+
+
+# ref: https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/model/modules.py#L247
+class ConvNeXtV2Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ intermediate_dim: int,
+ dilation: int = 1,
+ ):
+ super().__init__()
+ padding = (dilation * (7 - 1)) // 2
+ self.dwconv = nn.Conv1d(
+ dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
+ ) # depthwise conv
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.grn = GRN(intermediate_dim)
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ residual = x
+ x = x.transpose(1, 2) # b n d -> b d n
+ x = self.dwconv(x)
+ x = x.transpose(1, 2) # b d n -> b n d
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.grn(x)
+ x = self.pwconv2(x)
+ return residual + x
diff --git a/soulxsinger/models/modules/decoder.py b/soulxsinger/models/modules/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab30f0f32638a06a121e714143b77ee52d16af23
--- /dev/null
+++ b/soulxsinger/models/modules/decoder.py
@@ -0,0 +1,29 @@
+import torch
+import torch.nn as nn
+
+from soulxsinger.models.modules.flow_matching import FlowMatchingTransformer
+
+
+class CFMDecoder(nn.Module):
+ def __init__(self, config):
+ super(CFMDecoder, self).__init__()
+ self.model = FlowMatchingTransformer(cfg=config, **config)
+
+ def forward(self, mel, x_mask, decoder_inp, is_prompt):
+ outputs = self.model(mel, x_mask, decoder_inp, is_prompt)
+
+ noise, x, flow_pred, final_mask, prompt_len = outputs["output"]
+ return noise, x, flow_pred, final_mask, prompt_len
+
+ def reverse_diffusion(self, pt_mel, pt_decoder_inp, gt_decoder_inp, n_timesteps=32, cfg=1):
+ diffusion_cond = torch.cat([pt_decoder_inp, gt_decoder_inp], dim=1)
+ diffusion_cond_emb = self.model.cond_emb(diffusion_cond)
+ diffusion_prompt = pt_mel
+
+ generated = self.model.reverse_diffusion(
+ diffusion_cond_emb,
+ diffusion_prompt,
+ n_timesteps=n_timesteps,
+ cfg=cfg
+ )
+ return generated
\ No newline at end of file
diff --git a/soulxsinger/models/modules/flow_matching.py b/soulxsinger/models/modules/flow_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..47841857430fb1d66e5d47e2a84b0692047d53bb
--- /dev/null
+++ b/soulxsinger/models/modules/flow_matching.py
@@ -0,0 +1,445 @@
+# https://github.com/open-mmlab/Amphion/blob/main/models/svc/flow_matching_transformer/fmt_model.py
+
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import numpy as np
+import torch.nn as nn
+import math
+from .llama import DiffLlama
+import torch.nn.functional as F
+
+
+class FlowMatchingTransformer(nn.Module):
+ def __init__(
+ self,
+ mel_dim=100,
+ hidden_size=1024,
+ num_layers=12,
+ num_heads=16,
+ cfg_drop_prob=0.2,
+ use_embedding=True,
+ cond_codebook_size=1024,
+ cond_scale_factor=1,
+ sigma=1e-5,
+ time_scheduler="linear",
+ cfg=None,
+ ):
+ super().__init__()
+ self.cfg = cfg
+
+ if cfg is not None:
+ mel_dim = getattr(cfg, "mel_dim", mel_dim)
+ hidden_size = getattr(cfg, "hidden_size", hidden_size)
+ num_layers = getattr(cfg, "num_layers", num_layers)
+ num_heads = getattr(cfg, "num_heads", num_heads)
+ cfg_drop_prob = getattr(cfg, "cfg_drop_prob", cfg_drop_prob)
+ cond_codebook_size = getattr(cfg, "cond_codebook_size", cond_codebook_size)
+ time_scheduler = getattr(cfg, "time_scheduler", time_scheduler)
+ sigma = getattr(cfg, "sigma", sigma)
+ cond_scale_factor = getattr(cfg, "cond_scale_factor", cond_scale_factor)
+
+ self.mel_dim = mel_dim
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.num_heads = num_heads
+ self.cfg_drop_prob = cfg_drop_prob
+ self.cond_codebook_size = cond_codebook_size
+ self.time_scheduler = time_scheduler
+ self.sigma = sigma
+ self.cond_scale_factor = cond_scale_factor
+
+ if use_embedding:
+ self.cond_emb = nn.Embedding(cond_codebook_size, self.hidden_size)
+ else:
+ self.cond_emb = nn.Linear(cond_codebook_size, self.hidden_size)
+
+ if cond_scale_factor != 1:
+ self.do_resampling = True
+ assert np.log2(cond_scale_factor).is_integer()
+
+ up_layers = []
+ for _ in range(int(np.log2(cond_scale_factor))):
+ up_layers.extend(
+ [
+ nn.ConvTranspose1d(
+ hidden_size, hidden_size, kernel_size=4, stride=2, padding=1
+ ),
+ nn.GELU(),
+ ]
+ )
+ self.resampling_layers = nn.Sequential(*up_layers)
+ else:
+ self.do_resampling = False
+
+ ### REPA: Use the Wav2Vec2Bert features to align. ###
+ self.use_repa = "repa" in cfg
+ self.repa_layer_index = None
+ if self.use_repa:
+ self.repa_layer_index = cfg.repa.layer_index
+
+ self.repa_mlp_layer = nn.Sequential(
+ nn.Linear(hidden_size, hidden_size * 4),
+ nn.SiLU(),
+ nn.Linear(hidden_size * 4, cfg.repa.output_dim),
+ )
+
+ ### CTC: Use the ASR loss ###
+ self.use_ctc = "ctc" in cfg
+ self.ctc_layer_index = None
+ if self.use_ctc:
+ self.ctc_layer_index = cfg.ctc.layer_index
+
+ self.ctc_mlp_layer = nn.Sequential(
+ nn.Linear(hidden_size, hidden_size * 4),
+ nn.SiLU(),
+ nn.Linear(hidden_size * 4, cfg.ctc.output_dim),
+ )
+
+ self.reset_parameters()
+
+ self.diff_estimator = DiffLlama(
+ mel_dim=mel_dim,
+ hidden_size=hidden_size,
+ num_heads=num_heads,
+ num_layers=num_layers,
+ )
+
+ self.sigma = sigma
+
+ @torch.no_grad()
+ def forward_diffusion(self, x, t, is_prompt=None):
+ """
+ x: (B, T, mel_dim)
+ t: (B,)
+ """
+ new_t = t
+ t = t.unsqueeze(-1).unsqueeze(-1)
+ z = torch.randn(
+ x.shape, dtype=x.dtype, device=x.device, requires_grad=False
+ ) # (B, T, mel_dim)
+
+ # get prompt len
+ if torch.rand(1) <= self.cfg_drop_prob:
+ prompt_len = torch.zeros(x.shape[0]).to(x)
+ is_prompt = torch.zeros_like(x[:, :, 0])
+ else:
+ if is_prompt is None:
+ prompt_len = torch.randint(
+ min(x.shape[1] // 4, 5), int(x.shape[1] * 0.4), (x.shape[0],)
+ ).to(
+ x.device
+ ) # (B,)
+
+ # get is_prompt
+ is_prompt = torch.zeros_like(x[:, :, 0]) # (B, T)
+ col_indices = (
+ torch.arange(is_prompt.shape[1])
+ .repeat(is_prompt.shape[0], 1)
+ .to(prompt_len)
+ ) # (B, T)
+ is_prompt[col_indices < prompt_len.unsqueeze(1)] = 1 # (B, T) 1 if prompt
+ else:
+ prompt_len = is_prompt.sum(dim=1) # (B,)
+
+ mask = torch.ones_like(x[:, :, 0]) # mask if 1, not mask if 0
+ mask[is_prompt.bool()] = 0
+ mask = mask[:, :, None]
+
+ # flow matching: xt = (1 - (1 - sigma) * t) * x0 + t * x; where x0 ~ N(0, 1), x is a sample
+ # flow gt: x - (1 - sigma) * x0 = x - (1 - sigma) * noise
+ xt = ((1 - (1 - self.sigma) * t) * z + t * x) * mask + x * (1 - mask)
+
+ return xt, z, new_t, prompt_len, mask
+
+ def loss_t(
+ self,
+ x,
+ x_mask,
+ t,
+ cond=None,
+ is_prompt=None
+ ):
+ xt, z, new_t, prompt_len, mask = self.forward_diffusion(x, t, is_prompt)
+
+ noise = z
+
+ # drop all condition for cfg, so if prompt_len is 0, we also drop cond
+ if cond is not None:
+ cond = cond * torch.where(
+ prompt_len > 0,
+ torch.ones_like(prompt_len),
+ torch.zeros_like(prompt_len),
+ ).to(cond.device).unsqueeze(-1).unsqueeze(-1)
+
+ dit_output = self.diff_estimator(xt, new_t, cond, x_mask, return_dict=True)
+ flow_pred = dit_output["output"] # (B, T, mel_dim)
+
+ # final mask used for loss calculation
+ final_mask = mask * x_mask[..., None] # (B, T, 1)
+
+ results = {"output": (noise, x, flow_pred, final_mask, prompt_len)}
+
+ if self.use_repa:
+ repa_hidden_states = dit_output["hidden_states"][
+ self.repa_layer_index
+ ] # (B, T, hidden_size)
+
+ repa_pred = self.repa_mlp_layer(repa_hidden_states) # (B, T, repa_dim)
+ results["repa"] = repa_pred
+
+ if self.use_ctc:
+ ctc_hidden_states = dit_output["hidden_states"][
+ self.ctc_layer_index
+ ] # (B, T, hidden_size)
+ ctc_pred = self.ctc_mlp_layer(ctc_hidden_states) # (B, T, ctc_dim)
+ results["ctc"] = ctc_pred
+
+ return results
+
+ def compute_loss(self, x, x_mask, cond=None, is_prompt=None):
+ # x0: (B, T, num_quantizer)
+ # x_mask: (B, T) mask is 0 for padding
+ t = torch.rand(x.shape[0], device=x.device, requires_grad=False)
+ t = torch.clamp(t, 1e-5, 1.0)
+ # from CosyVoice: considering the generation process at the beginning is harder than follows, we involve a cosine scheduler for the timestep t
+ if self.time_scheduler == "cos":
+ t = 1 - torch.cos(t * math.pi * 0.5)
+ else:
+ pass
+ return self.loss_t(x, x_mask, t, cond, is_prompt)
+
+ def reset_parameters(self):
+ def _reset_parameters(m):
+ if isinstance(m, nn.MultiheadAttention):
+ if m._qkv_same_embed_dim:
+ nn.init.normal_(m.in_proj_weight, std=0.02)
+ else:
+ nn.init.normal_(m.q_proj_weight, std=0.02)
+ nn.init.normal_(m.k_proj_weight, std=0.02)
+ nn.init.normal_(m.v_proj_weight, std=0.02)
+
+ if m.in_proj_bias is not None:
+ nn.init.constant_(m.in_proj_bias, 0.0)
+ nn.init.constant_(m.out_proj.bias, 0.0)
+ if m.bias_k is not None:
+ nn.init.xavier_normal_(m.bias_k)
+ if m.bias_v is not None:
+ nn.init.xavier_normal_(m.bias_v)
+
+ elif (
+ isinstance(m, nn.Conv1d)
+ or isinstance(m, nn.ConvTranspose1d)
+ or isinstance(m, nn.Conv2d)
+ or isinstance(m, nn.ConvTranspose2d)
+ ):
+ m.weight.data.normal_(0.0, 0.02)
+
+ elif isinstance(m, nn.Linear):
+ m.weight.data.normal_(mean=0.0, std=0.02)
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ elif isinstance(m, nn.Embedding):
+ m.weight.data.normal_(mean=0.0, std=0.02)
+ if m.padding_idx is not None:
+ m.weight.data[m.padding_idx].zero_()
+
+ self.apply(_reset_parameters)
+
+ @torch.no_grad()
+ def reverse_diffusion(
+ self,
+ cond,
+ prompt,
+ x_mask=None,
+ prompt_mask=None,
+ n_timesteps=10,
+ cfg=1.0,
+ rescale_cfg=0.75,
+ ):
+ h = 1.0 / n_timesteps
+ prompt_len = prompt.shape[1]
+ target_len = cond.shape[1] - prompt_len
+
+ if x_mask == None:
+ x_mask = torch.ones(cond.shape[0], target_len).to(cond.device) # (B, T)
+ if prompt_mask == None:
+ prompt_mask = torch.ones(cond.shape[0], prompt_len).to(
+ cond.device
+ ) # (B, prompt_len)
+ xt_mask = torch.cat([prompt_mask, x_mask], dim=1)
+ z = torch.randn(
+ (cond.shape[0], target_len, self.mel_dim),
+ dtype=cond.dtype,
+ device=cond.device,
+ requires_grad=False,
+ )
+ xt = z
+
+ # t from 0 to 1: x0 = z ~ N(0, 1)
+ for i in range(n_timesteps):
+ xt_input = torch.cat([prompt, xt], dim=1)
+ t = (0 + (i + 0.5) * h) * torch.ones(
+ z.shape[0], dtype=z.dtype, device=z.device
+ )
+ flow_pred = self.diff_estimator(xt_input, t, cond, xt_mask)
+ flow_pred = flow_pred[:, prompt_len:, :]
+
+ # cfg
+ if cfg > 0:
+ uncond_flow_pred = self.diff_estimator(
+ xt, t, torch.zeros_like(cond)[:, : xt.shape[1], :], x_mask
+ )
+ pos_flow_pred_std = flow_pred.std()
+ flow_pred_cfg = flow_pred + cfg * (flow_pred - uncond_flow_pred)
+ rescale_flow_pred = (
+ flow_pred_cfg * pos_flow_pred_std / flow_pred_cfg.std()
+ )
+ flow_pred = (
+ rescale_cfg * rescale_flow_pred + (1 - rescale_cfg) * flow_pred_cfg
+ )
+
+ dxt = flow_pred * h
+ xt = xt + dxt
+
+ return xt
+
+ @torch.no_grad()
+ def reverse_diffusion_v2(
+ self,
+ cond,
+ prompt,
+ x_mask=None,
+ prompt_mask=None,
+ n_timesteps=10,
+ cfg=1.0,
+ rescale_cfg=0.75,
+ ):
+ h = 1.0 / n_timesteps
+ prompt_len = prompt.shape[1]
+ target_len = cond.shape[1] - prompt_len * 2
+
+ if x_mask == None:
+ x_mask = torch.ones(cond.shape[0], target_len).to(cond.device) # (B, T)
+ if prompt_mask == None:
+ prompt_mask = torch.ones(cond.shape[0], prompt_len).to(
+ cond.device
+ ) # (B, prompt_len)
+ xt_mask = torch.cat([prompt_mask, x_mask, prompt_mask], dim=1)
+ z = torch.randn(
+ (cond.shape[0], target_len, self.mel_dim),
+ dtype=cond.dtype,
+ device=cond.device,
+ requires_grad=False,
+ )
+ xt = z
+
+ # t from 0 to 1: x0 = z ~ N(0, 1)
+ for i in range(n_timesteps):
+ xt_input = torch.cat([prompt, xt, prompt], dim=1)
+ t = (0 + (i + 0.5) * h) * torch.ones(
+ z.shape[0], dtype=z.dtype, device=z.device
+ )
+ flow_pred = self.diff_estimator(xt_input, t, cond, xt_mask)
+ flow_pred = flow_pred[:, prompt_len:-prompt_len, :]
+
+ # cfg
+ if cfg > 0:
+ uncond_flow_pred = self.diff_estimator(
+ xt, t, torch.zeros_like(cond)[:, : xt.shape[1], :], x_mask
+ )
+ pos_flow_pred_std = flow_pred.std()
+ flow_pred_cfg = flow_pred + cfg * (flow_pred - uncond_flow_pred)
+ rescale_flow_pred = (
+ flow_pred_cfg * pos_flow_pred_std / flow_pred_cfg.std()
+ )
+ flow_pred = (
+ rescale_cfg * rescale_flow_pred + (1 - rescale_cfg) * flow_pred_cfg
+ )
+
+ dxt = flow_pred * h
+ xt = xt + dxt
+
+ return xt
+
+ def forward(self, x, x_mask, cond_code, is_prompt=None):
+ """
+ Args:
+ x: (B, T, mel_dim)
+ x_mask: (B, T)
+ cond_code: (B, T), Note that cond_code might be not at 50Hz!
+ """
+ T = x.shape[1]
+
+ cond = self.cond_emb(cond_code) # (B, T, hidden_size)
+ if self.do_resampling:
+ # Align to the frame rate of Mels
+ cond = self.resampling_layers(cond.transpose(1, 2)).transpose(1, 2)
+
+ # print("cond_code: {}, after resampling: {}".format(cond_code.shape, cond.shape))
+
+ if cond.shape[1] >= T: # Check time dimension
+ cond = cond[:, :T, :]
+ else:
+ padding_frames = T - cond.shape[1]
+ last_frame = cond[:, -1:, :]
+ padding = last_frame.repeat(1, padding_frames, 1)
+ cond = torch.cat([cond, padding], dim=1)
+
+ return self.compute_loss(x, x_mask, cond, is_prompt)
+
+
+if __name__ == "__main__":
+
+ model_cfg = {
+ "mel_dim": 128,
+ "hidden_size": 256,
+ "num_layers": 8,
+ "num_heads": 8,
+ "cfg_drop_prob": 0.2,
+ "use_embedding": False,
+ "cond_codebook_size": 256,
+ "cond_scale_factor": 1,
+ "sigma": 1e-5,
+ "time_scheduler": "cos",
+ }
+
+ device = "cuda"
+ x = torch.randn(2, 100, 128).to(device)
+ x_mask = torch.ones(2, 100).to(device)
+ # cond_code = torch.randint(0, 16384, (2, 25)).to(device)
+ cond_code = torch.randn(2, 100, 256).to(device)
+
+ model = FlowMatchingTransformer(cfg=model_cfg, **model_cfg).to(device)
+ outputs = model(x, x_mask, cond_code)
+ print(outputs)
+
+ noise, x, flow_pred, final_mask, prompt_len = outputs["output"]
+ final_mask = final_mask.squeeze(-1)
+
+ flow_gt = x - (1 - 1e-5) * noise
+
+ # [B, n_frames, D]
+ diff_loss = F.l1_loss(
+ flow_pred, flow_gt, reduction="none"
+ ).float() * final_mask.unsqueeze(-1)
+ diff_loss = torch.mean(diff_loss, dim=2).sum() / final_mask.sum()
+
+ print("diff_loss:", diff_loss.item())
+
+
+ diffusion_cond = torch.randn(2, 150, 256).to(device)
+ diffusion_cond_emb = model.cond_emb(diffusion_cond)
+ diffusion_prompt = torch.randn(2, 50, 128).to(device)
+ n_timesteps = 32
+
+ generated = model.reverse_diffusion(
+ diffusion_cond_emb,
+ diffusion_prompt,
+ n_timesteps=n_timesteps
+ )
+ print("generated:", generated.shape)
\ No newline at end of file
diff --git a/soulxsinger/models/modules/llama.py b/soulxsinger/models/modules/llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9fee9c17f6cef5c020808c81d9a67d99cfc0287
--- /dev/null
+++ b/soulxsinger/models/modules/llama.py
@@ -0,0 +1,392 @@
+from transformers import LlamaConfig, LlamaModel
+import torch
+import torch.nn as nn
+from typing import List, Optional, Tuple, Union
+import math
+
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
+
+
+# sinusoidal positional encoding
+class SinusoidalPosEmb(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ device = x.device
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+ emb = x[:, None] * emb[None, :] * 1.0
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+ return emb
+
+
+class LlamaAdaptiveRMSNorm(nn.Module):
+ def __init__(self, hidden_size=1024, eps=1e-6, dim_cond=1024):
+ super().__init__()
+ self.to_weight = nn.Linear(dim_cond, hidden_size)
+ nn.init.zeros_(self.to_weight.weight)
+ nn.init.ones_(self.to_weight.bias)
+ self.variance_epsilon = eps
+ self._is_hf_initialized = True # disable automatic init
+
+ def forward(self, hidden_states, cond_embedding):
+ input_dtype = hidden_states.dtype
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ weight = self.to_weight(cond_embedding)
+ if len(weight.shape) == 2:
+ weight = weight.unsqueeze(1)
+
+ return (weight * hidden_states).to(input_dtype)
+
+
+class LlamaNARDecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: LlamaConfig, layer_idx: int):
+ """Override to adaptive layer norm"""
+ super().__init__(config, layer_idx) # init attention, mlp, etc.
+ self.input_layernorm = LlamaAdaptiveRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
+ )
+ self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
+ )
+
+ # add `cond` in forward function
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cond_embedding: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
+ ]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(
+ hidden_states, cond_embedding=cond_embedding
+ )
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(
+ hidden_states, cond_embedding=cond_embedding
+ )
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class DiffLlama(LlamaModel):
+ def __init__(
+ self,
+ mel_dim=100,
+ hidden_size=1024,
+ num_heads=16,
+ num_layers=16,
+ dropout=0.1,
+ ffn_dropout=0.1,
+ attention_dropout=0.0,
+ config=LlamaConfig(0, 256, 1024, 1, 1),
+ ):
+ super().__init__(config)
+
+ self.layers = nn.ModuleList(
+ [
+ LlamaNARDecoderLayer(
+ LlamaConfig(
+ hidden_size=hidden_size,
+ num_attention_heads=num_heads,
+ max_position_embeddings=4096,
+ intermediate_size=hidden_size * 4,
+ ),
+ layer_idx=i,
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size)
+
+ self.diff_step_embedding = SinusoidalPosEmb(hidden_size)
+ self.diff_step_mlp = nn.Sequential(
+ nn.Linear(hidden_size, hidden_size * 4),
+ nn.SiLU(),
+ nn.Linear(hidden_size * 4, hidden_size),
+ )
+
+ self.cond_mlp = nn.Sequential(
+ nn.Linear(hidden_size, hidden_size * 4),
+ nn.SiLU(),
+ nn.Linear(hidden_size * 4, hidden_size),
+ )
+
+ self.mel_mlp = nn.Sequential(
+ nn.Linear(mel_dim, hidden_size * 4),
+ nn.SiLU(),
+ nn.Linear(hidden_size * 4, hidden_size),
+ )
+
+ self.mel_out_mlp = nn.Sequential(
+ nn.Linear(hidden_size, hidden_size * 4),
+ nn.SiLU(),
+ nn.Linear(hidden_size * 4, mel_dim),
+ )
+
+ for layer in self.layers:
+ layer.input_layernorm = LlamaAdaptiveRMSNorm(
+ hidden_size, dim_cond=hidden_size
+ )
+ layer.post_attention_layernorm = LlamaAdaptiveRMSNorm(
+ hidden_size, dim_cond=hidden_size
+ )
+
+ self.embed_tokens = None
+
+ self.post_init()
+
+ # self.reset_parameters()
+
+ def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
+ ):
+ # create noncausal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+
+ def _expand_mask(
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
+ ):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = (
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+ )
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
+ )
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ ).to(inputs_embeds.device)
+ combined_attention_mask = (
+ expanded_attn_mask
+ if combined_attention_mask is None
+ else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ def forward(
+ self,
+ x,
+ diffusion_step,
+ cond,
+ x_mask,
+ input_ids: torch.LongTensor = None, # [num_quant, B, T]
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = False,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+
+ # retrieve some shape info
+ batch_size, seq_length, _ = x.shape
+
+ # condtion mlp
+ cond_embedding = self.cond_mlp(cond) # (B, T, C)
+
+ # condition mel
+ x = self.mel_mlp(x)
+
+ # diffusion step embedding
+ diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device)
+ diffusion_step = self.diff_step_mlp(diffusion_step) # (B, C)
+ x = x + cond_embedding
+
+ inputs_embeds = x
+ attention_mask = x_mask
+
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past),
+ dtype=torch.bool,
+ device=inputs_embeds.device,
+ )
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ )
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ all_layer_hidden_states = []
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = (
+ past_key_values[idx] if past_key_values is not None else None
+ )
+
+ if self.gradient_checkpointing and self.training:
+ raise NotImplementedError
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cond_embedding=diffusion_step,
+ )
+
+ hidden_states = layer_outputs[0]
+ all_layer_hidden_states.append(hidden_states.clone())
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+
+ hidden_states = self.mel_out_mlp(hidden_states)
+
+ # if not return_dict:
+ # return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ # return BaseModelOutputWithPast(
+ # last_hidden_state=hidden_states,
+ # past_key_values=next_cache,
+ # hidden_states=all_hidden_states,
+ # attentions=all_self_attns,
+ # )
+ if return_dict:
+ return {
+ "output": hidden_states,
+ "hidden_states": all_layer_hidden_states,
+ }
+
+ return hidden_states
diff --git a/soulxsinger/models/modules/mel_transform.py b/soulxsinger/models/modules/mel_transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..babcf59752c129039e9969ef8382a8ca1c700f13
--- /dev/null
+++ b/soulxsinger/models/modules/mel_transform.py
@@ -0,0 +1,151 @@
+import torch
+import math
+import numpy as np
+from librosa.filters import mel as librosa_mel_fn
+import torch.nn as nn
+from typing import Any, Dict, Optional
+
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+ return np.exp(x) / C
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+ return torch.exp(x) / C
+
+
+def spectral_normalize_torch(magnitudes):
+ output = dynamic_range_compression_torch(magnitudes)
+ return output
+
+
+def spectral_de_normalize_torch(magnitudes):
+ output = dynamic_range_decompression_torch(magnitudes)
+ return output
+
+
+class MelSpectrogram(nn.Module):
+ def __init__(
+ self,
+ n_fft,
+ num_mels,
+ sampling_rate,
+ hop_size,
+ win_size,
+ fmin,
+ fmax,
+ center=False,
+ ):
+ super(MelSpectrogram, self).__init__()
+ self.n_fft = n_fft
+ self.hop_size = hop_size
+ self.win_size = win_size
+ self.sampling_rate = sampling_rate
+ self.num_mels = num_mels
+ self.fmin = fmin
+ self.fmax = fmax
+ self.center = center
+
+ mel_basis = {}
+ hann_window = {}
+
+ mel = librosa_mel_fn(
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
+ )
+ mel_basis = torch.from_numpy(mel).float()
+ hann_window = torch.hann_window(win_size)
+
+ self.register_buffer("mel_basis", mel_basis)
+ self.register_buffer("hann_window", hann_window)
+
+ def forward(self, y):
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1),
+ (
+ int((self.n_fft - self.hop_size) / 2),
+ int((self.n_fft - self.hop_size) / 2),
+ ),
+ mode="reflect",
+ )
+ y = y.squeeze(1)
+ spec = torch.stft(
+ y,
+ self.n_fft,
+ hop_length=self.hop_size,
+ win_length=self.win_size,
+ window=self.hann_window,
+ center=self.center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=True,
+ )
+ spec = torch.view_as_real(spec)
+
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
+
+ spec = torch.matmul(self.mel_basis, spec)
+ spec = spectral_normalize_torch(spec)
+
+ return spec
+
+
+def load_mel_spectrogram():
+ return load_mel_spectrogram_from_cfg(None)
+
+
+def _get_from_mapping(cfg: Any, key: str, default: Any = None) -> Any:
+ """Safely read a field from a dict/OmegaConf-like object."""
+ if cfg is None:
+ return default
+ if isinstance(cfg, dict):
+ return cfg.get(key, default)
+ return getattr(cfg, key, default)
+
+
+def load_mel_spectrogram_from_cfg(audio_cfg: Optional[Any] = None) -> MelSpectrogram:
+ """Build MelSpectrogram from `audio_config`-like config.
+
+ Expected keys (either in dict or Hydra/OmegaConf object):
+ - hop_size, sample_rate (or sampling_rate), n_fft, num_mels, win_size, fmin, fmax
+ """
+ # Defaults keep current behavior.
+ mel_cfg: Dict[str, Any] = {
+ "hop_size": _get_from_mapping(audio_cfg, "hop_size", 480),
+ "sampling_rate": _get_from_mapping(
+ audio_cfg,
+ "sampling_rate",
+ _get_from_mapping(audio_cfg, "sample_rate", 24000),
+ ),
+ "n_fft": _get_from_mapping(audio_cfg, "n_fft", 1920),
+ "num_mels": _get_from_mapping(audio_cfg, "num_mels", 128),
+ "win_size": _get_from_mapping(audio_cfg, "win_size", 1920),
+ "fmin": _get_from_mapping(audio_cfg, "fmin", 0),
+ "fmax": _get_from_mapping(audio_cfg, "fmax", 12000),
+ }
+
+ mel_model = MelSpectrogram(**mel_cfg)
+ mel_model.eval()
+ return mel_model
+
+
+class MelSpectrogramEncoder(nn.Module):
+ def __init__(self, audio_config: dict | None = None):
+ super(MelSpectrogramEncoder, self).__init__()
+ self.model = load_mel_spectrogram_from_cfg(audio_config)
+ audio_config = audio_config or {}
+ self.mel_mean = audio_config.get("mel_mean", -4.92)
+ self.mel_var = audio_config.get("mel_var", 8.14)
+
+ def forward(self, x):
+ x = self.model(x).transpose(1, 2)
+ x = (x - self.mel_mean) / math.sqrt(self.mel_var)
+ return x
\ No newline at end of file
diff --git a/soulxsinger/models/modules/vocoder.py b/soulxsinger/models/modules/vocoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e739098c5232c5805f3a3f4f077a6840cd302e74
--- /dev/null
+++ b/soulxsinger/models/modules/vocoder.py
@@ -0,0 +1,1037 @@
+from typing import Optional, Tuple
+
+import numpy as np
+import scipy
+import torch
+from torch import nn, view_as_real, view_as_complex
+from torch import nn
+from torch.nn.utils import weight_norm, remove_weight_norm
+from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
+import accelerate
+
+from omegaconf import DictConfig
+
+
+def _deep_update_dict(base: dict, override: Optional[dict]) -> dict:
+ """Recursively merge `override` into `base` (dict only)."""
+ if not override:
+ return base
+ for k, v in override.items():
+ if isinstance(v, dict) and isinstance(base.get(k), dict):
+ base[k] = _deep_update_dict(base[k], v)
+ else:
+ base[k] = v
+ return base
+
+
+def _get_vocoder_default_cfg() -> dict:
+ # Defaults keep current behavior.
+ return {
+ "preprocess": {
+ "hop_size": 480,
+ "sample_rate": 24000,
+ "max_length": 36000,
+ "n_fft": 1920,
+ "num_mels": 128,
+ "win_size": 1920,
+ "fmin": 0,
+ "fmax": 12000,
+ "mel_var": 8.14,
+ "mel_mean": -4.92,
+ "load_phone": False,
+ "load_chromagram": False,
+ },
+ "model": {
+ "vocos": {
+ "input_channels": 128,
+ "dim": 1024,
+ "intermediate_dim": 4096,
+ "num_layers": 30,
+ "n_fft": 1920,
+ "hop_size": 480,
+ "padding": "same",
+ },
+ "period_gan": {
+ "max_downsample_channels": 1024,
+ "channels": 64,
+ "channel_increasing_factor": 2,
+ },
+ "spec_gan": {
+ "stft_params": {
+ "fft_sizes": [128, 256, 512, 1024, 2048],
+ "hop_sizes": [32, 64, 128, 256, 512],
+ "win_lengths": [128, 256, 512, 1024, 2048],
+ "window": "hann_window",
+ },
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": 64,
+ "kernel_sizes": [5, 3],
+ "max_downsample_channels": 1024,
+ "down_scales": [2, 2, 2],
+ "use_weight_norm": True,
+ "use_complex": False,
+ },
+ },
+ "loss": {
+ "mel_loss": {"sample_rate": 24000},
+ "disc_loss_weight": 1.0,
+ "mel_loss_weight": 10.0,
+ "adv_loss_weight": 2.0,
+ "fm_loss_weight": 2.0,
+ "spec_fm_loss_weight": 1.0,
+ },
+ }
+
+
+def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
+ """
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
+
+ Args:
+ x (Tensor): Input tensor.
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
+
+ Returns:
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
+ """
+ return torch.log(torch.clip(x, min=clip_val))
+
+
+def symlog(x: torch.Tensor) -> torch.Tensor:
+ return torch.sign(x) * torch.log1p(x.abs())
+
+
+def symexp(x: torch.Tensor) -> torch.Tensor:
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
+
+
+class STFT(nn.Module):
+ def __init__(
+ self,
+ n_fft: int,
+ hop_length: int,
+ win_length: int,
+ center=True,
+ ):
+ super().__init__()
+ self.center = center
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.win_length = win_length
+ window = torch.hann_window(win_length)
+ self.register_buffer("window", window)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # x: (B, T * hop_length)
+
+ if not self.center:
+ pad = self.win_length - self.hop_length
+ x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
+
+ stft_spec = torch.stft(
+ x,
+ self.n_fft,
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ window=self.window,
+ center=self.center,
+ return_complex=False,
+ ) # (B, n_fft // 2 + 1, T, 2)
+
+ rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
+ imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
+
+ log_mag = torch.log(
+ torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
+ ) # (B, n_fft // 2 + 1, T)
+ phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
+
+ return log_mag, phase
+
+
+class ISTFT(nn.Module):
+ """
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
+ See issue: https://github.com/pytorch/pytorch/issues/62323
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
+ The NOLA constraint is met as we trim padded samples anyway.
+
+ Args:
+ n_fft (int): Size of Fourier transform.
+ hop_length (int): The distance between neighboring sliding window frames.
+ win_length (int): The size of window frame and STFT filter.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(
+ self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
+ ):
+ super().__init__()
+ if padding not in ["center", "same"]:
+ raise ValueError("Padding must be 'center' or 'same'.")
+ self.padding = padding
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.win_length = win_length
+ window = torch.hann_window(win_length)
+ self.register_buffer("window", window)
+
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
+
+ Args:
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
+ N is the number of frequency bins, and T is the number of time frames.
+
+ Returns:
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
+ """
+ if self.padding == "center":
+ # Fallback to pytorch native implementation
+ return torch.istft(
+ spec,
+ self.n_fft,
+ self.hop_length,
+ self.win_length,
+ self.window,
+ center=True,
+ )
+ elif self.padding == "same":
+ pad = (self.win_length - self.hop_length) // 2
+ else:
+ raise ValueError("Padding must be 'center' or 'same'.")
+
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
+ B, N, T = spec.shape
+
+ # Inverse FFT
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
+ ifft = ifft * self.window[None, :, None]
+
+ # Overlap and Add
+ output_size = (T - 1) * self.hop_length + self.win_length
+ y = torch.nn.functional.fold(
+ ifft,
+ output_size=(1, output_size),
+ kernel_size=(1, self.win_length),
+ stride=(1, self.hop_length),
+ )[:, 0, 0, pad:-pad]
+
+ # Window envelope
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
+ window_envelope = torch.nn.functional.fold(
+ window_sq,
+ output_size=(1, output_size),
+ kernel_size=(1, self.win_length),
+ stride=(1, self.hop_length),
+ ).squeeze()[pad:-pad]
+
+ # Normalize
+ assert (window_envelope > 1e-11).all()
+ y = y / window_envelope
+
+ return y
+
+
+class MDCT(nn.Module):
+ """
+ Modified Discrete Cosine Transform (MDCT) module.
+
+ Args:
+ frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(self, frame_len: int, padding: str = "same"):
+ super().__init__()
+ if padding not in ["center", "same"]:
+ raise ValueError("Padding must be 'center' or 'same'.")
+ self.padding = padding
+ self.frame_len = frame_len
+ N = frame_len // 2
+ n0 = (N + 1) / 2
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
+ self.register_buffer("window", window)
+
+ pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
+ post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
+ # view_as_real: NCCL Backend does not support ComplexFloat data type
+ # https://github.com/pytorch/pytorch/issues/71613
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
+
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
+ """
+ Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
+
+ Args:
+ audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
+ and T is the length of the audio.
+
+ Returns:
+ Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
+ and N is the number of frequency bins.
+ """
+ if self.padding == "center":
+ audio = torch.nn.functional.pad(
+ audio, (self.frame_len // 2, self.frame_len // 2)
+ )
+ elif self.padding == "same":
+ # hop_length is 1/2 frame_len
+ audio = torch.nn.functional.pad(
+ audio, (self.frame_len // 4, self.frame_len // 4)
+ )
+ else:
+ raise ValueError("Padding must be 'center' or 'same'.")
+
+ x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
+ N = self.frame_len // 2
+ x = x * self.window.expand(x.shape)
+ X = torch.fft.fft(
+ x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
+ )[..., :N]
+ res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
+ return torch.real(res) * np.sqrt(2)
+
+
+class IMDCT(nn.Module):
+ """
+ Inverse Modified Discrete Cosine Transform (IMDCT) module.
+
+ Args:
+ frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(self, frame_len: int, padding: str = "same"):
+ super().__init__()
+ if padding not in ["center", "same"]:
+ raise ValueError("Padding must be 'center' or 'same'.")
+ self.padding = padding
+ self.frame_len = frame_len
+ N = frame_len // 2
+ n0 = (N + 1) / 2
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
+ self.register_buffer("window", window)
+
+ pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
+ post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
+
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
+ """
+ Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
+
+ Args:
+ X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
+ L is the number of frames, and N is the number of frequency bins.
+
+ Returns:
+ Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
+ """
+ B, L, N = X.shape
+ Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
+ Y[..., :N] = X
+ Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
+ y = torch.fft.ifft(
+ Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
+ )
+ y = (
+ torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
+ * np.sqrt(N)
+ * np.sqrt(2)
+ )
+ result = y * self.window.expand(y.shape)
+ output_size = (1, (L + 1) * N)
+ audio = torch.nn.functional.fold(
+ result.transpose(1, 2),
+ output_size=output_size,
+ kernel_size=(1, self.frame_len),
+ stride=(1, self.frame_len // 2),
+ )[:, 0, 0, :]
+
+ if self.padding == "center":
+ pad = self.frame_len // 2
+ elif self.padding == "same":
+ pad = self.frame_len // 4
+ else:
+ raise ValueError("Padding must be 'center' or 'same'.")
+
+ audio = audio[:, pad:-pad]
+ return audio
+
+
+class FourierHead(nn.Module):
+ """Base class for inverse fourier modules."""
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ raise NotImplementedError("Subclasses must implement the forward method.")
+
+
+class ISTFTHead(FourierHead):
+ """
+ ISTFT Head module for predicting STFT complex coefficients.
+
+ Args:
+ dim (int): Hidden dimension of the model.
+ n_fft (int): Size of Fourier transform.
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
+ the resolution of the input features.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
+ super().__init__()
+ out_dim = n_fft + 2
+ self.out = torch.nn.Linear(dim, out_dim)
+ self.istft = ISTFT(
+ n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the ISTFTHead module.
+
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ x = self.out(x).transpose(1, 2)
+ mag, p = x.chunk(2, dim=1)
+ mag = torch.exp(mag)
+ mag = torch.clip(
+ mag, max=1e2
+ ) # safeguard to prevent excessively large magnitudes
+ # wrapping happens here. These two lines produce real and imaginary value
+ x = torch.cos(p)
+ y = torch.sin(p)
+ # recalculating phase here does not produce anything new
+ # only costs time
+ # phase = torch.atan2(y, x)
+ # S = mag * torch.exp(phase * 1j)
+ # better directly produce the complex value
+ S = mag * (x + 1j * y)
+ audio = self.istft(S)
+ return audio
+
+
+class IMDCTSymExpHead(FourierHead):
+ """
+ IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
+
+ Args:
+ dim (int): Hidden dimension of the model.
+ mdct_frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
+ based on perceptual scaling. Defaults to None.
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ mdct_frame_len: int,
+ padding: str = "same",
+ sample_rate: Optional[int] = None,
+ clip_audio: bool = False,
+ ):
+ super().__init__()
+ out_dim = mdct_frame_len // 2
+ self.out = nn.Linear(dim, out_dim)
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
+ self.clip_audio = clip_audio
+
+ if sample_rate is not None:
+ # optionally init the last layer following mel-scale
+ m_max = _hz_to_mel(sample_rate // 2)
+ m_pts = torch.linspace(0, m_max, out_dim)
+ f_pts = _mel_to_hz(m_pts)
+ scale = 1 - (f_pts / f_pts.max())
+
+ with torch.no_grad():
+ self.out.weight.mul_(scale.view(-1, 1))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the IMDCTSymExpHead module.
+
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ x = self.out(x)
+ x = symexp(x)
+ x = torch.clip(
+ x, min=-1e2, max=1e2
+ ) # safeguard to prevent excessively large magnitudes
+ audio = self.imdct(x)
+ if self.clip_audio:
+ audio = torch.clip(x, min=-1.0, max=1.0)
+
+ return audio
+
+
+class IMDCTCosHead(FourierHead):
+ """
+ IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
+
+ Args:
+ dim (int): Hidden dimension of the model.
+ mdct_frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ mdct_frame_len: int,
+ padding: str = "same",
+ clip_audio: bool = False,
+ ):
+ super().__init__()
+ self.clip_audio = clip_audio
+ self.out = nn.Linear(dim, mdct_frame_len)
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the IMDCTCosHead module.
+
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ x = self.out(x)
+ m, p = x.chunk(2, dim=2)
+ m = torch.exp(m).clip(
+ max=1e2
+ ) # safeguard to prevent excessively large magnitudes
+ audio = self.imdct(m * torch.cos(p))
+ if self.clip_audio:
+ audio = torch.clip(x, min=-1.0, max=1.0)
+ return audio
+
+
+class ConvNeXtBlock(nn.Module):
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
+
+ Args:
+ dim (int): Number of input channels.
+ intermediate_dim (int): Dimensionality of the intermediate layer.
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
+ Defaults to None.
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
+ None means non-conditional LayerNorm. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ intermediate_dim: int,
+ layer_scale_init_value: float,
+ adanorm_num_embeddings: Optional[int] = None,
+ ):
+ super().__init__()
+ self.dwconv = nn.Conv1d(
+ dim, dim, kernel_size=7, padding=3, groups=dim
+ ) # depthwise conv
+ self.adanorm = adanorm_num_embeddings is not None
+ if adanorm_num_embeddings:
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
+ else:
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(
+ dim, intermediate_dim
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+
+ def forward(
+ self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ residual = x
+ x = self.dwconv(x)
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
+ if self.adanorm:
+ assert cond_embedding_id is not None
+ x = self.norm(x, cond_embedding_id)
+ else:
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
+
+ x = residual + x
+ return x
+
+
+class AdaLayerNorm(nn.Module):
+ """
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
+
+ Args:
+ num_embeddings (int): Number of embeddings.
+ embedding_dim (int): Dimension of the embeddings.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.dim = embedding_dim
+ self.scale = nn.Embedding(
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
+ )
+ self.shift = nn.Embedding(
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
+ )
+ torch.nn.init.ones_(self.scale.weight)
+ torch.nn.init.zeros_(self.shift.weight)
+
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
+ scale = self.scale(cond_embedding_id)
+ shift = self.shift(cond_embedding_id)
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
+ x = x * scale + shift
+ return x
+
+
+class ResBlock1(nn.Module):
+ """
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
+ but without upsampling layers.
+
+ Args:
+ dim (int): Number of input channels.
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
+ Defaults to (1, 3, 5).
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
+ Defaults to 0.1.
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
+ Defaults to None.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ kernel_size: int = 3,
+ dilation: Tuple[int, int, int] = (1, 3, 5),
+ lrelu_slope: float = 0.1,
+ layer_scale_init_value: Optional[float] = None,
+ ):
+ super().__init__()
+ self.lrelu_slope = lrelu_slope
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=self.get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=self.get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=self.get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=self.get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=self.get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=self.get_padding(kernel_size, 1),
+ )
+ ),
+ ]
+ )
+
+ self.gamma = nn.ParameterList(
+ [
+ (
+ nn.Parameter(
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
+ )
+ if layer_scale_init_value is not None
+ else None
+ ),
+ (
+ nn.Parameter(
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
+ )
+ if layer_scale_init_value is not None
+ else None
+ ),
+ (
+ nn.Parameter(
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
+ )
+ if layer_scale_init_value is not None
+ else None
+ ),
+ ]
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
+ xt = c1(xt)
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
+ xt = c2(xt)
+ if gamma is not None:
+ xt = gamma * xt
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+ @staticmethod
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+class Backbone(nn.Module):
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Args:
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
+ C denotes output features, and L is the sequence length.
+
+ Returns:
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
+ and H denotes the model dimension.
+ """
+ raise NotImplementedError("Subclasses must implement the forward method.")
+
+
+class VocosBackbone(Backbone):
+ """
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
+
+ Args:
+ input_channels (int): Number of input features channels.
+ dim (int): Hidden dimension of the model.
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
+ num_layers (int): Number of ConvNeXtBlock layers.
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
+ None means non-conditional model. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ input_channels: int,
+ dim: int,
+ intermediate_dim: int,
+ num_layers: int,
+ layer_scale_init_value: Optional[float] = None,
+ adanorm_num_embeddings: Optional[int] = None,
+ ):
+ super().__init__()
+ self.input_channels = input_channels
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
+ self.adanorm = adanorm_num_embeddings is not None
+ if adanorm_num_embeddings:
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
+ else:
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
+ self.convnext = nn.ModuleList(
+ [
+ ConvNeXtBlock(
+ dim=dim,
+ intermediate_dim=intermediate_dim,
+ layer_scale_init_value=layer_scale_init_value,
+ adanorm_num_embeddings=adanorm_num_embeddings,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ bandwidth_id = kwargs.get("bandwidth_id", None)
+ x = self.embed(x)
+ if self.adanorm:
+ assert bandwidth_id is not None
+ x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
+ else:
+ x = self.norm(x.transpose(1, 2))
+ x = x.transpose(1, 2)
+ for conv_block in self.convnext:
+ x = conv_block(x, cond_embedding_id=bandwidth_id)
+ x = self.final_layer_norm(x.transpose(1, 2))
+ return x
+
+
+class VocosResNetBackbone(Backbone):
+ """
+ Vocos backbone module built with ResBlocks.
+
+ Args:
+ input_channels (int): Number of input features channels.
+ dim (int): Hidden dimension of the model.
+ num_blocks (int): Number of ResBlock1 blocks.
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ input_channels,
+ dim,
+ num_blocks,
+ layer_scale_init_value=None,
+ ):
+ super().__init__()
+ self.input_channels = input_channels
+ self.embed = weight_norm(
+ nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
+ )
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
+ self.resnet = nn.Sequential(
+ *[
+ ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
+ for _ in range(num_blocks)
+ ]
+ )
+
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
+ x = self.embed(x)
+ x = self.resnet(x)
+ x = x.transpose(1, 2)
+ return x
+
+
+class Vocos(nn.Module):
+ def __init__(
+ self,
+ input_channels: int = 256,
+ dim: int = 384,
+ intermediate_dim: int = 1152,
+ num_layers: int = 8,
+ n_fft: int = 800,
+ hop_size: int = 200,
+ padding: str = "same",
+ adanorm_num_embeddings=None,
+ cfg=None,
+ ):
+ super().__init__()
+
+ input_channels = (
+ cfg.input_channels
+ if cfg is not None and hasattr(cfg, "input_channels")
+ else input_channels
+ )
+ dim = cfg.dim if cfg is not None and hasattr(cfg, "dim") else dim
+ intermediate_dim = (
+ cfg.intermediate_dim
+ if cfg is not None and hasattr(cfg, "intermediate_dim")
+ else intermediate_dim
+ )
+ num_layers = (
+ cfg.num_layers
+ if cfg is not None and hasattr(cfg, "num_layers")
+ else num_layers
+ )
+ adanorm_num_embeddings = (
+ cfg.adanorm_num_embeddings
+ if cfg is not None and hasattr(cfg, "adanorm_num_embeddings")
+ else adanorm_num_embeddings
+ )
+ n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
+ hop_size = (
+ cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
+ )
+ padding = (
+ cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
+ )
+
+ self.backbone = VocosBackbone(
+ input_channels=input_channels,
+ dim=dim,
+ intermediate_dim=intermediate_dim,
+ num_layers=num_layers,
+ adanorm_num_embeddings=adanorm_num_embeddings,
+ )
+ self.head = ISTFTHead(dim, n_fft, hop_size, padding)
+
+ def forward(self, x):
+ x = self.backbone(x)
+ x = self.head(x)
+
+ return x[:, None, :]
+
+
+class JsonHParams:
+ def __init__(self, **kwargs):
+ for k, v in kwargs.items():
+ if type(v) == dict:
+ v = JsonHParams(**v)
+ self[k] = v
+
+ def keys(self):
+ return self.__dict__.keys()
+
+ def items(self):
+ return self.__dict__.items()
+
+ def values(self):
+ return self.__dict__.values()
+
+ def __len__(self):
+ return len(self.__dict__)
+
+ def __getitem__(self, key):
+ return getattr(self, key)
+
+ def __setitem__(self, key, value):
+ return setattr(self, key, value)
+
+ def __contains__(self, key):
+ return key in self.__dict__
+
+ def __repr__(self):
+ return self.__dict__.__repr__()
+
+def build_vocoder_model(cfg):
+ vocoder_model = Vocos(cfg=cfg.model.vocos)
+ vocoder_model.eval()
+ return vocoder_model
+
+def load_checkpoint(build_model_func, cfg, ckpt_path):
+ model = build_model_func(cfg)
+
+ if ckpt_path is not None:
+ accelerate.load_checkpoint_and_dispatch(model, ckpt_path)
+ return model
+
+def load_vocos_model(
+ ckpt_path: str | None = None,
+ config: DictConfig = None,
+ device: str = "cuda",
+):
+ """Load Vocos vocoder.
+
+ Args:
+ config: DictConfig, config for vocoder.
+ ckpt_path: str | None = None, path to checkpoint.
+ device: str, device to load model. Note: accelerate dispatch handles device placement.
+ """
+
+ merged_cfg = _deep_update_dict(_get_vocoder_default_cfg(), config)
+ vocoder_cfg_obj = JsonHParams(**merged_cfg)
+
+ vocoder_model = load_checkpoint(
+ build_vocoder_model, vocoder_cfg_obj, ckpt_path
+ )
+ vocoder_model.eval()
+
+ for param in vocoder_model.parameters():
+ param.requires_grad = False
+
+ return vocoder_model
+
+
+class Vocoder(nn.Module):
+ def __init__(self, vocoder: dict | None = None, ckpt_path: str | None = None):
+ super(Vocoder, self).__init__()
+ vocoder = vocoder or {}
+
+ model_cfg = vocoder.get("model_cfg") or vocoder.get("cfg")
+
+ self.model = load_vocos_model(
+ ckpt_path=ckpt_path,
+ config=model_cfg,
+ )
+
+ def forward(self, x):
+ x = self.model(x)
+ return x
\ No newline at end of file
diff --git a/soulxsinger/models/soulxsinger.py b/soulxsinger/models/soulxsinger.py
new file mode 100644
index 0000000000000000000000000000000000000000..18aff6cfdeda2b686847b437f771a93dcd7b3ed6
--- /dev/null
+++ b/soulxsinger/models/soulxsinger.py
@@ -0,0 +1,187 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+import numpy as np
+from typing import Optional, Dict, Any, List
+
+from soulxsinger.models.modules.vocoder import Vocoder
+from soulxsinger.models.modules.decoder import CFMDecoder
+from soulxsinger.models.modules.convnext import ConvNeXtV2Block
+from soulxsinger.models.modules.mel_transform import MelSpectrogramEncoder
+
+
+class SoulXSinger(nn.Module):
+ """
+ SoulXSinger model.
+ """
+ def __init__(self, config: Dict):
+ super(SoulXSinger, self).__init__()
+ audio_cfg = config.audio
+ enc_cfg = config.model.encoder
+ cfm_cfg = config.model.flow_matching
+
+ self.note_text_encoder = nn.Embedding(enc_cfg["vocab_size"], enc_cfg["text_dim"])
+ self.note_pitch_encoder = nn.Embedding(256, enc_cfg["pitch_dim"])
+ self.note_type_encoder = nn.Embedding(256, enc_cfg["type_dim"])
+ self.f0_encoder = nn.Embedding(enc_cfg["f0_bin"], enc_cfg["f0_dim"])
+
+ self.preflow = nn.Sequential(
+ *[ConvNeXtV2Block(enc_cfg["text_dim"], enc_cfg["text_dim"] * 2) for _ in range(enc_cfg["num_layers"])]
+ )
+ self.cfm_decoder = CFMDecoder(cfm_cfg)
+
+ if audio_cfg is None and isinstance(enc_cfg, dict):
+ audio_cfg = enc_cfg.get("audio_config")
+ self.mel = MelSpectrogramEncoder(audio_cfg)
+ self.vocoder = Vocoder()
+
+ @staticmethod
+ def expand_states(h, mel2token):
+ """
+ Expand the states to the mel-scale.
+ args:
+ h: states, shape: [B, T, H]
+ mel2token: mel2token, shape: [B, F]
+ returns:
+ h: expanded states, shape: [B, F, H]
+ """
+ try:
+ assert mel2token.max() <= h.size(1) - 1
+ except:
+ print(f"Warning: mel2token.max() ({mel2token.max()}) is greater than h.size(1) - 1 ({h.size(1) - 1})")
+ mel2token = torch.clamp(mel2token, 0, h.size(1)-1)
+ mel2token_ = mel2token[..., None].repeat([1, 1, h.shape[-1]])
+ h = torch.gather(h, 1, mel2token_) # [B, T, H]
+ return h
+
+ @staticmethod
+ def f0_to_coarse(f0, f0_bin=361, f0_min=32.7031956625, f0_shift=0):
+ """
+ Convert continuous F0 values to discrete F0 bins (SIL and C1 - B6, 361 bins).
+ args:
+ f0: continuous F0 values
+ f0_bin: number of F0 bins
+ f0_min: minimum F0 value
+ f0_shift: shift value for F0 bins
+ returns:
+ f0_coarse: discrete F0 bins
+ """
+ is_torch = isinstance(f0, torch.Tensor)
+ uv_mask = f0 <= 0
+
+ if is_torch:
+ f0_safe = torch.maximum(f0, torch.tensor(f0_min))
+ f0_cents = 1200 * torch.log2(f0_safe / f0_min)
+ else:
+ f0_safe = np.maximum(f0, f0_min)
+ f0_cents = 1200 * np.log2(f0_safe / f0_min)
+
+ f0_coarse = (f0_cents / 20) + 1
+
+ if is_torch:
+ f0_coarse = torch.round(f0_coarse).long()
+ f0_coarse = torch.clamp(f0_coarse, min=1, max=f0_bin - 1)
+ else:
+ f0_coarse = np.rint(f0_coarse).astype(int)
+ f0_coarse = np.clip(f0_coarse, 1, f0_bin - 1)
+
+ f0_coarse[uv_mask] = 0
+
+ if f0_shift != 0:
+ if is_torch:
+ voiced = f0_coarse > 0
+ if voiced.any():
+ shifted = f0_coarse[voiced] + f0_shift
+ f0_coarse[voiced] = torch.clamp(shifted, 1, f0_bin - 1)
+ else:
+ voiced = f0_coarse > 0
+ if np.any(voiced):
+ shifted = f0_coarse[voiced] + f0_shift
+ f0_coarse[voiced] = np.clip(shifted, 1, f0_bin - 1)
+
+ return f0_coarse
+
+ def infer(self, meta: dict, auto_shift=False, pitch_shift=0, n_steps=32, cfg=3, control="melody"):
+
+ gt_note_text = meta['target']['phoneme']
+ gt_mel2note = meta['target']['mel2note']
+ gt_note_type = meta['target']['note_type']
+
+ pt_wav = meta['prompt']['waveform']
+ pt_note_text = meta['prompt']['phoneme']
+ pt_mel2note = meta['prompt']['mel2note']
+ pt_note_type = meta['prompt']['note_type']
+
+ if control == "score":
+ gt_note_pitch = meta['target']['note_pitch']
+ pt_note_pitch = meta['prompt']['note_pitch']
+ gt_f0 = None
+ pt_f0 = None
+ elif control == "melody":
+ gt_f0 = meta['target']['f0']
+ pt_f0 = meta['prompt']['f0']
+ gt_note_pitch = None
+ pt_note_pitch = None
+ else:
+ raise ValueError(f"Unknown control mode: {control}")
+
+ # calculate auto pitch shift
+ if auto_shift and pitch_shift == 0:
+ if gt_note_pitch != None and pt_note_pitch != None:
+ gt_median = torch.median(gt_note_pitch[gt_note_pitch >= 1])
+ pt_median = torch.median(pt_note_pitch[pt_note_pitch >= 1])
+ f0_shift = torch.round(pt_median - gt_median).int().item()
+ elif gt_f0 != None and pt_f0 != None:
+ gt_f0_median = torch.median(gt_f0[gt_f0 > 0])
+ pt_f0_median = torch.median(pt_f0[pt_f0 > 0])
+ f0_shift = torch.round(torch.log2(pt_f0_median / gt_f0_median) * 1200 / 100).int().item()
+ else:
+ print("Warning: pitch_shift is True but note_pitch or f0 is None. Set f0_shift to 0.")
+ f0_shift = 0
+ else:
+ f0_shift = pitch_shift
+
+ if gt_f0 is None or pt_f0 is None:
+ gt_f0, pt_f0 = torch.zeros_like(gt_mel2note).float().to(gt_mel2note.device), torch.zeros_like(pt_mel2note).float().to(pt_mel2note.device)
+ if gt_note_pitch is None or pt_note_pitch is None:
+ gt_note_pitch, pt_note_pitch = torch.zeros_like(gt_note_type).int().to(gt_note_type.device), torch.zeros_like(pt_note_type).int().to(pt_note_type.device)
+
+ # convert prompt waveform to mel spectrogram
+ pt_mel = self.mel(pt_wav)
+
+ len_prompt = pt_note_pitch.shape[1]
+ len_prompt_mel = pt_f0.shape[1]
+
+ note_pitch = torch.cat([pt_note_pitch, gt_note_pitch], 1)
+ note_text = torch.cat([pt_note_text, gt_note_text], 1)
+ note_type = torch.cat([pt_note_type, gt_note_type], 1)
+ mel2note = torch.cat([pt_mel2note, gt_mel2note + len_prompt], 1)
+
+ f0_course_pt = self.f0_to_coarse(pt_f0)
+ f0_course_gt = self.f0_to_coarse(gt_f0, f0_shift=f0_shift * 5)
+ f0_course = torch.cat([f0_course_pt, f0_course_gt], 1)
+
+ note_pitch[note_pitch > 0] = note_pitch[note_pitch > 0] + f0_shift
+ note_pitch = torch.clamp(note_pitch, 0, 255)
+
+ features = self.note_pitch_encoder(note_pitch) + self.note_type_encoder(note_type) + self.note_text_encoder(note_text)
+
+ features = self.preflow(features)
+ features = self.expand_states(features, mel2note)
+ features = features + self.f0_encoder(f0_course)
+
+ gt_decoder_inp = features[:, len_prompt_mel:, :]
+ pt_decoder_inp = features[:, :len_prompt_mel, :]
+
+ generated_mel = self.cfm_decoder.reverse_diffusion(
+ pt_mel,
+ pt_decoder_inp,
+ gt_decoder_inp,
+ n_timesteps=n_steps,
+ cfg=cfg
+ )
+
+ generated_audio = self.vocoder(generated_mel.transpose(1, 2)[0:1, ...])
+
+ return generated_audio
diff --git a/soulxsinger/utils/audio_utils.py b/soulxsinger/utils/audio_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bca06ac478ab48cb70f262fcf7d706e7121fb040
--- /dev/null
+++ b/soulxsinger/utils/audio_utils.py
@@ -0,0 +1,25 @@
+import torch
+import torchaudio
+
+
+def load_wav(wav_path: str, sample_rate: int):
+ """Load wav file and resample to target sample rate.
+
+ Args:
+ wav_path (str): Path to wav file.
+ sample_rate (int): Target sample rate.
+
+ Returns:
+ torch.Tensor: Waveform tensor with shape (1, T).
+ """
+ waveform, sr = torchaudio.load(wav_path)
+
+ if sr != sample_rate:
+ waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
+
+ if len(waveform.shape) > 1 and waveform.shape[0] > 1:
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
+
+ return waveform
+
+
diff --git a/soulxsinger/utils/data_processor.py b/soulxsinger/utils/data_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..50c366e554de293fa8ba03482175691b0f12917a
--- /dev/null
+++ b/soulxsinger/utils/data_processor.py
@@ -0,0 +1,174 @@
+import json
+import torch
+import numpy as np
+import torchaudio
+from typing import List
+
+from soulxsinger.utils.audio_utils import load_wav
+
+
+class DataProcessor:
+ """Data processor for SoulX-Singer
+ """
+ def __init__(
+ self,
+ hop_size: int,
+ sample_rate: int,
+ phoneset_path: str = 'soulxsinger/utils/phoneme/phone_set.json',
+ device: str = 'cuda',
+ prompt_append_duration: float = 0.5):
+ """Initialize data processor.
+
+ Args:
+ hop_size (int): Hop size in samples.
+ sample_rate (int): Sample rate in Hz.
+ phoneset_path (str): Path to phoneme set JSON file.
+ device (str): Device to use for tensor operations.
+ prompt_append_duration (float): Duration to append to prompt in seconds.
+ """
+ self.hop_size = hop_size
+ self.sample_rate = sample_rate
+ self.device = device
+ self.prompt_append_duration = prompt_append_duration
+ self.prompt_append_length = int(prompt_append_duration * sample_rate / hop_size)
+ self.load_phoneme_id_map(phoneset_path)
+
+ def load_phoneme_id_map(self, phoneset_path: str):
+ with open(phoneset_path, "r", encoding='utf-8') as f:
+ phoneset = json.load(f)
+ self.phone2idx = {ph: idx for idx, ph in enumerate(phoneset)}
+
+ def merge_phoneme(self, meta):
+ merged_items = []
+
+ duration = [float(x) for x in meta["duration"].split()]
+ phoneme = [str(x).replace("", "") for i, x in enumerate(meta["phoneme"].split())]
+ note_pitch = [int(x) for x in meta["note_pitch"].split()]
+ note_type = [int(x) if phoneme[i] != "" else 1 for i, x in enumerate(meta["note_type"].split())]
+
+ for i in range(len(phoneme)):
+ if i > 0 and phoneme[i] == phoneme[i - 1] == "" and note_type[i] == note_type[i - 1] and note_pitch[i] == note_pitch[i - 1]:
+ merged_items[-1][1] += duration[i]
+ else:
+ merged_items.append([phoneme[i], duration[i], note_pitch[i], note_type[i]])
+
+ single_frame_duration = self.hop_size / self.sample_rate
+ meta['phoneme'] = [x[0] for x in merged_items]
+ meta['duration'] = [x[1] for x in merged_items]
+ meta['note_pitch'] = [x[2] for x in merged_items]
+ meta['note_type'] = [x[3] for x in merged_items]
+
+ return meta
+
+ def preprocess(
+ self,
+ note_duration: List[float],
+ phonemes: List[str],
+ note_pitch: List[int],
+ note_type: List[int],
+ ):
+ """
+ Insert and for each note.
+ Get aligned indices for each frame.
+
+ Args:
+ note_duration: Duration of each note in seconds
+ phonemes: Phoneme sequence for each note
+ note_pitch: Pitch value for each note
+ note_type: Type value for each note
+
+ """
+ sample_rate = self.sample_rate
+ hop_size = self.hop_size
+ duration = sum(note_duration) * sample_rate / hop_size
+ mel2note = torch.zeros(int(duration), dtype=torch.long)
+
+ ph_locations = [] # idx at mel scale and length
+ new_phonemes = []
+ dur_sum = 0
+
+ note2origin = []
+
+ for ph_idx in range(len(phonemes)):
+ dur = int(np.round(dur_sum * sample_rate / hop_size))
+ dur = min(dur, len(mel2note) - 1)
+ new_phonemes.append("")
+ note2origin.append(ph_idx)
+ if phonemes[ph_idx][:3] == "en_":
+ en_phs = ['en_' + x for x in phonemes[ph_idx][3:].split('-')] + [''] # between en words in one note
+ ph_locations.append([dur, max(1, len(en_phs))])
+ new_phonemes.extend(en_phs)
+ note2origin.extend([ph_idx] * len(en_phs))
+ else:
+ ph_locations.append([dur, 1])
+ new_phonemes.append(phonemes[ph_idx])
+ note2origin.append(ph_idx)
+ new_phonemes.append("")
+ note2origin.append(ph_idx)
+ dur_sum += note_duration[ph_idx]
+
+ ph_idx = 1
+ for idx, (i, j) in enumerate(ph_locations):
+ next_phoneme_start = ph_locations[idx + 1][0] if idx < len(ph_locations) - 1 else len(mel2note)
+ if i >= len(mel2note) or i + j > len(mel2note):
+ break
+ if i < len(mel2note) and mel2note[i] > 0:
+ # print(f"warning: overlap of {idx}: {mel2note[i]}")
+ while i < len(mel2note) and mel2note[i] > 0:
+ i += 1
+ mel2note[i] = ph_idx
+ k = i + 1
+ while k + j < next_phoneme_start:
+ mel2note[k : k + j] = torch.arange(ph_idx, ph_idx + j) + 1
+ k += j
+ mel2note[next_phoneme_start - 1] = ph_idx + j + 1
+ ph_idx += j + 2 # + ph repeats +
+
+ new_phonemes = [""] + new_phonemes
+ new_note_pitch = [0] + [note_pitch[k] for k in note2origin]
+ new_note_type = [1] + [note_type[k] for k in note2origin]
+
+ return {
+ "phoneme": torch.tensor([self.phone2idx[x] for x in new_phonemes], device=self.device).unsqueeze(0),
+ "note_pitch": torch.tensor(new_note_pitch, device=self.device).unsqueeze(0),
+ "note_type": torch.tensor(new_note_type, device=self.device).unsqueeze(0),
+ "mel2note": mel2note.clone().detach().to(self.device).unsqueeze(0),
+ }
+
+ def process(
+ self,
+ meta: dict,
+ wav_path: str = None
+ ):
+
+ meta = self.merge_phoneme(meta)
+
+ item = self.preprocess(
+ meta["duration"],
+ meta["phoneme"],
+ meta["note_pitch"],
+ meta["note_type"],
+ )
+
+ f0 = torch.tensor([float(x) for x in meta["f0"].split()])
+ min_frame = min(item["mel2note"].shape[1], f0.shape[0])
+ item['f0'] = f0[:min_frame].unsqueeze(0).float().to(self.device)
+ item["mel2note"] = item["mel2note"][:, :min_frame]
+
+ if wav_path is not None:
+ waveform = load_wav(wav_path, self.sample_rate)
+ item["waveform"] = waveform.to(self.device)[:, :min_frame * self.hop_size]
+
+ return item
+
+
+# test
+if __name__ == "__main__":
+ import json
+ with open("example/metadata/zh_prompt.json", "r", encoding="utf-8") as f:
+ meta = json.load(f)
+ if isinstance(meta, list):
+ meta = meta[0]
+ processor = DataProcessor(hop_size=480, sample_rate=24000)
+ item = processor.process(meta, "example/audio/zh_prompt.wav")
+ print(item.keys())
\ No newline at end of file
diff --git a/soulxsinger/utils/file_utils.py b/soulxsinger/utils/file_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b8693ec338245b1fd159fc674f6c73edb2e3beb
--- /dev/null
+++ b/soulxsinger/utils/file_utils.py
@@ -0,0 +1,77 @@
+
+"""
+Description:
+ This script contains a collection of functions designed to handle various
+ file reading and writing operations. It provides utilities to read from files,
+ write data to files, and perform file manipulation tasks.
+"""
+
+import os
+import json
+
+from tqdm import tqdm
+from typing import List, Dict
+from pathlib import Path
+from omegaconf import OmegaConf, DictConfig
+
+
+def write_jsonl(metadata: List[dict], file_path: Path):
+ """Writes a list of dictionaries to a JSONL file.
+
+ Args:
+ metadata : List[dict]
+ A list of dictionaries, each representing a piece of meta.
+ file_path : Path
+ The file path to save the JSONL file
+
+ This function writes each dictionary in the list to a new line in the specified file.
+ """
+ with open(file_path, "w", encoding="utf-8") as f:
+ for meta in tqdm(metadata, desc="writing jsonl"):
+ # Convert dictionary to JSON string and write it to the file with a newline
+ json_str = json.dumps(meta, ensure_ascii=False) + "\n"
+ f.write(json_str)
+ print(f"jsonl saved to {file_path}")
+
+
+def read_jsonl(file_path: Path) -> List[dict]:
+ """
+ Reads a JSONL file and returns a list of dictionaries.
+
+ Args:
+ file_path : Path
+ The path to the JSONL file to be read.
+
+ Returns:
+ List[dict]
+ A list of dictionaries parsed from each line of the JSONL file.
+ """
+ metadata = []
+ # Open the file for reading
+ with open(file_path, "r", encoding="utf-8") as f:
+ # Split the file into lines
+ lines = f.read().splitlines()
+ # Process each line
+ for line in lines:
+ # Convert JSON string back to dictionary and append to list
+ meta = json.loads(line)
+ metadata.append(meta)
+ # Return the list of metadata
+ return metadata
+
+
+def load_config(config_path: Path) -> DictConfig:
+ """Loads a configuration file and optionally merges it with a base configuration.
+
+ Args:
+ config_path (Path): Path to the configuration file.
+ """
+ # Load the initial configuration from the given path
+ config = OmegaConf.load(config_path)
+
+ # Check if there is a base configuration specified and merge if necessary
+ if config.get("base_config", None) is not None:
+ base_config = OmegaConf.load(config["base_config"])
+ config = OmegaConf.merge(base_config, config)
+
+ return config
\ No newline at end of file
diff --git a/soulxsinger/utils/phoneme/phone_set.json b/soulxsinger/utils/phoneme/phone_set.json
new file mode 100644
index 0000000000000000000000000000000000000000..5ba48c138961bd5b360e72892adc98e03c943852
--- /dev/null
+++ b/soulxsinger/utils/phoneme/phone_set.json
@@ -0,0 +1,2822 @@
+[
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "en_AA0",
+ "en_AA1",
+ "en_AA2",
+ "en_AE0",
+ "en_AE1",
+ "en_AE2",
+ "en_AH0",
+ "en_AH1",
+ "en_AH2",
+ "en_AO0",
+ "en_AO1",
+ "en_AO2",
+ "en_AW0",
+ "en_AW1",
+ "en_AW2",
+ "en_AY0",
+ "en_AY1",
+ "en_AY2",
+ "en_B",
+ "en_CH",
+ "en_D",
+ "en_DH",
+ "en_EH0",
+ "en_EH1",
+ "en_EH2",
+ "en_ER0",
+ "en_ER1",
+ "en_ER2",
+ "en_EY0",
+ "en_EY1",
+ "en_EY2",
+ "en_F",
+ "en_G",
+ "en_HH",
+ "en_IH0",
+ "en_IH1",
+ "en_IH2",
+ "en_IY0",
+ "en_IY1",
+ "en_IY2",
+ "en_JH",
+ "en_K",
+ "en_L",
+ "en_M",
+ "en_N",
+ "en_NG",
+ "en_OW0",
+ "en_OW1",
+ "en_OW2",
+ "en_OY0",
+ "en_OY1",
+ "en_OY2",
+ "en_P",
+ "en_R",
+ "en_S",
+ "en_SH",
+ "en_T",
+ "en_TH",
+ "en_UH0",
+ "en_UH1",
+ "en_UH2",
+ "en_UW",
+ "en_UW0",
+ "en_UW1",
+ "en_UW2",
+ "en_V",
+ "en_W",
+ "en_Y",
+ "en_Z",
+ "en_ZH",
+ "yue_aa1",
+ "yue_aa2",
+ "yue_aa3",
+ "yue_aai1",
+ "yue_aai3",
+ "yue_aak1",
+ "yue_aap3",
+ "yue_aat3",
+ "yue_aau2",
+ "yue_aau3",
+ "yue_ai1",
+ "yue_ai2",
+ "yue_ai3",
+ "yue_ak1",
+ "yue_am1",
+ "yue_am2",
+ "yue_am3",
+ "yue_ang1",
+ "yue_au1",
+ "yue_au2",
+ "yue_au3",
+ "yue_baa1",
+ "yue_baa2",
+ "yue_baa3",
+ "yue_baa4",
+ "yue_baa6",
+ "yue_baai1",
+ "yue_baai2",
+ "yue_baai3",
+ "yue_baai6",
+ "yue_baak1",
+ "yue_baak2",
+ "yue_baak3",
+ "yue_baak6",
+ "yue_baan1",
+ "yue_baan2",
+ "yue_baan6",
+ "yue_baat3",
+ "yue_baau1",
+ "yue_baau2",
+ "yue_baau3",
+ "yue_bai3",
+ "yue_bai6",
+ "yue_bak1",
+ "yue_bam1",
+ "yue_ban1",
+ "yue_ban2",
+ "yue_ban3",
+ "yue_ban6",
+ "yue_bang1",
+ "yue_bat1",
+ "yue_bat6",
+ "yue_be1",
+ "yue_bei1",
+ "yue_bei2",
+ "yue_bei3",
+ "yue_bei6",
+ "yue_beng2",
+ "yue_beng3",
+ "yue_beng6",
+ "yue_bik1",
+ "yue_bin1",
+ "yue_bin2",
+ "yue_bin3",
+ "yue_bin6",
+ "yue_bing1",
+ "yue_bing2",
+ "yue_bing6",
+ "yue_bit1",
+ "yue_bit3",
+ "yue_bit6",
+ "yue_biu1",
+ "yue_biu2",
+ "yue_bo1",
+ "yue_bo3",
+ "yue_bok1",
+ "yue_bok3",
+ "yue_bok6",
+ "yue_bong1",
+ "yue_bong2",
+ "yue_bong6",
+ "yue_bou1",
+ "yue_bou2",
+ "yue_bou3",
+ "yue_bou6",
+ "yue_bui1",
+ "yue_bui3",
+ "yue_bui6",
+ "yue_buk1",
+ "yue_buk6",
+ "yue_bun1",
+ "yue_bun2",
+ "yue_bun3",
+ "yue_bun6",
+ "yue_but6",
+ "yue_caa1",
+ "yue_caa3",
+ "yue_caa4",
+ "yue_caai1",
+ "yue_caai2",
+ "yue_caai4",
+ "yue_caak1",
+ "yue_caak3",
+ "yue_caam1",
+ "yue_caam2",
+ "yue_caam3",
+ "yue_caam4",
+ "yue_caan1",
+ "yue_caan2",
+ "yue_caan3",
+ "yue_caan4",
+ "yue_caang2",
+ "yue_caang3",
+ "yue_caap3",
+ "yue_caat3",
+ "yue_caat4",
+ "yue_caau1",
+ "yue_caau2",
+ "yue_cai1",
+ "yue_cai3",
+ "yue_cai4",
+ "yue_cak1",
+ "yue_cam1",
+ "yue_cam2",
+ "yue_cam4",
+ "yue_can1",
+ "yue_can2",
+ "yue_can3",
+ "yue_can4",
+ "yue_cang4",
+ "yue_cap1",
+ "yue_cat1",
+ "yue_cau1",
+ "yue_cau2",
+ "yue_cau3",
+ "yue_cau4",
+ "yue_ce1",
+ "yue_ce2",
+ "yue_ce3",
+ "yue_ce4",
+ "yue_cek3",
+ "yue_ceng1",
+ "yue_ceng2",
+ "yue_ceoi1",
+ "yue_ceoi2",
+ "yue_ceoi3",
+ "yue_ceoi4",
+ "yue_ceon1",
+ "yue_ceon2",
+ "yue_ceon4",
+ "yue_ceot1",
+ "yue_ci1",
+ "yue_ci2",
+ "yue_ci3",
+ "yue_ci4",
+ "yue_ci5",
+ "yue_cik1",
+ "yue_cik4",
+ "yue_cim1",
+ "yue_cim4",
+ "yue_cin1",
+ "yue_cin2",
+ "yue_cin4",
+ "yue_cin5",
+ "yue_cing1",
+ "yue_cing2",
+ "yue_cing3",
+ "yue_cing4",
+ "yue_cit3",
+ "yue_ciu1",
+ "yue_ciu2",
+ "yue_ciu3",
+ "yue_ciu4",
+ "yue_co1",
+ "yue_co2",
+ "yue_co3",
+ "yue_co4",
+ "yue_co5",
+ "yue_coek3",
+ "yue_coeng1",
+ "yue_coeng2",
+ "yue_coeng3",
+ "yue_coeng4",
+ "yue_coi1",
+ "yue_coi2",
+ "yue_coi3",
+ "yue_coi4",
+ "yue_cong1",
+ "yue_cong2",
+ "yue_cong3",
+ "yue_cong4",
+ "yue_cou1",
+ "yue_cou2",
+ "yue_cou3",
+ "yue_cou4",
+ "yue_cou5",
+ "yue_cuk1",
+ "yue_cung1",
+ "yue_cung2",
+ "yue_cung4",
+ "yue_cung5",
+ "yue_cyu2",
+ "yue_cyu3",
+ "yue_cyu4",
+ "yue_cyu5",
+ "yue_cyun1",
+ "yue_cyun2",
+ "yue_cyun3",
+ "yue_cyun4",
+ "yue_cyut3",
+ "yue_daa1",
+ "yue_daa2",
+ "yue_daai2",
+ "yue_daai3",
+ "yue_daai6",
+ "yue_daam1",
+ "yue_daam2",
+ "yue_daam3",
+ "yue_daam6",
+ "yue_daan1",
+ "yue_daan2",
+ "yue_daan3",
+ "yue_daan6",
+ "yue_daap3",
+ "yue_daap6",
+ "yue_daat6",
+ "yue_dai1",
+ "yue_dai2",
+ "yue_dai3",
+ "yue_dai4",
+ "yue_dai6",
+ "yue_dak1",
+ "yue_dak6",
+ "yue_dan6",
+ "yue_dang1",
+ "yue_dang2",
+ "yue_dang3",
+ "yue_dat6",
+ "yue_dau1",
+ "yue_dau2",
+ "yue_dau3",
+ "yue_dau6",
+ "yue_de1",
+ "yue_de2",
+ "yue_dei2",
+ "yue_dei6",
+ "yue_dek6",
+ "yue_deng2",
+ "yue_deng6",
+ "yue_deoi1",
+ "yue_deoi2",
+ "yue_deoi3",
+ "yue_deoi6",
+ "yue_deon1",
+ "yue_deon6",
+ "yue_deu6",
+ "yue_di1",
+ "yue_dik1",
+ "yue_dik6",
+ "yue_dim2",
+ "yue_dim3",
+ "yue_dim6",
+ "yue_din1",
+ "yue_din2",
+ "yue_din6",
+ "yue_ding1",
+ "yue_ding2",
+ "yue_ding3",
+ "yue_ding6",
+ "yue_dip2",
+ "yue_dip6",
+ "yue_dit3",
+ "yue_dit6",
+ "yue_diu1",
+ "yue_diu3",
+ "yue_diu6",
+ "yue_do1",
+ "yue_do2",
+ "yue_do6",
+ "yue_doek3",
+ "yue_doi2",
+ "yue_doi6",
+ "yue_dok6",
+ "yue_dong1",
+ "yue_dong2",
+ "yue_dong3",
+ "yue_dong6",
+ "yue_dou1",
+ "yue_dou2",
+ "yue_dou3",
+ "yue_dou6",
+ "yue_duk1",
+ "yue_duk6",
+ "yue_dung1",
+ "yue_dung2",
+ "yue_dung3",
+ "yue_dung6",
+ "yue_dyun1",
+ "yue_dyun2",
+ "yue_dyun3",
+ "yue_dyun6",
+ "yue_dyut3",
+ "yue_dyut6",
+ "yue_ei6",
+ "yue_faa1",
+ "yue_faa2",
+ "yue_faa3",
+ "yue_faai3",
+ "yue_faan1",
+ "yue_faan2",
+ "yue_faan3",
+ "yue_faan4",
+ "yue_faan6",
+ "yue_faat3",
+ "yue_fai1",
+ "yue_fai3",
+ "yue_fan1",
+ "yue_fan2",
+ "yue_fan3",
+ "yue_fan4",
+ "yue_fan5",
+ "yue_fan6",
+ "yue_fat1",
+ "yue_fat6",
+ "yue_fau2",
+ "yue_fau4",
+ "yue_fau6",
+ "yue_fe1",
+ "yue_fei1",
+ "yue_fei2",
+ "yue_fei4",
+ "yue_fo1",
+ "yue_fo2",
+ "yue_fo3",
+ "yue_fok3",
+ "yue_fong1",
+ "yue_fong2",
+ "yue_fong3",
+ "yue_fong4",
+ "yue_fu1",
+ "yue_fu2",
+ "yue_fu3",
+ "yue_fu4",
+ "yue_fu5",
+ "yue_fu6",
+ "yue_fui1",
+ "yue_fui3",
+ "yue_fuk1",
+ "yue_fuk6",
+ "yue_fun1",
+ "yue_fun2",
+ "yue_fung1",
+ "yue_fung3",
+ "yue_fung4",
+ "yue_fung6",
+ "yue_fut3",
+ "yue_gaa1",
+ "yue_gaa2",
+ "yue_gaa3",
+ "yue_gaai1",
+ "yue_gaai2",
+ "yue_gaai3",
+ "yue_gaak3",
+ "yue_gaam1",
+ "yue_gaam2",
+ "yue_gaam3",
+ "yue_gaan1",
+ "yue_gaan2",
+ "yue_gaan3",
+ "yue_gaang1",
+ "yue_gaap3",
+ "yue_gaau1",
+ "yue_gaau2",
+ "yue_gaau3",
+ "yue_gai1",
+ "yue_gai2",
+ "yue_gai3",
+ "yue_gam1",
+ "yue_gam2",
+ "yue_gam3",
+ "yue_gan1",
+ "yue_gan2",
+ "yue_gan6",
+ "yue_gang1",
+ "yue_gang2",
+ "yue_gang3",
+ "yue_gap1",
+ "yue_gat1",
+ "yue_gau1",
+ "yue_gau2",
+ "yue_gau3",
+ "yue_gau6",
+ "yue_ge3",
+ "yue_gei1",
+ "yue_gei2",
+ "yue_gei3",
+ "yue_gei6",
+ "yue_geng2",
+ "yue_geng3",
+ "yue_geoi1",
+ "yue_geoi2",
+ "yue_geoi3",
+ "yue_geoi6",
+ "yue_gik1",
+ "yue_gik6",
+ "yue_gim1",
+ "yue_gim2",
+ "yue_gim3",
+ "yue_gim6",
+ "yue_gin1",
+ "yue_gin2",
+ "yue_gin3",
+ "yue_gin6",
+ "yue_ging1",
+ "yue_ging2",
+ "yue_ging3",
+ "yue_ging6",
+ "yue_gip3",
+ "yue_git3",
+ "yue_git6",
+ "yue_giu1",
+ "yue_giu2",
+ "yue_giu3",
+ "yue_giu6",
+ "yue_go1",
+ "yue_go2",
+ "yue_go3",
+ "yue_go4",
+ "yue_goek3",
+ "yue_goeng1",
+ "yue_goi1",
+ "yue_goi2",
+ "yue_goi3",
+ "yue_gok1",
+ "yue_gok2",
+ "yue_gok3",
+ "yue_gok4",
+ "yue_gon1",
+ "yue_gon2",
+ "yue_gon3",
+ "yue_gong1",
+ "yue_gong2",
+ "yue_gong3",
+ "yue_got3",
+ "yue_gou1",
+ "yue_gou2",
+ "yue_gou3",
+ "yue_gu1",
+ "yue_gu2",
+ "yue_gu3",
+ "yue_gu4",
+ "yue_gui6",
+ "yue_guk1",
+ "yue_guk2",
+ "yue_guk6",
+ "yue_gun1",
+ "yue_gun2",
+ "yue_gun3",
+ "yue_gung1",
+ "yue_gung2",
+ "yue_gung3",
+ "yue_gung4",
+ "yue_gung6",
+ "yue_gwaa1",
+ "yue_gwaa2",
+ "yue_gwaa3",
+ "yue_gwaai1",
+ "yue_gwaai2",
+ "yue_gwaai3",
+ "yue_gwaak3",
+ "yue_gwaan1",
+ "yue_gwaan3",
+ "yue_gwaat3",
+ "yue_gwai1",
+ "yue_gwai2",
+ "yue_gwai3",
+ "yue_gwai6",
+ "yue_gwan1",
+ "yue_gwan2",
+ "yue_gwan3",
+ "yue_gwan6",
+ "yue_gwang1",
+ "yue_gwat1",
+ "yue_gwat6",
+ "yue_gwek4",
+ "yue_gwo1",
+ "yue_gwo2",
+ "yue_gwo3",
+ "yue_gwok3",
+ "yue_gwong1",
+ "yue_gwong2",
+ "yue_gyun1",
+ "yue_gyun2",
+ "yue_gyun3",
+ "yue_gyun6",
+ "yue_haa1",
+ "yue_haa2",
+ "yue_haa4",
+ "yue_haa5",
+ "yue_haa6",
+ "yue_haai1",
+ "yue_haai4",
+ "yue_haai5",
+ "yue_haai6",
+ "yue_haak1",
+ "yue_haak3",
+ "yue_haam3",
+ "yue_haam4",
+ "yue_haam6",
+ "yue_haan4",
+ "yue_haan6",
+ "yue_haang4",
+ "yue_haap3",
+ "yue_haap6",
+ "yue_haau1",
+ "yue_haau2",
+ "yue_haau3",
+ "yue_haau6",
+ "yue_hai2",
+ "yue_hai4",
+ "yue_hai6",
+ "yue_hak1",
+ "yue_ham1",
+ "yue_ham2",
+ "yue_ham3",
+ "yue_ham4",
+ "yue_ham6",
+ "yue_han2",
+ "yue_han4",
+ "yue_han6",
+ "yue_hang1",
+ "yue_hang2",
+ "yue_hang4",
+ "yue_hang6",
+ "yue_hap1",
+ "yue_hap2",
+ "yue_hap6",
+ "yue_hat1",
+ "yue_hat6",
+ "yue_hau2",
+ "yue_hau4",
+ "yue_hau5",
+ "yue_hau6",
+ "yue_he3",
+ "yue_hei1",
+ "yue_hei2",
+ "yue_hei3",
+ "yue_hek3",
+ "yue_heng1",
+ "yue_heng6",
+ "yue_heoi1",
+ "yue_heoi2",
+ "yue_heoi3",
+ "yue_hi1",
+ "yue_him1",
+ "yue_him2",
+ "yue_him3",
+ "yue_hin1",
+ "yue_hin2",
+ "yue_hin3",
+ "yue_hing1",
+ "yue_hing3",
+ "yue_hip3",
+ "yue_hit3",
+ "yue_hiu1",
+ "yue_hiu2",
+ "yue_ho1",
+ "yue_ho2",
+ "yue_ho4",
+ "yue_ho6",
+ "yue_hoeng1",
+ "yue_hoeng2",
+ "yue_hoeng3",
+ "yue_hoi1",
+ "yue_hoi2",
+ "yue_hoi6",
+ "yue_hok3",
+ "yue_hok6",
+ "yue_hon1",
+ "yue_hon2",
+ "yue_hon3",
+ "yue_hon4",
+ "yue_hon5",
+ "yue_hon6",
+ "yue_hong1",
+ "yue_hong2",
+ "yue_hong4",
+ "yue_hong6",
+ "yue_hot3",
+ "yue_hou2",
+ "yue_hou3",
+ "yue_hou4",
+ "yue_hou6",
+ "yue_huk1",
+ "yue_huk6",
+ "yue_hung1",
+ "yue_hung2",
+ "yue_hung3",
+ "yue_hung4",
+ "yue_hyun1",
+ "yue_hyun3",
+ "yue_hyut3",
+ "yue_jaa1",
+ "yue_jaa3",
+ "yue_jaa5",
+ "yue_jaa6",
+ "yue_jaak3",
+ "yue_jai5",
+ "yue_jai6",
+ "yue_jam1",
+ "yue_jam2",
+ "yue_jam3",
+ "yue_jam4",
+ "yue_jam6",
+ "yue_jan1",
+ "yue_jan2",
+ "yue_jan3",
+ "yue_jan4",
+ "yue_jan5",
+ "yue_jan6",
+ "yue_jap1",
+ "yue_jap6",
+ "yue_jat1",
+ "yue_jat2",
+ "yue_jat6",
+ "yue_jau1",
+ "yue_jau2",
+ "yue_jau3",
+ "yue_jau4",
+ "yue_jau5",
+ "yue_jau6",
+ "yue_je2",
+ "yue_je4",
+ "yue_je5",
+ "yue_je6",
+ "yue_jeoi5",
+ "yue_jeoi6",
+ "yue_jeon6",
+ "yue_ji1",
+ "yue_ji2",
+ "yue_ji3",
+ "yue_ji4",
+ "yue_ji5",
+ "yue_ji6",
+ "yue_jik1",
+ "yue_jik6",
+ "yue_jim1",
+ "yue_jim2",
+ "yue_jim3",
+ "yue_jim4",
+ "yue_jim5",
+ "yue_jim6",
+ "yue_jin1",
+ "yue_jin2",
+ "yue_jin3",
+ "yue_jin4",
+ "yue_jin6",
+ "yue_jing1",
+ "yue_jing2",
+ "yue_jing3",
+ "yue_jing4",
+ "yue_jing6",
+ "yue_jip3",
+ "yue_jip6",
+ "yue_jit3",
+ "yue_jit6",
+ "yue_jiu1",
+ "yue_jiu2",
+ "yue_jiu3",
+ "yue_jiu4",
+ "yue_jiu6",
+ "yue_jo1",
+ "yue_joek3",
+ "yue_joek6",
+ "yue_joeng1",
+ "yue_joeng2",
+ "yue_joeng4",
+ "yue_joeng5",
+ "yue_joeng6",
+ "yue_juk1",
+ "yue_juk2",
+ "yue_juk6",
+ "yue_jung1",
+ "yue_jung2",
+ "yue_jung4",
+ "yue_jung5",
+ "yue_jung6",
+ "yue_jyu1",
+ "yue_jyu2",
+ "yue_jyu4",
+ "yue_jyu5",
+ "yue_jyu6",
+ "yue_jyun1",
+ "yue_jyun2",
+ "yue_jyun3",
+ "yue_jyun4",
+ "yue_jyun5",
+ "yue_jyun6",
+ "yue_jyut2",
+ "yue_jyut6",
+ "yue_kaa1",
+ "yue_kaa3",
+ "yue_kaat1",
+ "yue_kaau3",
+ "yue_kai1",
+ "yue_kai2",
+ "yue_kai3",
+ "yue_kam1",
+ "yue_kam4",
+ "yue_kan4",
+ "yue_kan5",
+ "yue_kap1",
+ "yue_kap6",
+ "yue_kat1",
+ "yue_kau1",
+ "yue_kau3",
+ "yue_kau4",
+ "yue_ke4",
+ "yue_kei1",
+ "yue_kei2",
+ "yue_kei3",
+ "yue_kei4",
+ "yue_kei5",
+ "yue_kek6",
+ "yue_keoi1",
+ "yue_keoi4",
+ "yue_keoi5",
+ "yue_kin4",
+ "yue_king1",
+ "yue_king2",
+ "yue_king4",
+ "yue_king5",
+ "yue_kit3",
+ "yue_kiu1",
+ "yue_kiu2",
+ "yue_kiu3",
+ "yue_kiu4",
+ "yue_kiu5",
+ "yue_koek3",
+ "yue_koek6",
+ "yue_koeng4",
+ "yue_koeng5",
+ "yue_koi3",
+ "yue_kok3",
+ "yue_kong3",
+ "yue_ku1",
+ "yue_kui2",
+ "yue_kuk1",
+ "yue_kung4",
+ "yue_kut3",
+ "yue_kwaa1",
+ "yue_kwaa3",
+ "yue_kwaang1",
+ "yue_kwaang3",
+ "yue_kwai1",
+ "yue_kwai4",
+ "yue_kwai5",
+ "yue_kwan1",
+ "yue_kwan2",
+ "yue_kwan3",
+ "yue_kwan4",
+ "yue_kwik1",
+ "yue_kwok3",
+ "yue_kwong3",
+ "yue_kwong4",
+ "yue_kyun4",
+ "yue_kyut3",
+ "yue_laa1",
+ "yue_laa3",
+ "yue_laa4",
+ "yue_laai1",
+ "yue_laai2",
+ "yue_laai6",
+ "yue_laak6",
+ "yue_laam2",
+ "yue_laam4",
+ "yue_laam5",
+ "yue_laam6",
+ "yue_laan2",
+ "yue_laan4",
+ "yue_laan5",
+ "yue_laan6",
+ "yue_laang1",
+ "yue_laang5",
+ "yue_laap6",
+ "yue_laat3",
+ "yue_laat6",
+ "yue_laau4",
+ "yue_lai4",
+ "yue_lai5",
+ "yue_lai6",
+ "yue_lak6",
+ "yue_lam1",
+ "yue_lam2",
+ "yue_lam4",
+ "yue_lam5",
+ "yue_lat1",
+ "yue_lau1",
+ "yue_lau2",
+ "yue_lau4",
+ "yue_lau5",
+ "yue_lau6",
+ "yue_le2",
+ "yue_le4",
+ "yue_lei1",
+ "yue_lei2",
+ "yue_lei4",
+ "yue_lei5",
+ "yue_lei6",
+ "yue_lek1",
+ "yue_leng3",
+ "yue_leng5",
+ "yue_leoi4",
+ "yue_leoi5",
+ "yue_leoi6",
+ "yue_leon2",
+ "yue_leon4",
+ "yue_leon6",
+ "yue_leot2",
+ "yue_leot6",
+ "yue_leu5",
+ "yue_li1",
+ "yue_lik1",
+ "yue_lik6",
+ "yue_lim2",
+ "yue_lim4",
+ "yue_lim5",
+ "yue_lim6",
+ "yue_lin2",
+ "yue_lin4",
+ "yue_lin6",
+ "yue_ling1",
+ "yue_ling2",
+ "yue_ling4",
+ "yue_ling5",
+ "yue_ling6",
+ "yue_lip6",
+ "yue_lit6",
+ "yue_liu2",
+ "yue_liu4",
+ "yue_liu5",
+ "yue_liu6",
+ "yue_lo1",
+ "yue_lo2",
+ "yue_lo4",
+ "yue_loek2",
+ "yue_loek6",
+ "yue_loeng2",
+ "yue_loeng4",
+ "yue_loeng5",
+ "yue_loeng6",
+ "yue_loi4",
+ "yue_lok1",
+ "yue_lok3",
+ "yue_lok6",
+ "yue_long1",
+ "yue_long2",
+ "yue_long4",
+ "yue_long5",
+ "yue_long6",
+ "yue_lou2",
+ "yue_lou4",
+ "yue_lou5",
+ "yue_lou6",
+ "yue_luk1",
+ "yue_luk2",
+ "yue_luk6",
+ "yue_lung4",
+ "yue_lung5",
+ "yue_lung6",
+ "yue_lyun2",
+ "yue_lyun4",
+ "yue_lyun6",
+ "yue_lyut3",
+ "yue_lyut6",
+ "yue_m4",
+ "yue_maa1",
+ "yue_maa2",
+ "yue_maa3",
+ "yue_maa4",
+ "yue_maa5",
+ "yue_maa6",
+ "yue_maai4",
+ "yue_maai5",
+ "yue_maai6",
+ "yue_maak3",
+ "yue_maan1",
+ "yue_maan2",
+ "yue_maan4",
+ "yue_maan5",
+ "yue_maan6",
+ "yue_maang4",
+ "yue_maang5",
+ "yue_maang6",
+ "yue_maau1",
+ "yue_maau4",
+ "yue_maau5",
+ "yue_maau6",
+ "yue_mai1",
+ "yue_mai4",
+ "yue_mai5",
+ "yue_mai6",
+ "yue_mak6",
+ "yue_man1",
+ "yue_man2",
+ "yue_man4",
+ "yue_man5",
+ "yue_man6",
+ "yue_mang2",
+ "yue_mang4",
+ "yue_mat1",
+ "yue_mat2",
+ "yue_mat6",
+ "yue_mau4",
+ "yue_mau5",
+ "yue_mau6",
+ "yue_me1",
+ "yue_me2",
+ "yue_mei1",
+ "yue_mei2",
+ "yue_mei4",
+ "yue_mei5",
+ "yue_mei6",
+ "yue_meng2",
+ "yue_meng6",
+ "yue_mik6",
+ "yue_min2",
+ "yue_min4",
+ "yue_min5",
+ "yue_min6",
+ "yue_ming4",
+ "yue_ming5",
+ "yue_ming6",
+ "yue_mit6",
+ "yue_miu1",
+ "yue_miu4",
+ "yue_miu5",
+ "yue_miu6",
+ "yue_mo1",
+ "yue_mo2",
+ "yue_mo4",
+ "yue_mo6",
+ "yue_mok1",
+ "yue_mok2",
+ "yue_mok6",
+ "yue_mong4",
+ "yue_mong5",
+ "yue_mong6",
+ "yue_mou1",
+ "yue_mou2",
+ "yue_mou4",
+ "yue_mou5",
+ "yue_mou6",
+ "yue_mui2",
+ "yue_mui4",
+ "yue_mui5",
+ "yue_mui6",
+ "yue_muk6",
+ "yue_mun2",
+ "yue_mun4",
+ "yue_mun5",
+ "yue_mun6",
+ "yue_mung1",
+ "yue_mung2",
+ "yue_mung4",
+ "yue_mung6",
+ "yue_mut3",
+ "yue_mut6",
+ "yue_naa4",
+ "yue_naa5",
+ "yue_naai2",
+ "yue_naai4",
+ "yue_naai5",
+ "yue_naam4",
+ "yue_naam5",
+ "yue_naan4",
+ "yue_naan6",
+ "yue_naap6",
+ "yue_naat6",
+ "yue_naau4",
+ "yue_naau5",
+ "yue_naau6",
+ "yue_nai4",
+ "yue_nam2",
+ "yue_nang4",
+ "yue_nap1",
+ "yue_nau1",
+ "yue_nau2",
+ "yue_nau5",
+ "yue_ne1",
+ "yue_nei4",
+ "yue_nei5",
+ "yue_nei6",
+ "yue_neoi1",
+ "yue_neoi2",
+ "yue_neoi4",
+ "yue_neoi5",
+ "yue_neot6",
+ "yue_ng4",
+ "yue_ng5",
+ "yue_ng6",
+ "yue_ngaa4",
+ "yue_ngaa5",
+ "yue_ngaa6",
+ "yue_ngaai4",
+ "yue_ngaai6",
+ "yue_ngaak6",
+ "yue_ngaam4",
+ "yue_ngaan4",
+ "yue_ngaan5",
+ "yue_ngaan6",
+ "yue_ngaang6",
+ "yue_ngaau4",
+ "yue_ngaau5",
+ "yue_ngai4",
+ "yue_ngai5",
+ "yue_ngai6",
+ "yue_ngam4",
+ "yue_ngan4",
+ "yue_ngau1",
+ "yue_ngau4",
+ "yue_ngau5",
+ "yue_ngo2",
+ "yue_ngo4",
+ "yue_ngo5",
+ "yue_ngo6",
+ "yue_ngoi2",
+ "yue_ngoi4",
+ "yue_ngoi6",
+ "yue_ngok3",
+ "yue_ngok6",
+ "yue_ngon6",
+ "yue_ngong4",
+ "yue_ngou4",
+ "yue_ngou6",
+ "yue_ni1",
+ "yue_nik1",
+ "yue_nik6",
+ "yue_nim1",
+ "yue_nim6",
+ "yue_nin2",
+ "yue_nin4",
+ "yue_ning4",
+ "yue_nip6",
+ "yue_niu5",
+ "yue_no4",
+ "yue_no6",
+ "yue_noeng2",
+ "yue_noeng4",
+ "yue_noi6",
+ "yue_nok6",
+ "yue_nong4",
+ "yue_nou4",
+ "yue_nou5",
+ "yue_nou6",
+ "yue_nung4",
+ "yue_nyun5",
+ "yue_nyun6",
+ "yue_o1",
+ "yue_o2",
+ "yue_oi1",
+ "yue_oi2",
+ "yue_oi3",
+ "yue_ok3",
+ "yue_on1",
+ "yue_on3",
+ "yue_ong3",
+ "yue_ou3",
+ "yue_paa3",
+ "yue_paa4",
+ "yue_paai2",
+ "yue_paai3",
+ "yue_paai4",
+ "yue_paak2",
+ "yue_paak3",
+ "yue_paan1",
+ "yue_paan3",
+ "yue_paang4",
+ "yue_paang5",
+ "yue_paau1",
+ "yue_paau2",
+ "yue_paau3",
+ "yue_paau4",
+ "yue_pai1",
+ "yue_pan3",
+ "yue_pan4",
+ "yue_pang4",
+ "yue_pat1",
+ "yue_pau2",
+ "yue_pei1",
+ "yue_pei2",
+ "yue_pei3",
+ "yue_pei4",
+ "yue_pei5",
+ "yue_pek3",
+ "yue_peng1",
+ "yue_peng4",
+ "yue_pik1",
+ "yue_pin1",
+ "yue_pin2",
+ "yue_pin3",
+ "yue_pin4",
+ "yue_ping1",
+ "yue_ping3",
+ "yue_ping4",
+ "yue_pit3",
+ "yue_piu1",
+ "yue_piu3",
+ "yue_po2",
+ "yue_po3",
+ "yue_po4",
+ "yue_pok1",
+ "yue_pok3",
+ "yue_pong4",
+ "yue_pou1",
+ "yue_pou2",
+ "yue_pou4",
+ "yue_pou5",
+ "yue_pui1",
+ "yue_pui3",
+ "yue_pui4",
+ "yue_pui5",
+ "yue_puk1",
+ "yue_pun1",
+ "yue_pun2",
+ "yue_pun3",
+ "yue_pun4",
+ "yue_pung2",
+ "yue_pung3",
+ "yue_pung4",
+ "yue_put3",
+ "yue_saa1",
+ "yue_saa2",
+ "yue_saai1",
+ "yue_saai2",
+ "yue_saai3",
+ "yue_saam1",
+ "yue_saan1",
+ "yue_saan2",
+ "yue_saan3",
+ "yue_saan4",
+ "yue_saang1",
+ "yue_saang2",
+ "yue_saang4",
+ "yue_saap3",
+ "yue_saap6",
+ "yue_saat3",
+ "yue_saau1",
+ "yue_saau2",
+ "yue_saau3",
+ "yue_sai1",
+ "yue_sai2",
+ "yue_sai3",
+ "yue_sai6",
+ "yue_sak1",
+ "yue_sam1",
+ "yue_sam2",
+ "yue_sam3",
+ "yue_sam6",
+ "yue_san1",
+ "yue_san4",
+ "yue_san5",
+ "yue_san6",
+ "yue_sang1",
+ "yue_sang3",
+ "yue_sap1",
+ "yue_sap6",
+ "yue_sat1",
+ "yue_sat6",
+ "yue_sau1",
+ "yue_sau2",
+ "yue_sau3",
+ "yue_sau4",
+ "yue_sau6",
+ "yue_se1",
+ "yue_se2",
+ "yue_se3",
+ "yue_se4",
+ "yue_se5",
+ "yue_se6",
+ "yue_sei2",
+ "yue_sei3",
+ "yue_sek3",
+ "yue_sek6",
+ "yue_seng1",
+ "yue_seng2",
+ "yue_seng4",
+ "yue_seoi1",
+ "yue_seoi2",
+ "yue_seoi3",
+ "yue_seoi4",
+ "yue_seoi5",
+ "yue_seoi6",
+ "yue_seon1",
+ "yue_seon3",
+ "yue_seon4",
+ "yue_seon6",
+ "yue_seot1",
+ "yue_seot6",
+ "yue_si1",
+ "yue_si2",
+ "yue_si3",
+ "yue_si4",
+ "yue_si5",
+ "yue_si6",
+ "yue_sik1",
+ "yue_sik6",
+ "yue_sim2",
+ "yue_sim4",
+ "yue_sim6",
+ "yue_sin1",
+ "yue_sin3",
+ "yue_sin4",
+ "yue_sin6",
+ "yue_sing1",
+ "yue_sing2",
+ "yue_sing3",
+ "yue_sing4",
+ "yue_sing6",
+ "yue_sip3",
+ "yue_sit3",
+ "yue_sit6",
+ "yue_siu1",
+ "yue_siu2",
+ "yue_siu3",
+ "yue_siu4",
+ "yue_siu6",
+ "yue_so1",
+ "yue_so2",
+ "yue_so4",
+ "yue_soek3",
+ "yue_soeng1",
+ "yue_soeng2",
+ "yue_soeng3",
+ "yue_soeng4",
+ "yue_soeng5",
+ "yue_soeng6",
+ "yue_soi1",
+ "yue_sok3",
+ "yue_song1",
+ "yue_song2",
+ "yue_song3",
+ "yue_sou1",
+ "yue_sou2",
+ "yue_sou3",
+ "yue_suk1",
+ "yue_suk6",
+ "yue_sung1",
+ "yue_sung2",
+ "yue_sung3",
+ "yue_sung4",
+ "yue_syu1",
+ "yue_syu2",
+ "yue_syu3",
+ "yue_syu4",
+ "yue_syu6",
+ "yue_syun1",
+ "yue_syun2",
+ "yue_syun3",
+ "yue_syun4",
+ "yue_syun5",
+ "yue_syut3",
+ "yue_taa1",
+ "yue_taai1",
+ "yue_taai2",
+ "yue_taai3",
+ "yue_taai5",
+ "yue_taam1",
+ "yue_taam2",
+ "yue_taam3",
+ "yue_taam4",
+ "yue_taam5",
+ "yue_taan1",
+ "yue_taan2",
+ "yue_taan3",
+ "yue_taan4",
+ "yue_taan6",
+ "yue_taap3",
+ "yue_taat3",
+ "yue_tai1",
+ "yue_tai2",
+ "yue_tai3",
+ "yue_tai4",
+ "yue_tam4",
+ "yue_tam5",
+ "yue_tan1",
+ "yue_tang4",
+ "yue_tau1",
+ "yue_tau2",
+ "yue_tau3",
+ "yue_tau4",
+ "yue_tek3",
+ "yue_teng1",
+ "yue_teng5",
+ "yue_teoi1",
+ "yue_teoi2",
+ "yue_teoi3",
+ "yue_teoi4",
+ "yue_teon1",
+ "yue_teon5",
+ "yue_tik1",
+ "yue_tim1",
+ "yue_tim2",
+ "yue_tim4",
+ "yue_tim5",
+ "yue_tin1",
+ "yue_tin2",
+ "yue_tin4",
+ "yue_ting1",
+ "yue_ting2",
+ "yue_ting3",
+ "yue_ting4",
+ "yue_ting5",
+ "yue_tip1",
+ "yue_tip2",
+ "yue_tip3",
+ "yue_tit3",
+ "yue_tiu1",
+ "yue_tiu3",
+ "yue_tiu4",
+ "yue_tiu5",
+ "yue_to1",
+ "yue_to4",
+ "yue_to5",
+ "yue_toe3",
+ "yue_toe5",
+ "yue_toi1",
+ "yue_toi2",
+ "yue_toi4",
+ "yue_toi5",
+ "yue_tok3",
+ "yue_tong1",
+ "yue_tong2",
+ "yue_tong3",
+ "yue_tong4",
+ "yue_tou1",
+ "yue_tou2",
+ "yue_tou3",
+ "yue_tou4",
+ "yue_tou5",
+ "yue_tuk1",
+ "yue_tung1",
+ "yue_tung2",
+ "yue_tung3",
+ "yue_tung4",
+ "yue_tyun4",
+ "yue_tyut3",
+ "yue_uk1",
+ "yue_waa1",
+ "yue_waa2",
+ "yue_waa4",
+ "yue_waa5",
+ "yue_waa6",
+ "yue_waai1",
+ "yue_waai4",
+ "yue_waai6",
+ "yue_waak6",
+ "yue_waan1",
+ "yue_waan2",
+ "yue_waan4",
+ "yue_waan5",
+ "yue_waan6",
+ "yue_waang4",
+ "yue_waat3",
+ "yue_waat6",
+ "yue_wai1",
+ "yue_wai2",
+ "yue_wai3",
+ "yue_wai4",
+ "yue_wai5",
+ "yue_wai6",
+ "yue_wan1",
+ "yue_wan2",
+ "yue_wan3",
+ "yue_wan4",
+ "yue_wan5",
+ "yue_wan6",
+ "yue_wang4",
+ "yue_wat1",
+ "yue_wik6",
+ "yue_wing4",
+ "yue_wing5",
+ "yue_wing6",
+ "yue_wo1",
+ "yue_wo4",
+ "yue_wo6",
+ "yue_wok6",
+ "yue_wong1",
+ "yue_wong2",
+ "yue_wong4",
+ "yue_wong5",
+ "yue_wong6",
+ "yue_wu1",
+ "yue_wu2",
+ "yue_wu3",
+ "yue_wu4",
+ "yue_wu6",
+ "yue_wui1",
+ "yue_wui2",
+ "yue_wui4",
+ "yue_wui6",
+ "yue_wun2",
+ "yue_wun4",
+ "yue_wun5",
+ "yue_wun6",
+ "yue_wut6",
+ "yue_zaa1",
+ "yue_zaa3",
+ "yue_zaai1",
+ "yue_zaai3",
+ "yue_zaak3",
+ "yue_zaak6",
+ "yue_zaam2",
+ "yue_zaam3",
+ "yue_zaam6",
+ "yue_zaan2",
+ "yue_zaan3",
+ "yue_zaan6",
+ "yue_zaang1",
+ "yue_zaap3",
+ "yue_zaap6",
+ "yue_zaat3",
+ "yue_zaau1",
+ "yue_zaau2",
+ "yue_zaau3",
+ "yue_zaau6",
+ "yue_zai1",
+ "yue_zai2",
+ "yue_zai3",
+ "yue_zai6",
+ "yue_zak1",
+ "yue_zam1",
+ "yue_zam2",
+ "yue_zam3",
+ "yue_zan1",
+ "yue_zan3",
+ "yue_zan6",
+ "yue_zang1",
+ "yue_zang2",
+ "yue_zang6",
+ "yue_zap1",
+ "yue_zat1",
+ "yue_zat6",
+ "yue_zau1",
+ "yue_zau2",
+ "yue_zau3",
+ "yue_zau6",
+ "yue_ze1",
+ "yue_ze2",
+ "yue_ze3",
+ "yue_ze4",
+ "yue_ze5",
+ "yue_ze6",
+ "yue_zek3",
+ "yue_zeng1",
+ "yue_zeng2",
+ "yue_zeng3",
+ "yue_zeng6",
+ "yue_zeoi1",
+ "yue_zeoi2",
+ "yue_zeoi3",
+ "yue_zeoi6",
+ "yue_zeon1",
+ "yue_zeon2",
+ "yue_zeon3",
+ "yue_zeon6",
+ "yue_zi1",
+ "yue_zi2",
+ "yue_zi3",
+ "yue_zi6",
+ "yue_zik1",
+ "yue_zik6",
+ "yue_zim1",
+ "yue_zim2",
+ "yue_zim6",
+ "yue_zin1",
+ "yue_zin2",
+ "yue_zin3",
+ "yue_zin6",
+ "yue_zing1",
+ "yue_zing2",
+ "yue_zing3",
+ "yue_zing6",
+ "yue_zip3",
+ "yue_zit1",
+ "yue_zit3",
+ "yue_zit6",
+ "yue_ziu1",
+ "yue_ziu2",
+ "yue_ziu3",
+ "yue_ziu6",
+ "yue_zo2",
+ "yue_zo3",
+ "yue_zo6",
+ "yue_zoek2",
+ "yue_zoek3",
+ "yue_zoek6",
+ "yue_zoeng1",
+ "yue_zoeng2",
+ "yue_zoeng3",
+ "yue_zoeng6",
+ "yue_zoi1",
+ "yue_zoi2",
+ "yue_zoi3",
+ "yue_zoi6",
+ "yue_zok3",
+ "yue_zok6",
+ "yue_zong1",
+ "yue_zong3",
+ "yue_zong6",
+ "yue_zou1",
+ "yue_zou2",
+ "yue_zou6",
+ "yue_zuk1",
+ "yue_zuk6",
+ "yue_zung1",
+ "yue_zung2",
+ "yue_zung3",
+ "yue_zung6",
+ "yue_zyu1",
+ "yue_zyu2",
+ "yue_zyu3",
+ "yue_zyu6",
+ "yue_zyun1",
+ "yue_zyun2",
+ "yue_zyun3",
+ "yue_zyun6",
+ "yue_zyut3",
+ "yue_zyut6",
+ "zh_a1",
+ "zh_a4",
+ "zh_a5",
+ "zh_ai1",
+ "zh_ai2",
+ "zh_ai3",
+ "zh_ai4",
+ "zh_an1",
+ "zh_an3",
+ "zh_an4",
+ "zh_ang1",
+ "zh_ang2",
+ "zh_ang4",
+ "zh_ao1",
+ "zh_ao2",
+ "zh_ao3",
+ "zh_ao4",
+ "zh_ba1",
+ "zh_ba2",
+ "zh_ba3",
+ "zh_ba4",
+ "zh_ba5",
+ "zh_bai1",
+ "zh_bai2",
+ "zh_bai3",
+ "zh_bai4",
+ "zh_ban1",
+ "zh_ban3",
+ "zh_ban4",
+ "zh_bang1",
+ "zh_bang3",
+ "zh_bang4",
+ "zh_bao1",
+ "zh_bao2",
+ "zh_bao3",
+ "zh_bao4",
+ "zh_bei1",
+ "zh_bei3",
+ "zh_bei4",
+ "zh_bei5",
+ "zh_ben1",
+ "zh_ben3",
+ "zh_ben4",
+ "zh_beng1",
+ "zh_beng3",
+ "zh_beng4",
+ "zh_bi1",
+ "zh_bi2",
+ "zh_bi3",
+ "zh_bi4",
+ "zh_bian1",
+ "zh_bian3",
+ "zh_bian4",
+ "zh_bian5",
+ "zh_biao1",
+ "zh_biao3",
+ "zh_bie1",
+ "zh_bie2",
+ "zh_bie3",
+ "zh_bin1",
+ "zh_bin4",
+ "zh_bing1",
+ "zh_bing3",
+ "zh_bing4",
+ "zh_bo1",
+ "zh_bo2",
+ "zh_bo3",
+ "zh_bo4",
+ "zh_bo5",
+ "zh_bu3",
+ "zh_bu4",
+ "zh_ca1",
+ "zh_cai1",
+ "zh_cai2",
+ "zh_cai3",
+ "zh_cai4",
+ "zh_can1",
+ "zh_can2",
+ "zh_can3",
+ "zh_can4",
+ "zh_cang1",
+ "zh_cang2",
+ "zh_cao1",
+ "zh_cao2",
+ "zh_cao3",
+ "zh_ce4",
+ "zh_cen2",
+ "zh_ceng2",
+ "zh_ceng4",
+ "zh_cha1",
+ "zh_cha2",
+ "zh_cha4",
+ "zh_chai1",
+ "zh_chai2",
+ "zh_chan1",
+ "zh_chan2",
+ "zh_chan3",
+ "zh_chan4",
+ "zh_chang1",
+ "zh_chang2",
+ "zh_chang3",
+ "zh_chang4",
+ "zh_chao1",
+ "zh_chao2",
+ "zh_chao3",
+ "zh_che1",
+ "zh_che3",
+ "zh_che4",
+ "zh_chen1",
+ "zh_chen2",
+ "zh_chen3",
+ "zh_chen4",
+ "zh_cheng1",
+ "zh_cheng2",
+ "zh_cheng3",
+ "zh_cheng4",
+ "zh_chi1",
+ "zh_chi2",
+ "zh_chi3",
+ "zh_chi4",
+ "zh_chong1",
+ "zh_chong2",
+ "zh_chong3",
+ "zh_chou1",
+ "zh_chou2",
+ "zh_chou3",
+ "zh_chou4",
+ "zh_chu1",
+ "zh_chu2",
+ "zh_chu3",
+ "zh_chu4",
+ "zh_chuai1",
+ "zh_chuai3",
+ "zh_chuai4",
+ "zh_chuan1",
+ "zh_chuan2",
+ "zh_chuan3",
+ "zh_chuan4",
+ "zh_chuang1",
+ "zh_chuang2",
+ "zh_chuang3",
+ "zh_chuang4",
+ "zh_chui1",
+ "zh_chui2",
+ "zh_chun1",
+ "zh_chun2",
+ "zh_chun3",
+ "zh_chuo1",
+ "zh_chuo4",
+ "zh_ci1",
+ "zh_ci2",
+ "zh_ci3",
+ "zh_ci4",
+ "zh_cong1",
+ "zh_cong2",
+ "zh_cou4",
+ "zh_cu1",
+ "zh_cu4",
+ "zh_cuan1",
+ "zh_cuan2",
+ "zh_cuan4",
+ "zh_cui1",
+ "zh_cui3",
+ "zh_cui4",
+ "zh_cun1",
+ "zh_cun2",
+ "zh_cun4",
+ "zh_cuo1",
+ "zh_cuo2",
+ "zh_cuo4",
+ "zh_da1",
+ "zh_da2",
+ "zh_da3",
+ "zh_da4",
+ "zh_dai1",
+ "zh_dai3",
+ "zh_dai4",
+ "zh_dan1",
+ "zh_dan3",
+ "zh_dan4",
+ "zh_dang1",
+ "zh_dang3",
+ "zh_dang4",
+ "zh_dao1",
+ "zh_dao3",
+ "zh_dao4",
+ "zh_de1",
+ "zh_de2",
+ "zh_de5",
+ "zh_deng1",
+ "zh_deng3",
+ "zh_deng4",
+ "zh_di1",
+ "zh_di2",
+ "zh_di3",
+ "zh_di4",
+ "zh_dia3",
+ "zh_dian1",
+ "zh_dian3",
+ "zh_dian4",
+ "zh_diao1",
+ "zh_diao3",
+ "zh_diao4",
+ "zh_die1",
+ "zh_die2",
+ "zh_ding1",
+ "zh_ding3",
+ "zh_ding4",
+ "zh_diu1",
+ "zh_dong1",
+ "zh_dong3",
+ "zh_dong4",
+ "zh_dou1",
+ "zh_dou3",
+ "zh_dou4",
+ "zh_du1",
+ "zh_du2",
+ "zh_du3",
+ "zh_du4",
+ "zh_duan1",
+ "zh_duan3",
+ "zh_duan4",
+ "zh_dui1",
+ "zh_dui4",
+ "zh_dun1",
+ "zh_dun3",
+ "zh_dun4",
+ "zh_duo1",
+ "zh_duo2",
+ "zh_duo3",
+ "zh_duo4",
+ "zh_e1",
+ "zh_e2",
+ "zh_e3",
+ "zh_e4",
+ "zh_en1",
+ "zh_en4",
+ "zh_eng1",
+ "zh_er2",
+ "zh_er3",
+ "zh_er4",
+ "zh_fa1",
+ "zh_fa2",
+ "zh_fa3",
+ "zh_fan1",
+ "zh_fan2",
+ "zh_fan3",
+ "zh_fan4",
+ "zh_fang1",
+ "zh_fang2",
+ "zh_fang3",
+ "zh_fang4",
+ "zh_fei1",
+ "zh_fei2",
+ "zh_fei3",
+ "zh_fei4",
+ "zh_fen1",
+ "zh_fen2",
+ "zh_fen3",
+ "zh_fen4",
+ "zh_feng1",
+ "zh_feng2",
+ "zh_feng3",
+ "zh_feng4",
+ "zh_fo2",
+ "zh_fou3",
+ "zh_fu1",
+ "zh_fu2",
+ "zh_fu3",
+ "zh_fu4",
+ "zh_ga1",
+ "zh_ga2",
+ "zh_ga3",
+ "zh_ga4",
+ "zh_gai1",
+ "zh_gai3",
+ "zh_gai4",
+ "zh_gan1",
+ "zh_gan3",
+ "zh_gan4",
+ "zh_gang1",
+ "zh_gang3",
+ "zh_gang4",
+ "zh_gao1",
+ "zh_gao3",
+ "zh_gao4",
+ "zh_ge1",
+ "zh_ge2",
+ "zh_ge3",
+ "zh_ge4",
+ "zh_gei3",
+ "zh_gen1",
+ "zh_gen2",
+ "zh_gen4",
+ "zh_geng1",
+ "zh_geng3",
+ "zh_geng4",
+ "zh_gong1",
+ "zh_gong3",
+ "zh_gong4",
+ "zh_gou1",
+ "zh_gou3",
+ "zh_gou4",
+ "zh_gu1",
+ "zh_gu3",
+ "zh_gu4",
+ "zh_gua1",
+ "zh_gua3",
+ "zh_gua4",
+ "zh_guai1",
+ "zh_guai3",
+ "zh_guai4",
+ "zh_guan1",
+ "zh_guan3",
+ "zh_guan4",
+ "zh_guang1",
+ "zh_guang3",
+ "zh_guang4",
+ "zh_gui1",
+ "zh_gui3",
+ "zh_gui4",
+ "zh_gun3",
+ "zh_gun4",
+ "zh_guo1",
+ "zh_guo2",
+ "zh_guo3",
+ "zh_guo4",
+ "zh_guo5",
+ "zh_ha1",
+ "zh_hai1",
+ "zh_hai2",
+ "zh_hai3",
+ "zh_hai4",
+ "zh_han1",
+ "zh_han2",
+ "zh_han3",
+ "zh_han4",
+ "zh_hang1",
+ "zh_hang2",
+ "zh_hao1",
+ "zh_hao2",
+ "zh_hao3",
+ "zh_hao4",
+ "zh_he1",
+ "zh_he2",
+ "zh_he4",
+ "zh_hei1",
+ "zh_hen2",
+ "zh_hen3",
+ "zh_hen4",
+ "zh_heng1",
+ "zh_heng2",
+ "zh_heng4",
+ "zh_hong1",
+ "zh_hong2",
+ "zh_hong3",
+ "zh_hong4",
+ "zh_hou1",
+ "zh_hou2",
+ "zh_hou3",
+ "zh_hou4",
+ "zh_hu1",
+ "zh_hu2",
+ "zh_hu3",
+ "zh_hu4",
+ "zh_hua1",
+ "zh_hua2",
+ "zh_hua4",
+ "zh_huai2",
+ "zh_huai4",
+ "zh_huan1",
+ "zh_huan2",
+ "zh_huan3",
+ "zh_huan4",
+ "zh_huang1",
+ "zh_huang2",
+ "zh_huang3",
+ "zh_huang4",
+ "zh_hui1",
+ "zh_hui2",
+ "zh_hui3",
+ "zh_hui4",
+ "zh_hun1",
+ "zh_hun2",
+ "zh_hun4",
+ "zh_huo1",
+ "zh_huo2",
+ "zh_huo3",
+ "zh_huo4",
+ "zh_ji1",
+ "zh_ji2",
+ "zh_ji3",
+ "zh_ji4",
+ "zh_jia1",
+ "zh_jia2",
+ "zh_jia3",
+ "zh_jia4",
+ "zh_jian1",
+ "zh_jian3",
+ "zh_jian4",
+ "zh_jiang1",
+ "zh_jiang3",
+ "zh_jiang4",
+ "zh_jiao1",
+ "zh_jiao2",
+ "zh_jiao3",
+ "zh_jiao4",
+ "zh_jie1",
+ "zh_jie2",
+ "zh_jie3",
+ "zh_jie4",
+ "zh_jin1",
+ "zh_jin3",
+ "zh_jin4",
+ "zh_jing1",
+ "zh_jing3",
+ "zh_jing4",
+ "zh_jiong3",
+ "zh_jiu1",
+ "zh_jiu3",
+ "zh_jiu4",
+ "zh_ju1",
+ "zh_ju2",
+ "zh_ju3",
+ "zh_ju4",
+ "zh_juan1",
+ "zh_juan3",
+ "zh_juan4",
+ "zh_jue1",
+ "zh_jue2",
+ "zh_jun1",
+ "zh_jun4",
+ "zh_ka1",
+ "zh_ka3",
+ "zh_kai1",
+ "zh_kai3",
+ "zh_kai4",
+ "zh_kan1",
+ "zh_kan3",
+ "zh_kan4",
+ "zh_kang1",
+ "zh_kang2",
+ "zh_kang4",
+ "zh_kao3",
+ "zh_kao4",
+ "zh_ke1",
+ "zh_ke2",
+ "zh_ke3",
+ "zh_ke4",
+ "zh_ken3",
+ "zh_keng1",
+ "zh_kong1",
+ "zh_kong3",
+ "zh_kong4",
+ "zh_kou1",
+ "zh_kou3",
+ "zh_kou4",
+ "zh_ku1",
+ "zh_ku3",
+ "zh_ku4",
+ "zh_kua1",
+ "zh_kua3",
+ "zh_kua4",
+ "zh_kuai3",
+ "zh_kuai4",
+ "zh_kuan1",
+ "zh_kuan3",
+ "zh_kuang1",
+ "zh_kuang2",
+ "zh_kuang3",
+ "zh_kuang4",
+ "zh_kui1",
+ "zh_kui2",
+ "zh_kui3",
+ "zh_kui4",
+ "zh_kun1",
+ "zh_kun3",
+ "zh_kun4",
+ "zh_kuo4",
+ "zh_la1",
+ "zh_la2",
+ "zh_la3",
+ "zh_la4",
+ "zh_la5",
+ "zh_lai2",
+ "zh_lai4",
+ "zh_lan2",
+ "zh_lan3",
+ "zh_lan4",
+ "zh_lang1",
+ "zh_lang2",
+ "zh_lang3",
+ "zh_lang4",
+ "zh_lao1",
+ "zh_lao2",
+ "zh_lao3",
+ "zh_lao4",
+ "zh_le4",
+ "zh_le5",
+ "zh_lei2",
+ "zh_lei3",
+ "zh_lei4",
+ "zh_lei5",
+ "zh_leng2",
+ "zh_leng3",
+ "zh_leng4",
+ "zh_li2",
+ "zh_li3",
+ "zh_li4",
+ "zh_lia3",
+ "zh_lian2",
+ "zh_lian3",
+ "zh_lian4",
+ "zh_liang2",
+ "zh_liang3",
+ "zh_liang4",
+ "zh_liao1",
+ "zh_liao2",
+ "zh_liao3",
+ "zh_liao4",
+ "zh_lie1",
+ "zh_lie3",
+ "zh_lie4",
+ "zh_lin1",
+ "zh_lin2",
+ "zh_lin3",
+ "zh_lin4",
+ "zh_ling2",
+ "zh_ling3",
+ "zh_ling4",
+ "zh_liu1",
+ "zh_liu2",
+ "zh_liu3",
+ "zh_liu4",
+ "zh_long2",
+ "zh_long3",
+ "zh_long4",
+ "zh_lou2",
+ "zh_lou3",
+ "zh_lou4",
+ "zh_lu1",
+ "zh_lu2",
+ "zh_lu3",
+ "zh_lu4",
+ "zh_lu:2",
+ "zh_lu:3",
+ "zh_lu:4",
+ "zh_lu:e4",
+ "zh_luan2",
+ "zh_luan3",
+ "zh_luan4",
+ "zh_lun2",
+ "zh_lun4",
+ "zh_luo1",
+ "zh_luo2",
+ "zh_luo3",
+ "zh_luo4",
+ "zh_ma1",
+ "zh_ma2",
+ "zh_ma3",
+ "zh_ma4",
+ "zh_ma5",
+ "zh_mai2",
+ "zh_mai3",
+ "zh_mai4",
+ "zh_man2",
+ "zh_man3",
+ "zh_man4",
+ "zh_mang2",
+ "zh_mang3",
+ "zh_mao1",
+ "zh_mao2",
+ "zh_mao3",
+ "zh_mao4",
+ "zh_me5",
+ "zh_mei2",
+ "zh_mei3",
+ "zh_mei4",
+ "zh_men1",
+ "zh_men2",
+ "zh_men4",
+ "zh_men5",
+ "zh_meng2",
+ "zh_meng3",
+ "zh_meng4",
+ "zh_mi1",
+ "zh_mi2",
+ "zh_mi3",
+ "zh_mi4",
+ "zh_mian2",
+ "zh_mian3",
+ "zh_mian4",
+ "zh_miao1",
+ "zh_miao2",
+ "zh_miao3",
+ "zh_miao4",
+ "zh_mie1",
+ "zh_mie4",
+ "zh_min2",
+ "zh_min3",
+ "zh_ming2",
+ "zh_ming3",
+ "zh_ming4",
+ "zh_miu4",
+ "zh_mo1",
+ "zh_mo2",
+ "zh_mo3",
+ "zh_mo4",
+ "zh_mou2",
+ "zh_mou3",
+ "zh_mu2",
+ "zh_mu3",
+ "zh_mu4",
+ "zh_na2",
+ "zh_na3",
+ "zh_na4",
+ "zh_na5",
+ "zh_nai3",
+ "zh_nai4",
+ "zh_nan1",
+ "zh_nan2",
+ "zh_nan3",
+ "zh_nan4",
+ "zh_nang1",
+ "zh_nang2",
+ "zh_nao1",
+ "zh_nao2",
+ "zh_nao3",
+ "zh_nao4",
+ "zh_ne4",
+ "zh_ne5",
+ "zh_nei3",
+ "zh_nei4",
+ "zh_nen4",
+ "zh_neng2",
+ "zh_ni1",
+ "zh_ni2",
+ "zh_ni3",
+ "zh_ni4",
+ "zh_nian1",
+ "zh_nian2",
+ "zh_nian3",
+ "zh_nian4",
+ "zh_niang2",
+ "zh_niang4",
+ "zh_niao3",
+ "zh_niao4",
+ "zh_nie1",
+ "zh_nie4",
+ "zh_nin2",
+ "zh_ning2",
+ "zh_ning3",
+ "zh_ning4",
+ "zh_niu1",
+ "zh_niu2",
+ "zh_niu3",
+ "zh_niu4",
+ "zh_nong2",
+ "zh_nong4",
+ "zh_nu2",
+ "zh_nu3",
+ "zh_nu4",
+ "zh_nu:3",
+ "zh_nu:e4",
+ "zh_nuan3",
+ "zh_nuo2",
+ "zh_nuo4",
+ "zh_o1",
+ "zh_o4",
+ "zh_o5",
+ "zh_ou1",
+ "zh_ou3",
+ "zh_ou4",
+ "zh_pa1",
+ "zh_pa2",
+ "zh_pa4",
+ "zh_pai1",
+ "zh_pai2",
+ "zh_pai4",
+ "zh_pan1",
+ "zh_pan2",
+ "zh_pan4",
+ "zh_pang1",
+ "zh_pang2",
+ "zh_pang4",
+ "zh_pao1",
+ "zh_pao2",
+ "zh_pao3",
+ "zh_pao4",
+ "zh_pei1",
+ "zh_pei2",
+ "zh_pei4",
+ "zh_pen1",
+ "zh_pen2",
+ "zh_peng1",
+ "zh_peng2",
+ "zh_peng3",
+ "zh_peng4",
+ "zh_pi1",
+ "zh_pi2",
+ "zh_pi3",
+ "zh_pi4",
+ "zh_pian1",
+ "zh_pian2",
+ "zh_pian3",
+ "zh_pian4",
+ "zh_piao1",
+ "zh_piao2",
+ "zh_piao3",
+ "zh_piao4",
+ "zh_pie1",
+ "zh_pie3",
+ "zh_pin1",
+ "zh_pin2",
+ "zh_pin3",
+ "zh_pin4",
+ "zh_ping1",
+ "zh_ping2",
+ "zh_po1",
+ "zh_po2",
+ "zh_po3",
+ "zh_po4",
+ "zh_pou1",
+ "zh_pou2",
+ "zh_pu1",
+ "zh_pu2",
+ "zh_pu3",
+ "zh_pu4",
+ "zh_qi1",
+ "zh_qi2",
+ "zh_qi3",
+ "zh_qi4",
+ "zh_qia1",
+ "zh_qia3",
+ "zh_qia4",
+ "zh_qian1",
+ "zh_qian2",
+ "zh_qian3",
+ "zh_qian4",
+ "zh_qiang1",
+ "zh_qiang2",
+ "zh_qiang3",
+ "zh_qiang4",
+ "zh_qiao1",
+ "zh_qiao2",
+ "zh_qiao3",
+ "zh_qiao4",
+ "zh_qie1",
+ "zh_qie2",
+ "zh_qie3",
+ "zh_qie4",
+ "zh_qin1",
+ "zh_qin2",
+ "zh_qin3",
+ "zh_qin4",
+ "zh_qing1",
+ "zh_qing2",
+ "zh_qing3",
+ "zh_qing4",
+ "zh_qiong2",
+ "zh_qiu1",
+ "zh_qiu2",
+ "zh_qiu3",
+ "zh_qu1",
+ "zh_qu2",
+ "zh_qu3",
+ "zh_qu4",
+ "zh_quan1",
+ "zh_quan2",
+ "zh_quan3",
+ "zh_quan4",
+ "zh_que1",
+ "zh_que2",
+ "zh_que4",
+ "zh_qun1",
+ "zh_qun2",
+ "zh_r5",
+ "zh_ran2",
+ "zh_ran3",
+ "zh_rang2",
+ "zh_rang3",
+ "zh_rang4",
+ "zh_rao2",
+ "zh_rao3",
+ "zh_rao4",
+ "zh_re3",
+ "zh_re4",
+ "zh_ren2",
+ "zh_ren3",
+ "zh_ren4",
+ "zh_reng1",
+ "zh_reng2",
+ "zh_ri4",
+ "zh_rong2",
+ "zh_rong3",
+ "zh_rou2",
+ "zh_rou4",
+ "zh_ru2",
+ "zh_ru3",
+ "zh_ru4",
+ "zh_ruan3",
+ "zh_rui3",
+ "zh_rui4",
+ "zh_run4",
+ "zh_ruo2",
+ "zh_ruo4",
+ "zh_sa1",
+ "zh_sa3",
+ "zh_sa4",
+ "zh_sai1",
+ "zh_sai4",
+ "zh_san1",
+ "zh_san3",
+ "zh_san4",
+ "zh_sang1",
+ "zh_sang3",
+ "zh_sang4",
+ "zh_sao1",
+ "zh_sao3",
+ "zh_sao4",
+ "zh_se4",
+ "zh_sen1",
+ "zh_seng1",
+ "zh_sha1",
+ "zh_sha2",
+ "zh_sha3",
+ "zh_sha4",
+ "zh_shai1",
+ "zh_shai4",
+ "zh_shan1",
+ "zh_shan3",
+ "zh_shan4",
+ "zh_shang1",
+ "zh_shang3",
+ "zh_shang4",
+ "zh_shao1",
+ "zh_shao2",
+ "zh_shao3",
+ "zh_shao4",
+ "zh_she1",
+ "zh_she2",
+ "zh_she3",
+ "zh_she4",
+ "zh_shei2",
+ "zh_shen1",
+ "zh_shen2",
+ "zh_shen3",
+ "zh_shen4",
+ "zh_sheng1",
+ "zh_sheng2",
+ "zh_sheng3",
+ "zh_sheng4",
+ "zh_shi1",
+ "zh_shi2",
+ "zh_shi3",
+ "zh_shi4",
+ "zh_shou1",
+ "zh_shou3",
+ "zh_shou4",
+ "zh_shu1",
+ "zh_shu2",
+ "zh_shu3",
+ "zh_shu4",
+ "zh_shua1",
+ "zh_shua3",
+ "zh_shuai1",
+ "zh_shuai3",
+ "zh_shuai4",
+ "zh_shuan1",
+ "zh_shuan4",
+ "zh_shuang1",
+ "zh_shuang3",
+ "zh_shui3",
+ "zh_shui4",
+ "zh_shun3",
+ "zh_shun4",
+ "zh_shuo1",
+ "zh_shuo4",
+ "zh_si1",
+ "zh_si3",
+ "zh_si4",
+ "zh_song1",
+ "zh_song3",
+ "zh_song4",
+ "zh_sou1",
+ "zh_sou3",
+ "zh_sou4",
+ "zh_su1",
+ "zh_su2",
+ "zh_su4",
+ "zh_suan1",
+ "zh_suan4",
+ "zh_sui1",
+ "zh_sui2",
+ "zh_sui3",
+ "zh_sui4",
+ "zh_sun1",
+ "zh_sun3",
+ "zh_suo1",
+ "zh_suo3",
+ "zh_ta1",
+ "zh_ta3",
+ "zh_ta4",
+ "zh_tai1",
+ "zh_tai2",
+ "zh_tai4",
+ "zh_tan1",
+ "zh_tan2",
+ "zh_tan3",
+ "zh_tan4",
+ "zh_tang1",
+ "zh_tang2",
+ "zh_tang3",
+ "zh_tang4",
+ "zh_tao1",
+ "zh_tao2",
+ "zh_tao3",
+ "zh_tao4",
+ "zh_te4",
+ "zh_teng2",
+ "zh_ti1",
+ "zh_ti2",
+ "zh_ti3",
+ "zh_ti4",
+ "zh_tian1",
+ "zh_tian2",
+ "zh_tian3",
+ "zh_tiao1",
+ "zh_tiao2",
+ "zh_tiao3",
+ "zh_tiao4",
+ "zh_tie1",
+ "zh_tie3",
+ "zh_tie4",
+ "zh_ting1",
+ "zh_ting2",
+ "zh_ting3",
+ "zh_tong1",
+ "zh_tong2",
+ "zh_tong3",
+ "zh_tong4",
+ "zh_tou1",
+ "zh_tou2",
+ "zh_tou4",
+ "zh_tu1",
+ "zh_tu2",
+ "zh_tu3",
+ "zh_tu4",
+ "zh_tuan1",
+ "zh_tuan2",
+ "zh_tui1",
+ "zh_tui2",
+ "zh_tui3",
+ "zh_tui4",
+ "zh_tun1",
+ "zh_tun2",
+ "zh_tun4",
+ "zh_tuo1",
+ "zh_tuo2",
+ "zh_tuo3",
+ "zh_tuo4",
+ "zh_wa1",
+ "zh_wa2",
+ "zh_wa3",
+ "zh_wa4",
+ "zh_wai1",
+ "zh_wai4",
+ "zh_wan1",
+ "zh_wan2",
+ "zh_wan3",
+ "zh_wan4",
+ "zh_wang1",
+ "zh_wang2",
+ "zh_wang3",
+ "zh_wang4",
+ "zh_wei1",
+ "zh_wei2",
+ "zh_wei3",
+ "zh_wei4",
+ "zh_wen1",
+ "zh_wen2",
+ "zh_wen3",
+ "zh_wen4",
+ "zh_weng1",
+ "zh_weng4",
+ "zh_wo1",
+ "zh_wo3",
+ "zh_wo4",
+ "zh_wo5",
+ "zh_wu1",
+ "zh_wu2",
+ "zh_wu3",
+ "zh_wu4",
+ "zh_xi1",
+ "zh_xi2",
+ "zh_xi3",
+ "zh_xi4",
+ "zh_xia1",
+ "zh_xia2",
+ "zh_xia4",
+ "zh_xian1",
+ "zh_xian2",
+ "zh_xian3",
+ "zh_xian4",
+ "zh_xiang1",
+ "zh_xiang2",
+ "zh_xiang3",
+ "zh_xiang4",
+ "zh_xiao1",
+ "zh_xiao2",
+ "zh_xiao3",
+ "zh_xiao4",
+ "zh_xie1",
+ "zh_xie2",
+ "zh_xie3",
+ "zh_xie4",
+ "zh_xin1",
+ "zh_xin4",
+ "zh_xing1",
+ "zh_xing2",
+ "zh_xing3",
+ "zh_xing4",
+ "zh_xiong1",
+ "zh_xiong2",
+ "zh_xiu1",
+ "zh_xiu3",
+ "zh_xiu4",
+ "zh_xu1",
+ "zh_xu2",
+ "zh_xu3",
+ "zh_xu4",
+ "zh_xuan1",
+ "zh_xuan2",
+ "zh_xuan3",
+ "zh_xuan4",
+ "zh_xue1",
+ "zh_xue2",
+ "zh_xue3",
+ "zh_xue4",
+ "zh_xun1",
+ "zh_xun2",
+ "zh_xun4",
+ "zh_ya1",
+ "zh_ya2",
+ "zh_ya3",
+ "zh_ya4",
+ "zh_ya5",
+ "zh_yan1",
+ "zh_yan2",
+ "zh_yan3",
+ "zh_yan4",
+ "zh_yang1",
+ "zh_yang2",
+ "zh_yang3",
+ "zh_yang4",
+ "zh_yao1",
+ "zh_yao2",
+ "zh_yao3",
+ "zh_yao4",
+ "zh_ye1",
+ "zh_ye2",
+ "zh_ye3",
+ "zh_ye4",
+ "zh_yi1",
+ "zh_yi2",
+ "zh_yi3",
+ "zh_yi4",
+ "zh_yin1",
+ "zh_yin2",
+ "zh_yin3",
+ "zh_yin4",
+ "zh_ying1",
+ "zh_ying2",
+ "zh_ying3",
+ "zh_ying4",
+ "zh_yo1",
+ "zh_yo5",
+ "zh_yong1",
+ "zh_yong3",
+ "zh_yong4",
+ "zh_you1",
+ "zh_you2",
+ "zh_you3",
+ "zh_you4",
+ "zh_yu1",
+ "zh_yu2",
+ "zh_yu3",
+ "zh_yu4",
+ "zh_yuan1",
+ "zh_yuan2",
+ "zh_yuan3",
+ "zh_yuan4",
+ "zh_yue1",
+ "zh_yue4",
+ "zh_yun1",
+ "zh_yun2",
+ "zh_yun3",
+ "zh_yun4",
+ "zh_za1",
+ "zh_za2",
+ "zh_za3",
+ "zh_zai1",
+ "zh_zai3",
+ "zh_zai4",
+ "zh_zan1",
+ "zh_zan2",
+ "zh_zan3",
+ "zh_zan4",
+ "zh_zang1",
+ "zh_zang4",
+ "zh_zao1",
+ "zh_zao2",
+ "zh_zao3",
+ "zh_zao4",
+ "zh_ze2",
+ "zh_ze4",
+ "zh_zei2",
+ "zh_zen3",
+ "zh_zeng1",
+ "zh_zeng4",
+ "zh_zha1",
+ "zh_zha2",
+ "zh_zha3",
+ "zh_zha4",
+ "zh_zhai1",
+ "zh_zhai2",
+ "zh_zhai3",
+ "zh_zhai4",
+ "zh_zhan1",
+ "zh_zhan3",
+ "zh_zhan4",
+ "zh_zhang1",
+ "zh_zhang3",
+ "zh_zhang4",
+ "zh_zhao1",
+ "zh_zhao2",
+ "zh_zhao3",
+ "zh_zhao4",
+ "zh_zhe1",
+ "zh_zhe2",
+ "zh_zhe3",
+ "zh_zhe4",
+ "zh_zhe5",
+ "zh_zhen1",
+ "zh_zhen3",
+ "zh_zhen4",
+ "zh_zheng1",
+ "zh_zheng3",
+ "zh_zheng4",
+ "zh_zhi1",
+ "zh_zhi2",
+ "zh_zhi3",
+ "zh_zhi4",
+ "zh_zhong1",
+ "zh_zhong3",
+ "zh_zhong4",
+ "zh_zhou1",
+ "zh_zhou2",
+ "zh_zhou3",
+ "zh_zhou4",
+ "zh_zhu1",
+ "zh_zhu2",
+ "zh_zhu3",
+ "zh_zhu4",
+ "zh_zhua1",
+ "zh_zhua3",
+ "zh_zhuai3",
+ "zh_zhuai4",
+ "zh_zhuan1",
+ "zh_zhuan3",
+ "zh_zhuan4",
+ "zh_zhuang1",
+ "zh_zhuang4",
+ "zh_zhui1",
+ "zh_zhui4",
+ "zh_zhun1",
+ "zh_zhun3",
+ "zh_zhuo1",
+ "zh_zhuo2",
+ "zh_zi1",
+ "zh_zi3",
+ "zh_zi4",
+ "zh_zi5",
+ "zh_zong1",
+ "zh_zong3",
+ "zh_zong4",
+ "zh_zou1",
+ "zh_zou3",
+ "zh_zou4",
+ "zh_zu1",
+ "zh_zu2",
+ "zh_zu3",
+ "zh_zuan1",
+ "zh_zuan3",
+ "zh_zuan4",
+ "zh_zui3",
+ "zh_zui4",
+ "zh_zun1",
+ "zh_zuo2",
+ "zh_zuo3",
+ "zh_zuo4",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ""
+]
\ No newline at end of file
diff --git a/soulxsinger/utils/pitch_utils.py b/soulxsinger/utils/pitch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee4d772d1d542a11d96c59d763f6cdda3d4c2672
--- /dev/null
+++ b/soulxsinger/utils/pitch_utils.py
@@ -0,0 +1,142 @@
+# https://github.com/gwx314/TechSinger/blob/main/utils/audio/pitch/utils.py
+
+import numpy as np
+import torch
+
+
+def to_lf0(f0):
+ f0[f0 < 1.0e-5] = 1.0e-6
+ lf0 = f0.log() if isinstance(f0, torch.Tensor) else np.log(f0)
+ lf0[f0 < 1.0e-5] = - 1.0E+10
+ return lf0
+
+
+def to_f0(lf0):
+ f0 = np.where(lf0 <= 0, 0.0, np.exp(lf0))
+ return f0.flatten()
+
+
+def f0_to_coarse_mel(f0, f0_bin=256, f0_max=900.0, f0_min=50.0, f0_shift=0):
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
+ is_torch = isinstance(f0, torch.Tensor)
+ f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
+
+ f0_mel[f0_mel <= 1] = 1
+ f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
+ f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(int)
+
+ if f0_shift != 0:
+ if f0_shift > 0:
+ f0_shift = min(f0_shift, f0_bin - 1 - f0_coarse[f0_coarse > 1].max().item())
+ else:
+ f0_shift = max(f0_shift, 1 - f0_coarse[f0_coarse > 1].min().item())
+
+ f0_coarse[f0_coarse > 1] = f0_coarse[f0_coarse > 1] + f0_shift
+
+ assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min(), f0.min(), f0.max())
+ return f0_coarse
+
+
+def coarse_to_f0_mel(f0_coarse, f0_bin=256, f0_max=900.0, f0_min=50.0):
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
+ uv = f0_coarse == 1
+ f0 = f0_mel_min + (f0_coarse - 1) * (f0_mel_max - f0_mel_min) / (f0_bin - 2)
+ f0 = ((f0 / 1127).exp() - 1) * 700
+ f0[uv] = 0
+ return f0
+
+CONST_C1_FREQ = 32.7031956625 # C1 frequency in Hz
+CONST_B6_FREQ = 1975.53320502 # B6 frequency in Hz
+
+def f0_to_coarse_midi(f0, f0_bin=361, f0_max=CONST_B6_FREQ, f0_min=CONST_C1_FREQ, f0_shift=0):
+ is_torch = isinstance(f0, torch.Tensor)
+ uv_mask = f0 <= 0
+
+ if is_torch:
+ f0_safe = torch.maximum(f0, torch.tensor(f0_min))
+ f0_cents = 1200 * torch.log2(f0_safe / f0_min)
+ else:
+ f0_safe = np.maximum(f0, f0_min)
+ f0_cents = 1200 * np.log2(f0_safe / f0_min)
+
+ f0_coarse = (f0_cents / 20) + 1
+
+ if is_torch:
+ f0_coarse = torch.round(f0_coarse).long()
+ f0_coarse = torch.clamp(f0_coarse, min=1, max=f0_bin - 1)
+ else:
+ f0_coarse = np.rint(f0_coarse).astype(int)
+ f0_coarse = np.clip(f0_coarse, 1, f0_bin - 1)
+
+ f0_coarse[uv_mask] = 0
+
+ if f0_shift != 0:
+ if is_torch:
+ voiced = f0_coarse > 0
+ if voiced.any():
+ shifted = f0_coarse[voiced] + f0_shift
+ f0_coarse[voiced] = torch.clamp(shifted, 1, f0_bin - 1)
+ else:
+ voiced = f0_coarse > 0
+ if np.any(voiced):
+ shifted = f0_coarse[voiced] + f0_shift
+ f0_coarse[voiced] = np.clip(shifted, 1, f0_bin - 1)
+
+ return f0_coarse
+
+
+def coarse_to_f0_midi(f0_coarse, f0_bin=361, f0_max=CONST_B6_FREQ, f0_min=CONST_C1_FREQ):
+
+ uv_mask = f0_coarse == 0
+ cents = (f0_coarse - 1) * 20
+ f0 = f0_min * (2 ** (cents / 1200))
+ f0[uv_mask] = 0
+
+ return f0
+
+
+def norm_f0(f0, uv, pitch_norm='log', f0_mean=400, f0_std=100):
+ is_torch = isinstance(f0, torch.Tensor)
+ if pitch_norm == 'standard':
+ f0 = (f0 - f0_mean) / f0_std
+ if pitch_norm == 'log':
+ f0 = torch.log2(f0 + 1e-8) if is_torch else np.log2(f0 + 1e-8)
+ if uv is not None:
+ f0[uv > 0] = 0
+ return f0
+
+
+def norm_interp_f0(f0, pitch_norm='log', f0_mean=None, f0_std=None):
+ is_torch = isinstance(f0, torch.Tensor)
+ if is_torch:
+ device = f0.device
+ f0 = f0.data.cpu().numpy()
+ uv = f0 == 0
+ f0 = norm_f0(f0, uv, pitch_norm, f0_mean, f0_std)
+ if sum(uv) == len(f0):
+ f0[uv] = 0
+ elif sum(uv) > 0:
+ f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
+ if is_torch:
+ uv = torch.FloatTensor(uv)
+ f0 = torch.FloatTensor(f0)
+ f0 = f0.to(device)
+ uv = uv.to(device)
+ return f0, uv
+
+
+def denorm_f0(f0, uv, pitch_norm='log', f0_mean=400, f0_std=100, pitch_padding=None, min=50, max=900):
+ is_torch = isinstance(f0, torch.Tensor)
+ if pitch_norm == 'standard':
+ f0 = f0 * f0_std + f0_mean
+ if pitch_norm == 'log':
+ f0 = 2 ** f0
+ f0 = f0.clamp(min=min, max=max) if is_torch else np.clip(f0, a_min=min, a_max=max)
+ if uv is not None:
+ f0[uv > 0] = 0
+ if pitch_padding is not None:
+ f0[pitch_padding] = 0
+ return f0
diff --git a/webui.py b/webui.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a0b79641a717e3ffcda6341f6c42e7f1cd8c1da
--- /dev/null
+++ b/webui.py
@@ -0,0 +1,789 @@
+import os
+import re
+import random
+import shutil
+import sys
+import traceback
+from pathlib import Path
+from typing import Literal, Tuple
+
+import numpy as np
+import torch
+import librosa
+import soundfile as sf
+import gradio as gr
+
+from preprocess.pipeline import PreprocessPipeline
+from soulxsinger.utils.file_utils import load_config
+from cli.inference import build_model as build_svs_model, process as svs_process
+
+
+ROOT = Path(__file__).parent
+
+ENGLISH_EXAMPLE_PROMPT_AUDIO = "example/audio/en_prompt.mp3"
+ENGLISH_EXAMPLE_PROMPT_META = "example/audio/en_prompt.json"
+ENGLISH_EXAMPLE_TARGET_AUDIO = "example/audio/en_target.mp3"
+ENGLISH_EXAMPLE_TARGET_META = "example/audio/en_target.json"
+
+MANDARIN_EXAMPLE_PROMPT_AUDIO = "example/audio/zh_prompt.mp3"
+MANDARIN_EXAMPLE_PROMPT_META = "example/audio/zh_prompt.json"
+MANDARIN_EXAMPLE_TARGET_AUDIO = "example/audio/zh_target.mp3"
+MANDARIN_EXAMPLE_TARGET_META = "example/audio/zh_target.json"
+
+CANTONESE_EXAMPLE_PROMPT_AUDIO = "example/audio/yue_prompt.mp3"
+CANTONESE_EXAMPLE_PROMPT_META = "example/audio/yue_prompt.json"
+CANTONESE_EXAMPLE_TARGET_AUDIO = "example/audio/yue_target.mp3"
+CANTONESE_EXAMPLE_TARGET_META = "example/audio/yue_target.json"
+
+MUSIC_EXAMPLE_TARGET_AUDIO = "example/audio/music.mp3"
+MUSIC_EXAMPLE_TARGET_META = "example/audio/music.json"
+
+# Lyric language: value (Mandarin/Cantonese/English) is passed to PreprocessPipeline; display labels from i18n via get_lyric_lang_choices()
+
+# Use absolute paths so Examples load correctly (including File components for metadata)
+EXAMPLES_LIST = [
+ [
+ str(ROOT / MANDARIN_EXAMPLE_PROMPT_AUDIO),
+ str(ROOT / MANDARIN_EXAMPLE_TARGET_AUDIO),
+ str(ROOT / MANDARIN_EXAMPLE_PROMPT_META),
+ str(ROOT / MANDARIN_EXAMPLE_TARGET_META),
+ "Mandarin",
+ "Mandarin",
+ "melody",
+ False,
+ True,
+ True,
+ 0,
+ ],
+ [
+ str(ROOT / MANDARIN_EXAMPLE_PROMPT_AUDIO),
+ str(ROOT / CANTONESE_EXAMPLE_TARGET_AUDIO),
+ str(ROOT / MANDARIN_EXAMPLE_PROMPT_META),
+ str(ROOT / CANTONESE_EXAMPLE_TARGET_META),
+ "Mandarin",
+ "Cantonese",
+ "melody",
+ False,
+ True,
+ True,
+ 0,
+ ],
+ [
+ str(ROOT / MANDARIN_EXAMPLE_PROMPT_AUDIO),
+ str(ROOT / ENGLISH_EXAMPLE_TARGET_AUDIO),
+ str(ROOT / MANDARIN_EXAMPLE_PROMPT_META),
+ str(ROOT / ENGLISH_EXAMPLE_TARGET_META),
+ "Mandarin",
+ "English",
+ "melody",
+ False,
+ True,
+ True,
+ 0,
+ ],
+ [
+ str(ROOT / MANDARIN_EXAMPLE_PROMPT_AUDIO),
+ str(ROOT / MUSIC_EXAMPLE_TARGET_AUDIO),
+ str(ROOT / MANDARIN_EXAMPLE_PROMPT_META),
+ str(ROOT / MUSIC_EXAMPLE_TARGET_META),
+ "Mandarin",
+ "Mandarin",
+ "score",
+ False,
+ True,
+ True,
+ 0,
+ ],
+]
+
+
+def _load_example(choice_value):
+ """Return 11 example values + skip_clear_count (2 when loading example so next 2 audio.change events don't clear metadata).
+ choice_value: selected dropdown string (or index in older flow); map to example index 0/1/2."""
+ if choice_value is None:
+ return [gr.update()] * 11 + [0]
+ idx = 0
+ if isinstance(choice_value, int):
+ idx = 0 if choice_value <= 0 else min(choice_value - 1, len(EXAMPLES_LIST) - 1)
+ else:
+ if choice_value == i18n("example_choice_1"):
+ idx = 1
+ elif choice_value == i18n("example_choice_2"):
+ idx = 2
+ elif choice_value == i18n("example_choice_3"):
+ idx = 3
+ elif choice_value == i18n("example_choice_4"):
+ idx = 4
+ if idx <= 0:
+ return [gr.update()] * 11 + [0]
+ list_idx = idx - 1
+ if list_idx >= len(EXAMPLES_LIST):
+ return [gr.update()] * 11 + [0]
+ row = EXAMPLES_LIST[list_idx]
+ return [
+ row[0], row[1], row[2], row[3], row[4], row[5], row[6], row[7], row[8], row[9], row[10],
+ 2, # skip_clear_metadata_count: next 2 audio.change events (prompt + target) will not clear metadata
+ ]
+
+
+def _clear_prompt_meta_unless_example(_audio, skip_count):
+ if skip_count and skip_count > 0:
+ return gr.skip(), max(0, skip_count - 1)
+ return None, 0
+
+
+def _clear_target_meta_unless_example(_audio, skip_count):
+ if skip_count and skip_count > 0:
+ return gr.skip(), max(0, skip_count - 1)
+ return None, 0
+
+
+def _get_device() -> str:
+ """Use CUDA if available, else CPU (e.g. for CI or CPU-only environments)."""
+ return "cuda:0" if torch.cuda.is_available() else "cpu"
+
+
+def _session_dir_from_target(target_audio_path: str) -> Path:
+ stem = Path(target_audio_path).stem
+ safe = re.sub(r"[^\w\-]", "_", stem)
+ safe = re.sub(r"_+", "_", safe).strip("_") or "session"
+ return ROOT / "outputs" / "gradio" / safe[:64]
+
+
+class AppState:
+ def __init__(self) -> None:
+ self.device = _get_device()
+ self.preprocess_pipeline = PreprocessPipeline(
+ device=self.device,
+ language="Mandarin",
+ save_dir=str(ROOT / "outputs" / "gradio" / "_placeholder" / "transcriptions"),
+ vocal_sep=True,
+ max_merge_duration=60000,
+ )
+ config = load_config("soulxsinger/config/soulxsinger.yaml")
+ self.svs_config = config
+ self.svs_model = build_svs_model(
+ model_path="pretrained_models/SoulX-Singer/model.pt",
+ config=config,
+ device=self.device,
+ )
+ self.phoneset_path = "soulxsinger/utils/phoneme/phone_set.json"
+
+ def run_preprocess(
+ self,
+ prompt_path: Path,
+ target_path: Path,
+ session_base: Path,
+ prompt_vocal_sep: bool,
+ target_vocal_sep: bool,
+ prompt_lyric_lang: str,
+ target_lyric_lang: str,
+ ) -> Tuple[bool, str]:
+ try:
+ self.preprocess_pipeline.save_dir = str(session_base / "transcriptions" / "prompt")
+ self.preprocess_pipeline.run(
+ audio_path=str(prompt_path),
+ vocal_sep=prompt_vocal_sep,
+ max_merge_duration=20000,
+ language=prompt_lyric_lang or "Mandarin",
+ )
+ self.preprocess_pipeline.save_dir = str(session_base / "transcriptions" / "target")
+ self.preprocess_pipeline.run(
+ audio_path=str(target_path),
+ vocal_sep=target_vocal_sep,
+ max_merge_duration=60000,
+ language=target_lyric_lang or "Mandarin",
+ )
+ return True, "preprocess done"
+ except Exception as e:
+ return False, f"preprocess failed: {e}"
+
+ def run_svs(
+ self,
+ control: str,
+ session_base: Path,
+ auto_shift: bool,
+ pitch_shift: int,
+ ) -> Tuple[bool, str, Path | None, Path | None, Path | None]:
+ if control not in ("melody", "score"):
+ control = "score"
+ save_dir = session_base / "generated"
+ save_dir.mkdir(parents=True, exist_ok=True)
+ class Args:
+ pass
+ args = Args()
+ args.device = self.device
+ args.model_path = "pretrained_models/SoulX-Singer/model.pt"
+ args.config = "soulxsinger/config/soulxsinger.yaml"
+ args.prompt_wav_path = str(session_base / "audio" / "prompt.wav")
+ prompt_meta_path = session_base / "transcriptions" / "prompt" / "metadata.json"
+ target_meta_path = session_base / "transcriptions" / "target" / "metadata.json"
+ args.prompt_metadata_path = str(prompt_meta_path)
+ args.target_metadata_path = str(target_meta_path)
+ args.phoneset_path = self.phoneset_path
+ args.save_dir = str(save_dir)
+ args.auto_shift = auto_shift
+ args.pitch_shift = int(pitch_shift)
+ args.control = control
+ try:
+ svs_process(args, self.svs_config, self.svs_model)
+ generated = save_dir / "generated.wav"
+ if not generated.exists():
+ return False, f"inference finished but {generated} not found", None, prompt_meta_path, target_meta_path
+ return True, "svs inference done", generated, prompt_meta_path, target_meta_path
+ except Exception as e:
+ return False, f"svs inference failed: {e}", None, prompt_meta_path, target_meta_path
+
+ def run_svs_from_paths(
+ self,
+ prompt_wav_path: str,
+ prompt_metadata_path: str,
+ target_metadata_path: str,
+ control: str,
+ auto_shift: bool,
+ pitch_shift: int,
+ save_dir: Path | None = None,
+ ) -> Tuple[bool, str, Path | None]:
+ """Run SVS from explicit prompt wav and metadata paths."""
+ if save_dir is None:
+ import uuid
+ save_dir = ROOT / "outputs" / "gradio" / "synthesis" / str(uuid.uuid4())[:8]
+ save_dir = Path(save_dir)
+ audio_dir = save_dir / "audio"
+ prompt_meta_dir = save_dir / "transcriptions" / "prompt"
+ target_meta_dir = save_dir / "transcriptions" / "target"
+ audio_dir.mkdir(parents=True, exist_ok=True)
+ prompt_meta_dir.mkdir(parents=True, exist_ok=True)
+ target_meta_dir.mkdir(parents=True, exist_ok=True)
+ shutil.copy2(prompt_wav_path, audio_dir / "prompt.wav")
+ shutil.copy2(prompt_metadata_path, prompt_meta_dir / "metadata.json")
+ shutil.copy2(target_metadata_path, target_meta_dir / "metadata.json")
+ ok, msg, merged, _, _ = self.run_svs(
+ control=control,
+ session_base=save_dir,
+ auto_shift=auto_shift,
+ pitch_shift=pitch_shift,
+ )
+ if not ok or merged is None:
+ return False, msg or "svs failed", None
+ return True, "svs inference done", merged
+
+
+APP_STATE = AppState()
+
+
+# i18n
+_i18n_key2lang_dict = dict(
+ display_lang_label=dict(en="Display Language", zh="显示语言"),
+ seed_label=dict(en="Seed", zh="种子"),
+ prompt_audio_label=dict(en="Prompt audio (reference voice), limit to 30 seconds", zh="Prompt 音频(参考音色),限制在 30 秒以内"),
+ target_audio_label=dict(en="Target audio (melody / lyrics source), limit to 60 seconds", zh="Target 音频(旋律/歌词来源),限制在 60 秒以内"),
+ generate_btn_label=dict(en="Start SVS", zh="开始 SVS"),
+ transcription_btn_label=dict(en="Run singing transcription", zh="开始歌声转录"),
+ synthesis_btn_label=dict(en="Run singing synthesis", zh="歌声合成"),
+ prompt_meta_label=dict(en="Prompt metadata", zh="Prompt metadata"),
+ target_meta_label=dict(en="Target metadata", zh="Target metadata"),
+ edit_tutorial_html=dict(
+ en='Refer to Edit Tutorial for metadata editing
',
+ zh='metadata 编辑请参考 编辑教程
',
+ ),
+ prompt_wav_label=dict(en="Prompt WAV (reference)", zh="Prompt WAV(参考音色)"),
+ generated_audio_label=dict(en="Generated merged audio", zh="合成结果音频"),
+ prompt_lyric_lang_label=dict(en="Prompt lyric language", zh="Prompt 歌词语种"),
+ target_lyric_lang_label=dict(en="Target lyric language", zh="Target 歌词语种"),
+ lyric_lang_mandarin=dict(en="Mandarin", zh="普通话"),
+ lyric_lang_cantonese=dict(en="Cantonese", zh="粤语"),
+ lyric_lang_english=dict(en="English", zh="英语"),
+ warn_missing_synthesis=dict(en="Please provide prompt WAV, prompt metadata, and target metadata", zh="请提供 Prompt WAV、Prompt metadata 与 Target metadata"),
+ prompt_vocal_sep_label=dict(en="Prompt vocal separation", zh="Prompt人声分离"),
+ target_vocal_sep_label=dict(en="Target vocal separation", zh="Target人声分离"),
+ auto_shift_label=dict(en="Auto pitch shift", zh="自动变调"),
+ pitch_shift_label=dict(en="Pitch shift (semitones)", zh="指定变调(半音)"),
+ control_type_label=dict(en="Control type", zh="控制类型"),
+ examples_label=dict(en="Reference examples (click to load)", zh="参考样例(点击加载)"),
+ example_choice_0=dict(en="—", zh="—"),
+ example_choice_1=dict(en="Example 1: Mandarin → Mandarin (melody), Start singing synthesis!", zh="样例 1: 普通话 → 普通话 (melody), 开始歌声合成吧!"),
+ example_choice_2=dict(en="Example 2: Mandarin → Cantonese (melody), Start singing synthesis!", zh="样例 2: 普通话 → 粤语 (melody), 开始歌声合成吧!"),
+ example_choice_3=dict(en="Example 3: Mandarin → English (melody), Start singing synthesis!", zh="样例 3: 普通话 → 英语 (melody), 开始歌声合成吧!"),
+ example_choice_4=dict(en="Example 4: Mandarin → Music (score), Start singing synthesis!", zh="样例 4: 普通话 → 音乐 (score), 开始歌声合成吧!"),
+ warn_missing_audio=dict(
+ en="Please upload both prompt audio and target audio",
+ zh="请上传 Prompt 音频与 Target 音频",
+ ),
+ # Instruction panel (workflow description)
+ instruction_title=dict(en="Usage", zh="使用说明"),
+ instruction_p1=dict(
+ en="After uploading prompt and target audio and clicking **Run singing transcription**, the system generates two metadata files (prompt and target).",
+ zh="上传 Prompt 与 Target 音频并点击「开始歌声转录」后,将生成 Prompt 与 Target 两份 metadata 文件。",
+ ),
+ instruction_p2=dict(
+ en="Auto-transcribed lyrics and notes are often misaligned. For better results, import the generated metadata into the **MIDI Editor** for manual adjustment: [SoulX-Singer-Midi-Editor](https://huggingface.co/spaces/Soul-AILab/SoulX-Singer-Midi-Editor).",
+ zh="自动转录的歌词与音高对齐效果通常不理想,建议将生成的 metadata 导入 **MIDI 编辑器** 进行手动调整:[SoulX-Singer-Midi-Editor](https://huggingface.co/spaces/Soul-AILab/SoulX-Singer-Midi-Editor)。",
+ ),
+ instruction_p3=dict(
+ en="Re-upload the adjusted metadata to the corresponding Prompt / Target Meta fields, then click **Run singing synthesis** to generate the final audio.",
+ zh="将调整后的 metadata 重新上传至对应的 Prompt / Target Meta 位置后,点击「歌声合成」开始最终生成。",
+ ),
+)
+
+def _detect_initial_lang() -> Literal["zh", "en"]:
+ """Detect initial UI language from server locale (browser language applied later via JS)."""
+ try:
+ import locale
+ loc = (locale.getdefaultlocale()[0] or os.environ.get("LANG", "") or "").lower()
+ return "en" if loc.startswith("en") else "zh"
+ except Exception:
+ return "zh"
+
+
+global_lang: Literal["zh", "en"] = _detect_initial_lang()
+
+
+def i18n(key: str) -> str:
+ return _i18n_key2lang_dict[key][global_lang]
+
+
+def get_lyric_lang_choices():
+ """Lyric language dropdown (display, value) for current UI language."""
+ return [
+ (i18n("lyric_lang_mandarin"), "Mandarin"),
+ (i18n("lyric_lang_cantonese"), "Cantonese"),
+ (i18n("lyric_lang_english"), "English"),
+ ]
+
+
+def _resolve_file_path(x):
+ """Gradio file input can be path string or (path, None) tuple."""
+ if x is None:
+ return None
+ if isinstance(x, tuple):
+ x = x[0]
+ return x if (x and os.path.isfile(x)) else None
+
+
+def transcription_function(
+ prompt_audio,
+ target_audio,
+ prompt_metadata,
+ target_metadata,
+ prompt_lyric_lang: str,
+ target_lyric_lang: str,
+ prompt_vocal_sep: bool,
+ target_vocal_sep: bool,
+):
+ """Step 1: Run transcription only; output (prompt_meta_path, target_meta_path)."""
+ try:
+ if isinstance(prompt_audio, tuple):
+ prompt_audio = prompt_audio[0]
+ if isinstance(target_audio, tuple):
+ target_audio = target_audio[0]
+ if prompt_audio is None or target_audio is None:
+ gr.Warning(message=i18n("warn_missing_audio"))
+ return None, None
+ prompt_meta_resolved = _resolve_file_path(prompt_metadata)
+ target_meta_resolved = _resolve_file_path(target_metadata)
+ use_input_metadata = prompt_meta_resolved is not None and target_meta_resolved is not None
+
+ session_base = _session_dir_from_target(target_audio)
+ audio_dir = session_base / "audio"
+ audio_dir.mkdir(parents=True, exist_ok=True)
+ transfer_prompt_path = audio_dir / "prompt.wav"
+ transfer_target_path = audio_dir / "target.wav"
+ SR = 44100
+ PROMPT_MAX_SEC = 30
+ TARGET_MAX_SEC = 60
+ prompt_audio_data, _ = librosa.load(prompt_audio, sr=SR, mono=True)
+ target_audio_data, _ = librosa.load(target_audio, sr=SR, mono=True)
+ prompt_audio_data = prompt_audio_data[: PROMPT_MAX_SEC * SR]
+ target_audio_data = target_audio_data[: TARGET_MAX_SEC * SR]
+ sf.write(transfer_prompt_path, prompt_audio_data, SR)
+ sf.write(transfer_target_path, target_audio_data, SR)
+
+ prompt_meta_path = session_base / "transcriptions" / "prompt" / "metadata.json"
+ target_meta_path = session_base / "transcriptions" / "target" / "metadata.json"
+ if use_input_metadata:
+ (session_base / "transcriptions" / "prompt").mkdir(parents=True, exist_ok=True)
+ (session_base / "transcriptions" / "target").mkdir(parents=True, exist_ok=True)
+ shutil.copy2(prompt_meta_resolved, prompt_meta_path)
+ shutil.copy2(target_meta_resolved, target_meta_path)
+ else:
+ ok, msg = APP_STATE.run_preprocess(
+ transfer_prompt_path,
+ transfer_target_path,
+ session_base,
+ prompt_vocal_sep=prompt_vocal_sep,
+ target_vocal_sep=target_vocal_sep,
+ prompt_lyric_lang=prompt_lyric_lang or "Mandarin",
+ target_lyric_lang=target_lyric_lang or "Mandarin",
+ )
+ if not ok:
+ print(msg, file=sys.stderr, flush=True)
+ return None, None
+
+ prompt_meta_file = str(prompt_meta_path) if prompt_meta_path.exists() else None
+ target_meta_file = str(target_meta_path) if target_meta_path.exists() else None
+ return prompt_meta_file, target_meta_file
+ except Exception:
+ print(traceback.format_exc(), file=sys.stderr, flush=True)
+ return None, None
+
+
+def synthesis_function(
+ prompt_audio,
+ prompt_metadata,
+ target_metadata,
+ control: str,
+ auto_shift: bool,
+ pitch_shift,
+ seed: int,
+):
+ """Step 2: Run SVS from top prompt_audio + prompt_metadata + target_metadata."""
+ try:
+ if isinstance(prompt_audio, tuple):
+ prompt_audio = prompt_audio[0]
+ prompt_wav_path = prompt_audio
+ prompt_meta_path = _resolve_file_path(prompt_metadata)
+ target_meta_path = _resolve_file_path(target_metadata)
+ if not prompt_wav_path or not os.path.isfile(prompt_wav_path):
+ gr.Warning(message=i18n("warn_missing_synthesis"))
+ return None
+ if not prompt_meta_path or not os.path.isfile(prompt_meta_path):
+ gr.Warning(message=i18n("warn_missing_synthesis"))
+ return None
+ if not target_meta_path or not os.path.isfile(target_meta_path):
+ gr.Warning(message=i18n("warn_missing_synthesis"))
+ return None
+ if control not in ("melody", "score"):
+ control = "score"
+ seed = int(seed)
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ ok, msg, merged = APP_STATE.run_svs_from_paths(
+ prompt_wav_path=prompt_wav_path,
+ prompt_metadata_path=prompt_meta_path,
+ target_metadata_path=target_meta_path,
+ control=control,
+ auto_shift=auto_shift,
+ pitch_shift=int(pitch_shift),
+ )
+ if not ok or merged is None:
+ print(msg or "synthesis failed", file=sys.stderr, flush=True)
+ return None
+ return str(merged)
+ except Exception:
+ print(traceback.format_exc(), file=sys.stderr, flush=True)
+ return None
+
+
+def _instruction_md() -> str:
+ """Markdown content for the instruction panel (supports links)."""
+ return "\n\n".join([
+ f"**1.** {i18n('instruction_p1')}",
+ f"**2.** {i18n('instruction_p2')}",
+ f"**3.** {i18n('instruction_p3')}",
+ ])
+
+
+def render_interface() -> gr.Blocks:
+ with gr.Blocks(title="SoulX-Singer 歌声合成Demo", theme=gr.themes.Default()) as page:
+ gr.HTML(
+ ''
+ '
SoulX-Singer
'
+ '
'
+ '
'
+ )
+ # Auto-detect browser language: run after Gradio mounts
+ gr.HTML(
+ '',
+ visible=False,
+ )
+ with gr.Row(equal_height=True):
+ lang_choice = gr.Radio(
+ choices=["中文", "English"],
+ value="中文",
+ label=i18n("display_lang_label"),
+ type="index",
+ interactive=True,
+ elem_id="lang_choice_radio",
+ )
+
+ # Instruction panel (usage workflow); updates on language change
+ instruction_md = gr.Markdown(f"### {i18n('instruction_title')}\n\n{_instruction_md()}")
+
+ # Reference examples — at the front of operations (handler registered after components exist)
+ skip_clear_metadata_count = gr.State(0)
+ with gr.Row():
+ _example_choices = [i18n("example_choice_0"), i18n("example_choice_1"), i18n("example_choice_2"), i18n("example_choice_3"), i18n("example_choice_4")]
+ example_choice = gr.Dropdown(
+ label=i18n("examples_label"),
+ choices=_example_choices,
+ value=_example_choices[0],
+ interactive=True,
+ )
+
+ # Step 1: Transcription (audio → metadata)
+ with gr.Row(equal_height=True):
+ with gr.Column(scale=1):
+ prompt_audio = gr.Audio(
+ label=i18n("prompt_audio_label"),
+ type="filepath",
+ editable=False,
+ interactive=True,
+ )
+ with gr.Column(scale=1):
+ target_audio = gr.Audio(
+ label=i18n("target_audio_label"),
+ type="filepath",
+ editable=False,
+ interactive=True,
+ )
+ with gr.Row(equal_height=True):
+ prompt_lyric_lang = gr.Dropdown(
+ label=i18n("prompt_lyric_lang_label"),
+ choices=get_lyric_lang_choices(),
+ value="Mandarin",
+ interactive=True,
+ scale=1,
+ )
+ target_lyric_lang = gr.Dropdown(
+ label=i18n("target_lyric_lang_label"),
+ choices=get_lyric_lang_choices(),
+ value="Mandarin",
+ interactive=True,
+ scale=1,
+ )
+ prompt_vocal_sep = gr.Checkbox(
+ label=i18n("prompt_vocal_sep_label"),
+ value=False,
+ interactive=True,
+ scale=1,
+ )
+ target_vocal_sep = gr.Checkbox(
+ label=i18n("target_vocal_sep_label"),
+ value=True,
+ interactive=True,
+ scale=1,
+ )
+ with gr.Row():
+ transcription_btn = gr.Button(
+ value=i18n("transcription_btn_label"),
+ variant="primary",
+ size="lg",
+ )
+
+ # Edit tutorial link (gr.HTML supports links; component labels do not)
+ metadata_tutorial_html = gr.HTML(value=i18n("edit_tutorial_html"))
+ # Synthesis: params row, then synthesis button on next row
+ with gr.Row(equal_height=True):
+ prompt_metadata = gr.File(
+ label=i18n("prompt_meta_label"),
+ type="filepath",
+ file_types=[".json"],
+ interactive=True,
+ )
+ target_metadata = gr.File(
+ label=i18n("target_meta_label"),
+ type="filepath",
+ file_types=[".json"],
+ interactive=True,
+ )
+ control_radio = gr.Radio(
+ choices=["melody", "score"],
+ value="score",
+ label=i18n("control_type_label"),
+ scale=1,
+ )
+ auto_shift = gr.Checkbox(
+ label=i18n("auto_shift_label"),
+ value=True,
+ interactive=True,
+ scale=1,
+ )
+ pitch_shift = gr.Number(
+ label=i18n("pitch_shift_label"),
+ value=0,
+ minimum=-36,
+ maximum=36,
+ step=1,
+ interactive=True,
+ scale=1,
+ )
+ seed_input = gr.Number(
+ label=i18n("seed_label"),
+ value=12306,
+ step=1,
+ interactive=True,
+ scale=1,
+ )
+ with gr.Row():
+ synthesis_btn = gr.Button(
+ value=i18n("synthesis_btn_label"),
+ variant="primary",
+ size="lg",
+ )
+ with gr.Row():
+ output_audio = gr.Audio(
+ label=i18n("generated_audio_label"),
+ type="filepath",
+ interactive=False,
+ )
+
+ example_choice.change(
+ fn=_load_example,
+ inputs=[example_choice],
+ outputs=[
+ prompt_audio,
+ target_audio,
+ prompt_metadata,
+ target_metadata,
+ prompt_lyric_lang,
+ target_lyric_lang,
+ control_radio,
+ prompt_vocal_sep,
+ target_vocal_sep,
+ auto_shift,
+ pitch_shift,
+ skip_clear_metadata_count,
+ ],
+ )
+
+ def _change_component_language(lang):
+ global global_lang
+ global_lang = ["zh", "en"][lang]
+ choices = get_lyric_lang_choices()
+ return [
+ gr.update(label=i18n("prompt_audio_label")),
+ gr.update(label=i18n("target_audio_label")),
+ gr.update(label=i18n("prompt_lyric_lang_label"), choices=choices),
+ gr.update(label=i18n("target_lyric_lang_label"), choices=choices),
+ gr.update(label=i18n("prompt_vocal_sep_label")),
+ gr.update(label=i18n("target_vocal_sep_label")),
+ gr.update(value=i18n("transcription_btn_label")),
+ gr.update(label=i18n("prompt_meta_label")),
+ gr.update(label=i18n("target_meta_label")),
+ gr.update(value=i18n("edit_tutorial_html")),
+ gr.update(label=i18n("control_type_label")),
+ gr.update(label=i18n("auto_shift_label")),
+ gr.update(label=i18n("pitch_shift_label")),
+ gr.update(label=i18n("seed_label")),
+ gr.update(value=i18n("synthesis_btn_label")),
+ gr.update(label=i18n("generated_audio_label")),
+ gr.update(label=i18n("display_lang_label")),
+ gr.update(
+ label=i18n("examples_label"),
+ choices=[i18n("example_choice_0"), i18n("example_choice_1"), i18n("example_choice_2"), i18n("example_choice_3"), i18n("example_choice_4")],
+ value=i18n("example_choice_0"),
+ ),
+ gr.update(value=f"### {i18n('instruction_title')}\n\n{_instruction_md()}"),
+ ]
+
+ lang_choice.change(
+ fn=_change_component_language,
+ inputs=[lang_choice],
+ outputs=[
+ prompt_audio,
+ target_audio,
+ prompt_lyric_lang,
+ target_lyric_lang,
+ prompt_vocal_sep,
+ target_vocal_sep,
+ transcription_btn,
+ prompt_metadata,
+ target_metadata,
+ metadata_tutorial_html,
+ control_radio,
+ auto_shift,
+ pitch_shift,
+ seed_input,
+ synthesis_btn,
+ output_audio,
+ lang_choice,
+ example_choice,
+ instruction_md,
+ ],
+ )
+
+ # Upload new prompt/target audio → clear corresponding metadata; skip clear when change came from load example
+ prompt_audio.change(
+ fn=_clear_prompt_meta_unless_example,
+ inputs=[prompt_audio, skip_clear_metadata_count],
+ outputs=[prompt_metadata, skip_clear_metadata_count],
+ )
+ target_audio.change(
+ fn=_clear_target_meta_unless_example,
+ inputs=[target_audio, skip_clear_metadata_count],
+ outputs=[target_metadata, skip_clear_metadata_count],
+ )
+
+ transcription_btn.click(
+ fn=transcription_function,
+ inputs=[
+ prompt_audio,
+ target_audio,
+ prompt_metadata,
+ target_metadata,
+ prompt_lyric_lang,
+ target_lyric_lang,
+ prompt_vocal_sep,
+ target_vocal_sep,
+ ],
+ outputs=[prompt_metadata, target_metadata],
+ )
+
+ synthesis_btn.click(
+ fn=synthesis_function,
+ inputs=[
+ prompt_audio,
+ prompt_metadata,
+ target_metadata,
+ control_radio,
+ auto_shift,
+ pitch_shift,
+ seed_input,
+ ],
+ outputs=[output_audio],
+ )
+
+ return page
+
+
+if __name__ == "__main__":
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--port", type=int, default=7860, help="Gradio server port")
+ parser.add_argument("--share", action="store_true", help="Create public link")
+ args = parser.parse_args()
+
+ page = render_interface()
+ page.queue()
+ page.launch(share=args.share, server_name="0.0.0.0", server_port=args.port)