Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- .gitignore +38 -0
- DEPLOY.md +201 -0
- LICENSE +201 -0
- README.md +226 -9
- app.py +63 -0
- assets/performance_radar.png +3 -0
- assets/soul_wechat01.jpg +3 -0
- assets/soulx-logo.png +3 -0
- assets/technical-report.pdf +3 -0
- cli/inference.py +147 -0
- deploy_to_hf.sh +70 -0
- example/audio/en_prompt.json +16 -0
- example/audio/en_prompt.mp3 +0 -0
- example/audio/en_target.json +16 -0
- example/audio/en_target.mp3 +0 -0
- example/audio/music.json +16 -0
- example/audio/music.mp3 +3 -0
- example/audio/yue_target.json +16 -0
- example/audio/yue_target.mp3 +3 -0
- example/audio/zh_prompt.json +16 -0
- example/audio/zh_prompt.mp3 +0 -0
- example/audio/zh_target.json +16 -0
- example/audio/zh_target.mp3 +0 -0
- example/infer.sh +28 -0
- example/preprocess.sh +41 -0
- preprocess/README.md +155 -0
- preprocess/pipeline.py +146 -0
- preprocess/requirements.txt +33 -0
- preprocess/tools/__init__.py +53 -0
- preprocess/tools/f0_extraction.py +527 -0
- preprocess/tools/g2p.py +72 -0
- preprocess/tools/lyric_transcription.py +279 -0
- preprocess/tools/midi_parser.py +669 -0
- preprocess/tools/note_transcription/__init__.py +0 -0
- preprocess/tools/note_transcription/model.py +522 -0
- preprocess/tools/note_transcription/modules/__init__.py +1 -0
- preprocess/tools/note_transcription/modules/commons/__init__.py +1 -0
- preprocess/tools/note_transcription/modules/commons/conformer/__init__.py +1 -0
- preprocess/tools/note_transcription/modules/commons/conformer/conformer.py +96 -0
- preprocess/tools/note_transcription/modules/commons/conformer/espnet_positional_embedding.py +113 -0
- preprocess/tools/note_transcription/modules/commons/conformer/espnet_transformer_attn.py +198 -0
- preprocess/tools/note_transcription/modules/commons/conformer/layers.py +260 -0
- preprocess/tools/note_transcription/modules/commons/conv.py +175 -0
- preprocess/tools/note_transcription/modules/commons/layers.py +85 -0
- preprocess/tools/note_transcription/modules/commons/rel_transformer.py +378 -0
- preprocess/tools/note_transcription/modules/commons/rnn.py +261 -0
- preprocess/tools/note_transcription/modules/commons/transformer.py +751 -0
- preprocess/tools/note_transcription/modules/commons/wavenet.py +109 -0
- preprocess/tools/note_transcription/modules/pe/__init__.py +1 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/performance_radar.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/soul_wechat01.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/soulx-logo.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/technical-report.pdf filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
example/audio/music.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
example/audio/yue_target.mp3 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
|
| 4 |
+
dev/
|
| 5 |
+
results/
|
| 6 |
+
wandb/
|
| 7 |
+
.ipynb_checkpoints/
|
| 8 |
+
.vscode/
|
| 9 |
+
.cache
|
| 10 |
+
local/
|
| 11 |
+
outputs/
|
| 12 |
+
|
| 13 |
+
*.pt
|
| 14 |
+
*.ckpt
|
| 15 |
+
|
| 16 |
+
# Logs
|
| 17 |
+
logs/
|
| 18 |
+
*.log
|
| 19 |
+
results/
|
| 20 |
+
runs/
|
| 21 |
+
dev*
|
| 22 |
+
local/
|
| 23 |
+
generated/
|
| 24 |
+
|
| 25 |
+
.DS_Store
|
| 26 |
+
pretrained_models/
|
| 27 |
+
|
| 28 |
+
*.err
|
| 29 |
+
*.out
|
| 30 |
+
|
| 31 |
+
# Dev
|
| 32 |
+
dev/
|
| 33 |
+
|
| 34 |
+
# Data
|
| 35 |
+
data/
|
| 36 |
+
outputs/
|
| 37 |
+
deploy/
|
| 38 |
+
.gradio/
|
DEPLOY.md
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚀 部署到 Hugging Face Space 指南
|
| 2 |
+
|
| 3 |
+
本指南将帮助您将 SoulX-Singer 部署到 Hugging Face Space。
|
| 4 |
+
|
| 5 |
+
## 📋 前置要求
|
| 6 |
+
|
| 7 |
+
1. **Hugging Face 账号**:如果没有,请先注册 [huggingface.co](https://huggingface.co/join)
|
| 8 |
+
2. **Git**:确保已安装 Git
|
| 9 |
+
3. **Hugging Face CLI**(可选但推荐):`pip install huggingface_hub`
|
| 10 |
+
|
| 11 |
+
## 🎯 部署步骤
|
| 12 |
+
|
| 13 |
+
### 方法一:通过 Web 界面创建(推荐)
|
| 14 |
+
|
| 15 |
+
#### 步骤 1:准备代码仓库
|
| 16 |
+
|
| 17 |
+
确保您的代码已准备好:
|
| 18 |
+
- ✅ `app.py` - Space 入口文件
|
| 19 |
+
- ✅ `webui.py` - Gradio 界面代码
|
| 20 |
+
- ✅ `requirements.txt` - Python 依赖
|
| 21 |
+
- ✅ `README.md` - 包含 Space 配置的 YAML 头部
|
| 22 |
+
|
| 23 |
+
#### 步骤 2:创建 Space
|
| 24 |
+
|
| 25 |
+
1. 访问 [huggingface.co/spaces](https://huggingface.co/spaces)
|
| 26 |
+
2. 点击 **"Create new Space"** 按钮
|
| 27 |
+
3. 填写 Space 信息:
|
| 28 |
+
- **Space name**: 例如 `SoulX-Singer` 或 `soulx-singer-demo`
|
| 29 |
+
- **SDK**: 选择 **Gradio**
|
| 30 |
+
- **Hardware**: 推荐选择 **GPU T4 small**(推理更快,首次下载模型后缓存)
|
| 31 |
+
- **Visibility**: 选择 Public(公开)或 Private(私有)
|
| 32 |
+
4. 点击 **"Create Space"**
|
| 33 |
+
|
| 34 |
+
#### 步骤 3:上传代码
|
| 35 |
+
|
| 36 |
+
**选项 A:使用 Git 推送(推荐)**
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
# 1. 在本地代码目录初始化 Git(如果还没有)
|
| 40 |
+
git init
|
| 41 |
+
git add .
|
| 42 |
+
git commit -m "Initial commit for HF Space"
|
| 43 |
+
|
| 44 |
+
# 2. 添加 Hugging Face 远程仓库
|
| 45 |
+
# 替换 YOUR_USERNAME 和 YOUR_SPACE_NAME
|
| 46 |
+
git remote add origin https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
|
| 47 |
+
|
| 48 |
+
# 3. 推送代码
|
| 49 |
+
git push -u origin main
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
**选项 B:使用 Web 界面上传**
|
| 53 |
+
|
| 54 |
+
1. 在 Space 页面点击 **"Files and versions"** 标签
|
| 55 |
+
2. 点击 **"Add file"** → **"Upload files"**
|
| 56 |
+
3. 拖拽或选择以下必需文件:
|
| 57 |
+
- `app.py`
|
| 58 |
+
- `webui.py`
|
| 59 |
+
- `requirements.txt`
|
| 60 |
+
- `README.md`
|
| 61 |
+
- `soulxsinger/` 目录(整个文件夹)
|
| 62 |
+
- `preprocess/` 目录(整个文件夹)
|
| 63 |
+
- `cli/` 目录(整个文件夹)
|
| 64 |
+
- `example/` 目录(整个文件夹)
|
| 65 |
+
- `assets/` 目录(整个文件夹)
|
| 66 |
+
- 其他配置文件(如 `LICENSE`, `.gitignore` 等)
|
| 67 |
+
|
| 68 |
+
#### 步骤 4:等待构建和首次运行
|
| 69 |
+
|
| 70 |
+
1. Space 会自动检测到代码并开始构建
|
| 71 |
+
2. 查看 **"Logs"** 标签页监控构建进度
|
| 72 |
+
3. 首次运行会:
|
| 73 |
+
- 安装 `requirements.txt` 中的依赖
|
| 74 |
+
- 执行 `app.py`
|
| 75 |
+
- **自动下载** `Soul-AILab/SoulX-Singer` 和 `Soul-AILab/SoulX-Singer-Preprocess` 模型(可能需要 5-15 分钟,取决于网络速度)
|
| 76 |
+
4. 构建完成后,Space 会自动启动,您可以在 **"App"** 标签页看到界面
|
| 77 |
+
|
| 78 |
+
### 方法二:使用 Hugging Face CLI
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
# 1. 安装 Hugging Face Hub CLI
|
| 82 |
+
pip install huggingface_hub
|
| 83 |
+
|
| 84 |
+
# 2. 登录(会打开浏览器)
|
| 85 |
+
huggingface-cli login
|
| 86 |
+
|
| 87 |
+
# 3. 创建 Space(替换 YOUR_USERNAME 和 YOUR_SPACE_NAME)
|
| 88 |
+
huggingface-cli repo create YOUR_SPACE_NAME --type space --sdk gradio
|
| 89 |
+
|
| 90 |
+
# 4. 克隆 Space 仓库
|
| 91 |
+
git clone https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
|
| 92 |
+
cd YOUR_SPACE_NAME
|
| 93 |
+
|
| 94 |
+
# 5. 复制代码文件到 Space 目录
|
| 95 |
+
# (将当前代码目录的所有文件复制过来)
|
| 96 |
+
|
| 97 |
+
# 6. 提交并推送
|
| 98 |
+
git add .
|
| 99 |
+
git commit -m "Deploy SoulX-Singer to HF Space"
|
| 100 |
+
git push
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
## ⚙️ Space 配置说明
|
| 104 |
+
|
| 105 |
+
Space 配置在 `README.md` 的 YAML 头部:
|
| 106 |
+
|
| 107 |
+
```yaml
|
| 108 |
+
---
|
| 109 |
+
title: SoulX-Singer
|
| 110 |
+
emoji: 🎤
|
| 111 |
+
sdk: gradio
|
| 112 |
+
sdk_version: "6.3.0"
|
| 113 |
+
app_file: app.py
|
| 114 |
+
python_version: "3.10"
|
| 115 |
+
suggested_hardware: t4-small # 取消注释以启用 GPU
|
| 116 |
+
---
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
### 硬件选择建议
|
| 120 |
+
|
| 121 |
+
- **CPU Basic**: 免费,但推理速度较慢,适合测试
|
| 122 |
+
- **GPU T4 Small**: 推荐,推理速度快,首次下载模型后缓存
|
| 123 |
+
- **GPU T4 Medium/Large**: 适合高并发或更复杂的推理
|
| 124 |
+
|
| 125 |
+
### 修改硬件配置
|
| 126 |
+
|
| 127 |
+
1. 进入 Space 页面
|
| 128 |
+
2. 点击 **"Settings"** 标签
|
| 129 |
+
3. 在 **"Hardware"** 部分选择所需硬件
|
| 130 |
+
4. 保存后 Space 会重启
|
| 131 |
+
|
| 132 |
+
## 🔍 故障排查
|
| 133 |
+
|
| 134 |
+
### 问题 1:构建失败
|
| 135 |
+
|
| 136 |
+
**检查点:**
|
| 137 |
+
- ✅ `requirements.txt` 中所有依赖版本是否兼容
|
| 138 |
+
- ✅ `app.py` 文件是否存在且可执行
|
| 139 |
+
- ✅ `README.md` 的 YAML 配置是否正确
|
| 140 |
+
|
| 141 |
+
**查看日志:**
|
| 142 |
+
- 在 Space 页面的 **"Logs"** 标签查看详细错误信息
|
| 143 |
+
|
| 144 |
+
### 问题 2:模型下载失败
|
| 145 |
+
|
| 146 |
+
**可能原因:**
|
| 147 |
+
- 网络连接问题
|
| 148 |
+
- Hugging Face Hub 认证问题
|
| 149 |
+
|
| 150 |
+
**解决方案:**
|
| 151 |
+
- 确保 Space 有网络访问权限(默认有)
|
| 152 |
+
- 如果使用私有模型,需要在 Space Settings 中添加 HF Token
|
| 153 |
+
|
| 154 |
+
### 问题 3:应用启动后无法访问
|
| 155 |
+
|
| 156 |
+
**检查点:**
|
| 157 |
+
- ✅ `app.py` 中 `server_name="0.0.0.0"` 已设置
|
| 158 |
+
- ✅ 端口使用环境变量 `PORT`(Space 会自动注入)
|
| 159 |
+
- ✅ 查看 **"Logs"** 确认应用是否成功启动
|
| 160 |
+
|
| 161 |
+
### 问题 4:内存不足
|
| 162 |
+
|
| 163 |
+
**解决方案:**
|
| 164 |
+
- 升级到更大的硬件(T4 Medium/Large)
|
| 165 |
+
- 或优化代码,减少内存占用
|
| 166 |
+
|
| 167 |
+
## 📝 重要提示
|
| 168 |
+
|
| 169 |
+
1. **首次运行时间**:首次部署时,模型下载可能需要 5-15 分钟,请耐心等待
|
| 170 |
+
2. **模型缓存**:下载的模型会缓存在 Space 的存��中,重启后无需重新下载
|
| 171 |
+
3. **存储限制**:免费 Space 有存储限制,确保模型文件不会超过限制
|
| 172 |
+
4. **自动重启**:Space 会在代码更新后自动重启
|
| 173 |
+
5. **日志查看**:遇到问题时,首先查看 **"Logs"** 标签页的详细日志
|
| 174 |
+
|
| 175 |
+
## 🔗 相关链接
|
| 176 |
+
|
| 177 |
+
- [Hugging Face Spaces 文档](https://huggingface.co/docs/hub/spaces)
|
| 178 |
+
- [Gradio 文档](https://gradio.app/docs/)
|
| 179 |
+
- [SoulX-Singer 模型页面](https://huggingface.co/Soul-AILab/SoulX-Singer)
|
| 180 |
+
- [SoulX-Singer-Preprocess 模型页面](https://huggingface.co/Soul-AILab/SoulX-Singer-Preprocess)
|
| 181 |
+
|
| 182 |
+
## ✅ 部署检查清单
|
| 183 |
+
|
| 184 |
+
部署前确认:
|
| 185 |
+
- [ ] `app.py` 文件存在且正确
|
| 186 |
+
- [ ] `requirements.txt` 包含所有依赖(包括 `huggingface_hub`)
|
| 187 |
+
- [ ] `README.md` 包含正确的 YAML 配置
|
| 188 |
+
- [ ] 所有必需的代码文件都已上传
|
| 189 |
+
- [ ] `.gitignore` 正确配置(排除 `pretrained_models/` 和 `outputs/`)
|
| 190 |
+
- [ ] Space 硬件配置合适(推荐 GPU T4 Small)
|
| 191 |
+
|
| 192 |
+
部署后验证:
|
| 193 |
+
- [ ] Space 构建成功(无错误日志)
|
| 194 |
+
- [ ] 模型自动下载完成
|
| 195 |
+
- [ ] Web 界面可以正常访问
|
| 196 |
+
- [ ] 可以上传音频文件进行测试
|
| 197 |
+
- [ ] 推理功能正常工作
|
| 198 |
+
|
| 199 |
+
---
|
| 200 |
+
|
| 201 |
+
**祝部署顺利!** 🎉
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,14 +1,231 @@
|
|
| 1 |
---
|
| 2 |
-
title: SoulX
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom: purple
|
| 5 |
-
colorTo: yellow
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 6.
|
| 8 |
app_file: app.py
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: SoulX-Singer
|
| 3 |
+
emoji: 🎤
|
|
|
|
|
|
|
| 4 |
sdk: gradio
|
| 5 |
+
sdk_version: "6.3.0"
|
| 6 |
app_file: app.py
|
| 7 |
+
python_version: "3.10"
|
| 8 |
+
# GPU recommended for inference speed (optional: use CPU for light usage)
|
| 9 |
+
# suggested_hardware: t4-small
|
| 10 |
---
|
| 11 |
|
| 12 |
+
<div align="center">
|
| 13 |
+
<h1>🎤 SoulX-Singer</h1>
|
| 14 |
+
<p>
|
| 15 |
+
Official inference code for<br>
|
| 16 |
+
<b><em>SoulX-Singer: Towards High-Quality Zero-Shot Singing Voice Synthesis</em></b>
|
| 17 |
+
</p>
|
| 18 |
+
<p>
|
| 19 |
+
<img src="assets/soulx-logo.png" alt="SoulX-Logo" style="height:80px;">
|
| 20 |
+
</p>
|
| 21 |
+
<p>
|
| 22 |
+
<a href="https://soul-ailab.github.io/soulx-singer/"><img src="https://img.shields.io/badge/Demo-Page-lightgrey" alt="Demo Page"></a>
|
| 23 |
+
<a href="https://huggingface.co/Soul-AILab/SoulX-Singer"><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue' alt="HF-model"></a>
|
| 24 |
+
<a href="assets/technical-report.pdf"><img src="https://img.shields.io/badge/Report-Github-red" alt="Technical Report"></a>
|
| 25 |
+
<a href="https://github.com/Soul-AILab/SoulX-Singer"><img src="https://img.shields.io/badge/License-Apache%202.0-blue" alt="License"></a>
|
| 26 |
+
</p>
|
| 27 |
+
</div>
|
| 28 |
+
|
| 29 |
+
---
|
| 30 |
+
|
| 31 |
+
## 🎵 Overview
|
| 32 |
+
|
| 33 |
+
**SoulX-Singer** is a high-fidelity, zero-shot singing voice synthesis model that enables users to generate realistic singing voices for unseen singers.
|
| 34 |
+
It supports **melody-conditioned (F0 contour)** and **score-conditioned (MIDI notes)** control for precise pitch, rhythm, and expression.
|
| 35 |
+
|
| 36 |
+
---
|
| 37 |
+
|
| 38 |
+
## ✨ Key Features
|
| 39 |
+
|
| 40 |
+
- **🎤 Zero-Shot Singing** – Generate high-fidelity voices for unseen singers, no fine-tuning needed.
|
| 41 |
+
- **🎵 Flexible Control Modes** – Melody (F0) and Score (MIDI) conditioning.
|
| 42 |
+
- **📚 Large-Scale Dataset** – 42,000+ hours of aligned vocals, lyrics, notes across Mandarin, English, Cantonese.
|
| 43 |
+
- **🧑🎤 Timbre Cloning** – Preserve singer identity across languages, styles, and edited lyrics.
|
| 44 |
+
- **✏️ Singing Voice Editing** – Modify lyrics while keeping natural prosody.
|
| 45 |
+
- **🌐 Cross-Lingual Synthesis** – High-fidelity synthesis by disentangling timbre from content.
|
| 46 |
+
|
| 47 |
+
---
|
| 48 |
+
|
| 49 |
+
<p align="center">
|
| 50 |
+
<img src="assets/performance_radar.png" width="80%" alt="Performance Radar"/>
|
| 51 |
+
</p>
|
| 52 |
+
|
| 53 |
+
---
|
| 54 |
+
|
| 55 |
+
## 🎬 Demo Examples
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
<div align="center">
|
| 59 |
+
|
| 60 |
+
<https://github.com/user-attachments/assets/13306f10-3a29-46ba-bcef-d6308d05cbcc>
|
| 61 |
+
|
| 62 |
+
</div>
|
| 63 |
+
<div align="center">
|
| 64 |
+
|
| 65 |
+
<https://github.com/user-attachments/assets/2eb260fe-6f0b-408c-aab8-5b81ddddb284>
|
| 66 |
+
|
| 67 |
+
</div>
|
| 68 |
+
|
| 69 |
+
---
|
| 70 |
+
|
| 71 |
+
## 📰 News
|
| 72 |
+
|
| 73 |
+
- **[2026-02-06]** SoulX-Singer inference code and models released.
|
| 74 |
+
|
| 75 |
+
---
|
| 76 |
+
|
| 77 |
+
## 🚀 Quick Start
|
| 78 |
+
|
| 79 |
+
**Note:** This repo does not ship pretrained weights. SVS and preprocessing models must be downloaded from Hugging Face (see step 3).
|
| 80 |
+
|
| 81 |
+
### 1. Clone Repository
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
git clone https://github.com/Soul-AILab/SoulX-Singer.git
|
| 85 |
+
cd SoulX-Singer
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
### 2. Set Up Environment
|
| 89 |
+
|
| 90 |
+
**1. Install Conda** (if not already installed): https://docs.conda.io/en/latest/miniconda.html
|
| 91 |
+
|
| 92 |
+
**2. Create and activate a Conda environment:**
|
| 93 |
+
```
|
| 94 |
+
conda create -n soulxsinger -y python=3.10
|
| 95 |
+
conda activate soulxsinger
|
| 96 |
+
```
|
| 97 |
+
**3. Install dependencies:**
|
| 98 |
+
```
|
| 99 |
+
pip install -r requirements.txt
|
| 100 |
+
```
|
| 101 |
+
⚠️ If you are in mainland China, use a PyPI mirror:
|
| 102 |
+
```
|
| 103 |
+
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
---
|
| 108 |
+
|
| 109 |
+
### 3. Download Pretrained Models
|
| 110 |
+
|
| 111 |
+
**This repository does not include pretrained models.** You must download them from Hugging Face:
|
| 112 |
+
|
| 113 |
+
- [Soul-AILab/SoulX-Singer](https://huggingface.co/Soul-AILab/SoulX-Singer) (SVS model)
|
| 114 |
+
- [Soul-AILab/SoulX-Singer-Preprocess](https://huggingface.co/Soul-AILab/SoulX-Singer-Preprocess) (preprocessing models)
|
| 115 |
+
|
| 116 |
+
Install Hugging Face Hub and download:
|
| 117 |
+
|
| 118 |
+
```sh
|
| 119 |
+
pip install -U huggingface_hub
|
| 120 |
+
|
| 121 |
+
# SoulX-Singer SVS model
|
| 122 |
+
huggingface-cli download Soul-AILab/SoulX-Singer --local-dir pretrained_models/SoulX-Singer
|
| 123 |
+
|
| 124 |
+
# Preprocessing models (vocal separation, F0, ASR, etc.)
|
| 125 |
+
huggingface-cli download Soul-AILab/SoulX-Singer-Preprocess --local-dir pretrained_models/SoulX-Singer-Preprocess
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
### 4. Run the Demo
|
| 130 |
+
|
| 131 |
+
Run the inference demo:
|
| 132 |
+
``` sh
|
| 133 |
+
bash example/infer.sh
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
This script relies on metadata generated from the preprocessing pipeline, including vocal separation and transcription. Users should follow the steps in [preprocess](preprocess/README.md) to prepare the necessary metadata before running the demo with their own data.
|
| 137 |
+
|
| 138 |
+
**⚠️ Important Note**
|
| 139 |
+
The metadata produced by the automatic preprocessing pipeline may not perfectly align the singing audio with the corresponding lyrics and musical notes. For best synthesis quality, we strongly recommend manually correcting the alignment using the 🎼 [Midi-Editor](https://huggingface.co/spaces/Soul-AILab/SoulX-Singer-Midi-Editor).
|
| 140 |
+
|
| 141 |
+
How to use the Midi-Editor:
|
| 142 |
+
- [Eiditing Metadata with Midi-Editor](preprocess/README.md#L104-L105)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
### 🌐 WebUI
|
| 146 |
+
|
| 147 |
+
You can launch the interactive interface with:
|
| 148 |
+
```
|
| 149 |
+
python webui.py
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
### 🚀 Deploy as Hugging Face Space
|
| 153 |
+
|
| 154 |
+
This repo is ready to deploy as a [Hugging Face Space](https://huggingface.co/spaces). **Pretrained models are not included;** `app.py` downloads them from the Hub on first run.
|
| 155 |
+
|
| 156 |
+
**📖 详细部署指南请查看:[DEPLOY.md](DEPLOY.md)**
|
| 157 |
+
|
| 158 |
+
**快速步骤:**
|
| 159 |
+
|
| 160 |
+
1. **创建 Space**:访问 [huggingface.co/spaces](https://huggingface.co/spaces),点击 "Create new Space",选择 **Gradio** SDK
|
| 161 |
+
2. **上传代码**:使用 Git 推送或 Web 界面上传代码文件
|
| 162 |
+
3. **配置硬件**:在 Space Settings 中选择 **GPU T4 Small**(推荐)以加快推理速度
|
| 163 |
+
4. **等待启动**:Space 会自动安装依赖、下载模型并启动应用(首次运行可能需要 5-15 分钟)
|
| 164 |
+
|
| 165 |
+
模型会自动从以下仓库下载:
|
| 166 |
+
- [Soul-AILab/SoulX-Singer](https://huggingface.co/Soul-AILab/SoulX-Singer) (SVS model)
|
| 167 |
+
- [Soul-AILab/SoulX-Singer-Preprocess](https://huggingface.co/Soul-AILab/SoulX-Singer-Preprocess) (preprocessing models)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
## 🚧 Roadmap
|
| 172 |
+
|
| 173 |
+
- [ ] 🖥️ Web-based UI for easy and interactive inference
|
| 174 |
+
- [ ] 🌐 Online demo deployment on Hugging Face Spaces
|
| 175 |
+
- [ ] 📊 Release the SoulX-Singer-Eval benchmark
|
| 176 |
+
- [ ] 📚 Comprehensive tutorials and usage documentation
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
## 🙏 Acknowledgements
|
| 180 |
+
|
| 181 |
+
Special thanks to the following open-source projects:
|
| 182 |
+
|
| 183 |
+
- [F5-TTS](https://github.com/SWivid/F5-TTS)
|
| 184 |
+
- [Amphion](https://github.com/open-mmlab/Amphion/tree/main)
|
| 185 |
+
- [Music Source Separation Training](https://github.com/ZFTurbo/Music-Source-Separation-Training)
|
| 186 |
+
- [Lead Vocal Separation](https://huggingface.co/becruily/mel-band-roformer-karaoke)
|
| 187 |
+
- [Vocal Dereverberation](https://huggingface.co/anvuew/dereverb_mel_band_roformer)
|
| 188 |
+
- [RMVPE](https://github.com/Dream-High/RMVPE)
|
| 189 |
+
[Paraformer](https://modelscope.cn/models/iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch)
|
| 190 |
+
- [Parakeet-tdt-0.6b-v2](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2)
|
| 191 |
+
- [ROSVOT](https://github.com/RickyL-2000/ROSVOT)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
## 📄 License
|
| 196 |
+
|
| 197 |
+
We use the Apache 2.0 license. Researchers and developers are free to use the codes and model weights of our SoulX-Singer. Check the license at [LICENSE](LICENSE) for more details.
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
## ⚠️ Usage Disclaimer
|
| 201 |
+
|
| 202 |
+
SoulX-Singer is intended for academic research, educational purposes, and legitimate applications such as personalized singing synthesis and assistive technologies.
|
| 203 |
+
|
| 204 |
+
Please note:
|
| 205 |
+
|
| 206 |
+
- 🎤 Respect intellectual property, privacy, and personal consent when generating singing content.
|
| 207 |
+
- 🚫 Do not use the model to impersonate individuals without authorization or to create deceptive audio.
|
| 208 |
+
- ⚠️ The developers assume no liability for any misuse of this model.
|
| 209 |
+
|
| 210 |
+
We advocate for the responsible development and use of AI and encourage the community to uphold safety and ethical principles. For ethics or misuse concerns, please contact us.
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
## 📬 Contact Us
|
| 214 |
+
|
| 215 |
+
We welcome your feedback, questions, and collaboration:
|
| 216 |
+
|
| 217 |
+
- **Email**: qianjiale@soulapp.cn | menghao@soulapp.cn | wangxinsheng@soulapp.cn
|
| 218 |
+
|
| 219 |
+
- **Join discussions**: WeChat or Soul APP groups for technical discussions and updates:
|
| 220 |
+
|
| 221 |
+
<p align="center">
|
| 222 |
+
<!-- <em>Due to group limits, if you can't scan the QR code, please add my WeChat for group access -->
|
| 223 |
+
<!-- : <strong>Tiamo James</strong></em> -->
|
| 224 |
+
<br>
|
| 225 |
+
<span style="display: inline-block; margin-right: 10px;">
|
| 226 |
+
<img src="assets/soul_wechat01.jpg" width="500" alt="WeChat Group QR Code"/>
|
| 227 |
+
</span>
|
| 228 |
+
<!-- <span style="display: inline-block;">
|
| 229 |
+
<img src="assets/wechat_tiamo.jpg" width="300" alt="WeChat QR Code"/>
|
| 230 |
+
</span> -->
|
| 231 |
+
</p>
|
app.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hugging Face Space entry point for SoulX-Singer.
|
| 3 |
+
Downloads pretrained models from the Hub if needed, then launches the Gradio app.
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
ROOT = Path(__file__).resolve().parent
|
| 10 |
+
PRETRAINED_DIR = ROOT / "pretrained_models"
|
| 11 |
+
MODEL_DIR_SVS = PRETRAINED_DIR / "SoulX-Singer"
|
| 12 |
+
MODEL_DIR_PREPROCESS = PRETRAINED_DIR / "SoulX-Singer-Preprocess"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def ensure_pretrained_models():
|
| 16 |
+
"""Download SoulX-Singer and Preprocess models from Hugging Face Hub if not present."""
|
| 17 |
+
if (MODEL_DIR_SVS / "model.pt").exists() and MODEL_DIR_PREPROCESS.exists():
|
| 18 |
+
print("Pretrained models already present, skipping download.", flush=True)
|
| 19 |
+
return
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
from huggingface_hub import snapshot_download
|
| 23 |
+
except ImportError:
|
| 24 |
+
print(
|
| 25 |
+
"huggingface_hub not installed. Install with: pip install huggingface_hub",
|
| 26 |
+
file=sys.stderr,
|
| 27 |
+
flush=True,
|
| 28 |
+
)
|
| 29 |
+
raise
|
| 30 |
+
|
| 31 |
+
PRETRAINED_DIR.mkdir(parents=True, exist_ok=True)
|
| 32 |
+
|
| 33 |
+
if not (MODEL_DIR_SVS / "model.pt").exists():
|
| 34 |
+
print("Downloading SoulX-Singer model...", flush=True)
|
| 35 |
+
snapshot_download(
|
| 36 |
+
repo_id="Soul-AILab/SoulX-Singer",
|
| 37 |
+
local_dir=str(MODEL_DIR_SVS),
|
| 38 |
+
local_dir_use_symlinks=False,
|
| 39 |
+
)
|
| 40 |
+
print("SoulX-Singer model ready.", flush=True)
|
| 41 |
+
|
| 42 |
+
if not MODEL_DIR_PREPROCESS.exists():
|
| 43 |
+
print("Downloading SoulX-Singer-Preprocess models...", flush=True)
|
| 44 |
+
snapshot_download(
|
| 45 |
+
repo_id="Soul-AILab/SoulX-Singer-Preprocess",
|
| 46 |
+
local_dir=str(MODEL_DIR_PREPROCESS),
|
| 47 |
+
local_dir_use_symlinks=False,
|
| 48 |
+
)
|
| 49 |
+
print("SoulX-Singer-Preprocess models ready.", flush=True)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
if __name__ == "__main__":
|
| 53 |
+
os.chdir(ROOT)
|
| 54 |
+
ensure_pretrained_models()
|
| 55 |
+
|
| 56 |
+
from webui import render_interface
|
| 57 |
+
|
| 58 |
+
page = render_interface()
|
| 59 |
+
page.queue()
|
| 60 |
+
page.launch(
|
| 61 |
+
server_name="0.0.0.0",
|
| 62 |
+
server_port=int(os.environ.get("PORT", "7860")),
|
| 63 |
+
)
|
assets/performance_radar.png
ADDED
|
Git LFS Details
|
assets/soul_wechat01.jpg
ADDED
|
Git LFS Details
|
assets/soulx-logo.png
ADDED
|
Git LFS Details
|
assets/technical-report.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ab2876f8850ce09e2b8ce7e929f8b9adf7de10f13900cb013f548f9707b80061
|
| 3 |
+
size 7927691
|
cli/inference.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import json
|
| 4 |
+
import argparse
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import numpy as np
|
| 7 |
+
import soundfile as sf
|
| 8 |
+
from collections import OrderedDict
|
| 9 |
+
from omegaconf import DictConfig
|
| 10 |
+
|
| 11 |
+
from soulxsinger.utils.file_utils import load_config
|
| 12 |
+
from soulxsinger.models.soulxsinger import SoulXSinger
|
| 13 |
+
from soulxsinger.utils.data_processor import DataProcessor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def build_model(
|
| 17 |
+
model_path: str,
|
| 18 |
+
config: DictConfig,
|
| 19 |
+
device: str = "cuda",
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
Build the model from the pre-trained model path and model configuration.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
model_path (str): Path to the checkpoint file.
|
| 26 |
+
config (DictConfig): Model configuration.
|
| 27 |
+
device (str, optional): Device to use. Defaults to "cuda".
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
Tuple[torch.nn.Module, torch.nn.Module]: The initialized model and vocoder.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
if not os.path.isfile(model_path):
|
| 34 |
+
raise FileNotFoundError(
|
| 35 |
+
f"Model checkpoint not found: {model_path}. "
|
| 36 |
+
"Please download the pretrained model and place it at the path, or set --model_path."
|
| 37 |
+
)
|
| 38 |
+
model = SoulXSinger(config).to(device)
|
| 39 |
+
print("Model initialized.")
|
| 40 |
+
print("Model parameters:", sum(p.numel() for p in model.parameters()) / 1e6, "M")
|
| 41 |
+
|
| 42 |
+
checkpoint = torch.load(model_path, weights_only=False, map_location=device)
|
| 43 |
+
if "state_dict" not in checkpoint:
|
| 44 |
+
raise KeyError(
|
| 45 |
+
f"Checkpoint at {model_path} has no 'state_dict' key. "
|
| 46 |
+
"Expected a checkpoint saved with model.state_dict()."
|
| 47 |
+
)
|
| 48 |
+
model.load_state_dict(checkpoint["state_dict"], strict=True)
|
| 49 |
+
|
| 50 |
+
model.eval()
|
| 51 |
+
model.to(device)
|
| 52 |
+
print("Model checkpoint loaded.")
|
| 53 |
+
|
| 54 |
+
return model
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def process(args, config, model: torch.nn.Module):
|
| 58 |
+
"""Run the full inference pipeline given a data_processor and model.
|
| 59 |
+
"""
|
| 60 |
+
if args.control not in ("melody", "score"):
|
| 61 |
+
raise ValueError(f"control must be 'melody' or 'score', got: {args.control}")
|
| 62 |
+
|
| 63 |
+
print(f"prompt_metadata_path: {args.prompt_metadata_path}")
|
| 64 |
+
print(f"target_metadata_path: {args.target_metadata_path}")
|
| 65 |
+
|
| 66 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 67 |
+
data_processor = DataProcessor(
|
| 68 |
+
hop_size=config.audio.hop_size,
|
| 69 |
+
sample_rate=config.audio.sample_rate,
|
| 70 |
+
phoneset_path=args.phoneset_path,
|
| 71 |
+
device=args.device,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
with open(args.prompt_metadata_path, "r", encoding="utf-8") as f:
|
| 75 |
+
prompt_meta_list = json.load(f)
|
| 76 |
+
if not prompt_meta_list:
|
| 77 |
+
raise ValueError("Prompt metadata is empty. Please run preprocess on prompt audio first.")
|
| 78 |
+
prompt_meta = prompt_meta_list[0] # load the first segment as the prompt
|
| 79 |
+
with open(args.target_metadata_path, "r", encoding="utf-8") as f:
|
| 80 |
+
target_meta_list = json.load(f)
|
| 81 |
+
infer_prompt_data = data_processor.process(prompt_meta, args.prompt_wav_path)
|
| 82 |
+
|
| 83 |
+
assert len(target_meta_list) > 0, "No target segments found in the target metadata."
|
| 84 |
+
generated_len = int(target_meta_list[-1]["time"][1] / 1000 * config.audio.sample_rate)
|
| 85 |
+
generated_merged = np.zeros(generated_len, dtype=np.float32)
|
| 86 |
+
|
| 87 |
+
for idx, target_meta in enumerate(
|
| 88 |
+
tqdm(target_meta_list, total=len(target_meta_list), desc="Inferring segments"),
|
| 89 |
+
):
|
| 90 |
+
start_sample_idx = int(target_meta["time"][0] / 1000 * config.audio.sample_rate)
|
| 91 |
+
end_sample_idx = int(target_meta["time"][1] / 1000 * config.audio.sample_rate)
|
| 92 |
+
infer_target_data = data_processor.process(target_meta, None)
|
| 93 |
+
|
| 94 |
+
infer_data = {
|
| 95 |
+
"prompt": infer_prompt_data,
|
| 96 |
+
"target": infer_target_data,
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
generated_audio = model.infer(
|
| 101 |
+
infer_data,
|
| 102 |
+
auto_shift=args.auto_shift,
|
| 103 |
+
pitch_shift=args.pitch_shift,
|
| 104 |
+
n_steps=config.infer.n_steps,
|
| 105 |
+
cfg=config.infer.cfg,
|
| 106 |
+
control=args.control,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
generated_audio = generated_audio.squeeze().cpu().numpy()
|
| 110 |
+
generated_merged[start_sample_idx : start_sample_idx + generated_audio.shape[0]] = generated_audio
|
| 111 |
+
|
| 112 |
+
merged_path = os.path.join(args.save_dir, "generated.wav")
|
| 113 |
+
sf.write(merged_path, generated_merged, 24000)
|
| 114 |
+
print(f"Generated audio saved to {merged_path}")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def main(args, config):
|
| 118 |
+
model = build_model(
|
| 119 |
+
model_path=args.model_path,
|
| 120 |
+
config=config,
|
| 121 |
+
device=args.device,
|
| 122 |
+
)
|
| 123 |
+
process(args, config, model)
|
| 124 |
+
|
| 125 |
+
if __name__ == "__main__":
|
| 126 |
+
parser = argparse.ArgumentParser()
|
| 127 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 128 |
+
parser.add_argument("--model_path", type=str, default='pretrained_models/soulx-singer/model.pt')
|
| 129 |
+
parser.add_argument("--config", type=str, default='soulxsinger/config/soulxsinger.yaml')
|
| 130 |
+
parser.add_argument("--prompt_wav_path", type=str, default='example/audio/zh_prompt.wav')
|
| 131 |
+
parser.add_argument("--prompt_metadata_path", type=str, default='example/metadata/zh_prompt.json')
|
| 132 |
+
parser.add_argument("--target_metadata_path", type=str, default='example/metadata/zh_target.json')
|
| 133 |
+
parser.add_argument("--phoneset_path", type=str, default='soulxsinger/utils/phoneme/phone_set.json')
|
| 134 |
+
parser.add_argument("--save_dir", type=str, default='outputs')
|
| 135 |
+
parser.add_argument("--auto_shift", action="store_true")
|
| 136 |
+
parser.add_argument("--pitch_shift", type=int, default=0)
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--control",
|
| 139 |
+
type=str,
|
| 140 |
+
default="melody",
|
| 141 |
+
choices=["melody", "score"],
|
| 142 |
+
help="Control mode: melody or score only",
|
| 143 |
+
)
|
| 144 |
+
args = parser.parse_args()
|
| 145 |
+
|
| 146 |
+
config = load_config(args.config)
|
| 147 |
+
main(args, config)
|
deploy_to_hf.sh
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# 快速部署脚本:将 SoulX-Singer 部署到 Hugging Face Space
|
| 3 |
+
# 使用方法: ./deploy_to_hf.sh YOUR_USERNAME YOUR_SPACE_NAME
|
| 4 |
+
|
| 5 |
+
set -e
|
| 6 |
+
|
| 7 |
+
if [ $# -lt 2 ]; then
|
| 8 |
+
echo "用法: $0 <YOUR_USERNAME> <YOUR_SPACE_NAME>"
|
| 9 |
+
echo "示例: $0 myusername soulx-singer-demo"
|
| 10 |
+
exit 1
|
| 11 |
+
fi
|
| 12 |
+
|
| 13 |
+
USERNAME=$1
|
| 14 |
+
SPACE_NAME=$2
|
| 15 |
+
SPACE_REPO="https://huggingface.co/spaces/${USERNAME}/${SPACE_NAME}"
|
| 16 |
+
|
| 17 |
+
echo "🚀 开始部署到 Hugging Face Space..."
|
| 18 |
+
echo "Space: ${USERNAME}/${SPACE_NAME}"
|
| 19 |
+
echo ""
|
| 20 |
+
|
| 21 |
+
# 检查是否已安装 huggingface_hub
|
| 22 |
+
if ! command -v huggingface-cli &> /dev/null; then
|
| 23 |
+
echo "⚠️ 未检测到 huggingface-cli,正在安装..."
|
| 24 |
+
pip install -U huggingface_hub
|
| 25 |
+
fi
|
| 26 |
+
|
| 27 |
+
# 检查是否已登录
|
| 28 |
+
if ! huggingface-cli whoami &> /dev/null; then
|
| 29 |
+
echo "🔐 请先登录 Hugging Face..."
|
| 30 |
+
huggingface-cli login
|
| 31 |
+
fi
|
| 32 |
+
|
| 33 |
+
# 创建 Space(如果不存在)
|
| 34 |
+
echo "📦 检查 Space 是否存在..."
|
| 35 |
+
if ! huggingface-cli repo info "${USERNAME}/${SPACE_NAME}" --repo-type space &> /dev/null; then
|
| 36 |
+
echo "✨ 创建新的 Space..."
|
| 37 |
+
huggingface-cli repo create "${SPACE_NAME}" --type space --sdk gradio
|
| 38 |
+
else
|
| 39 |
+
echo "✅ Space 已存在"
|
| 40 |
+
fi
|
| 41 |
+
|
| 42 |
+
# 检查是否已初始化 Git
|
| 43 |
+
if [ ! -d ".git" ]; then
|
| 44 |
+
echo "📝 初始化 Git 仓库..."
|
| 45 |
+
git init
|
| 46 |
+
git add .
|
| 47 |
+
git commit -m "Initial commit for HF Space deployment" || echo "⚠️ 没有新文件需要提交"
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
# 检查远程仓库
|
| 51 |
+
if git remote | grep -q "^origin$"; then
|
| 52 |
+
echo "🔄 更新远程仓库地址..."
|
| 53 |
+
git remote set-url origin "${SPACE_REPO}"
|
| 54 |
+
else
|
| 55 |
+
echo "➕ 添加远程仓库..."
|
| 56 |
+
git remote add origin "${SPACE_REPO}"
|
| 57 |
+
fi
|
| 58 |
+
|
| 59 |
+
# 推送代码
|
| 60 |
+
echo "📤 推送代码到 Hugging Face..."
|
| 61 |
+
git push -u origin main || git push -u origin master
|
| 62 |
+
|
| 63 |
+
echo ""
|
| 64 |
+
echo "✅ 部署完成!"
|
| 65 |
+
echo "🌐 Space 地址: ${SPACE_REPO}"
|
| 66 |
+
echo ""
|
| 67 |
+
echo "💡 提示:"
|
| 68 |
+
echo " - Space 会自动开始构建,请查看 Logs 标签页"
|
| 69 |
+
echo " - 首次运行会下载模型,可能需要 5-15 分钟"
|
| 70 |
+
echo " - 建议在 Space Settings 中选择 GPU T4 Small 硬件"
|
example/audio/en_prompt.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"index": "vocal_5220_10280",
|
| 4 |
+
"language": "English",
|
| 5 |
+
"time": [
|
| 6 |
+
5220,
|
| 7 |
+
10280
|
| 8 |
+
],
|
| 9 |
+
"duration": "0.24 0.36 0.30 0.78 0.24 0.56 0.19 0.53 0.36 0.20 0.32 0.57 0.19 0.22",
|
| 10 |
+
"text": "<SP> Ooh Ooh <SP> I wish nothing nothing more more the best best <SP>",
|
| 11 |
+
"phoneme": "<SP> en_UW1 en_UW1 <SP> en_AY1 en_W-IH1-SH en_N-AH1-TH-IH0-NG en_N-AH1-TH-IH0-NG en_M-AO1-R en_M-AO1-R en_DH-AH0 en_B-EH1-S-T en_B-EH1-S-T <SP>",
|
| 12 |
+
"note_pitch": "0 63 65 0 65 67 68 62 62 64 67 67 65 0",
|
| 13 |
+
"note_type": "1 2 3 1 2 2 2 3 2 3 2 2 3 1",
|
| 14 |
+
"f0": "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 345.2 343.1 341.6 339.8 337.8 331.9 319.5 312.1 310.8 312.6 315.1 316.1 315.3 314.6 315.3 317.9 322.0 329.6 337.5 344.7 347.5 347.2 344.3 339.5 338.2 341.7 342.8 342.2 340.7 343.0 342.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 347.0 345.3 348.7 350.2 350.9 350.3 344.7 340.3 338.3 338.0 342.8 347.4 348.3 346.7 343.4 339.5 340.5 345.2 350.4 357.7 367.3 376.6 385.9 392.6 393.6 389.9 384.7 381.8 382.0 383.0 380.6 373.5 367.9 377.0 385.4 391.4 393.8 395.6 396.1 397.2 399.8 406.0 413.5 416.1 416.0 414.4 413.5 412.9 415.5 418.9 417.5 408.8 389.2 373.9 0.0 0.0 0.0 288.5 286.0 284.2 285.6 288.9 291.3 293.5 294.5 295.2 297.8 299.5 301.0 303.0 305.9 306.8 306.0 304.4 301.8 301.0 300.8 301.8 310.2 309.8 308.2 305.9 303.6 301.5 299.3 298.5 300.0 302.1 303.5 303.6 302.2 299.7 297.5 296.3 296.4 296.8 298.6 302.6 311.8 322.0 333.8 349.0 368.8 393.3 407.1 410.7 407.0 402.3 401.2 401.7 403.9 405.7 403.5 396.8 387.4 378.6 377.8 381.4 384.0 384.7 383.5 382.5 380.8 377.3 378.4 383.5 390.0 392.7 390.5 387.6 385.3 382.7 381.0 382.8 383.9 382.2 379.6 379.3 380.2 383.1 386.0 386.5 385.4 384.3 383.7 384.4 386.2 388.2 388.5 385.0 378.6 360.4 333.7 328.2 332.4 340.2 348.9 339.6 334.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0"
|
| 15 |
+
}
|
| 16 |
+
]
|
example/audio/en_prompt.mp3
ADDED
|
Binary file (86.8 kB). View file
|
|
|
example/audio/en_target.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"index": "vocal_0_6900",
|
| 4 |
+
"language": "English",
|
| 5 |
+
"time": [
|
| 6 |
+
0,
|
| 7 |
+
6900
|
| 8 |
+
],
|
| 9 |
+
"duration": "0.16 0.24 0.32 0.15 0.17 0.24 0.15 0.44 0.29 0.32 0.24 0.32 0.22 0.18 0.24 0.25 1.01 0.26 0.48 0.29 0.79 0.14",
|
| 10 |
+
"text": "<SP> Who says you're you're not pretty <SP> pretty <SP> Who says you're you're not beautiful beautiful <SP> Who says says <SP>",
|
| 11 |
+
"phoneme": "<SP> en_HH-UW1 en_S-EH1-Z en_Y-UH1-R en_Y-UH1-R en_N-AA1-T en_P-R-IH1-T-IY0 <SP> en_P-R-IH1-T-IY0 <SP> en_HH-UW1 en_S-EH1-Z en_Y-UH1-R en_Y-UH1-R en_N-AA1-T en_B-Y-UW1-T-AH0-F-AH0-L en_B-Y-UW1-T-AH0-F-AH0-L <SP> en_HH-UW1 en_S-EH1-Z en_S-EH1-Z <SP>",
|
| 12 |
+
"note_pitch": "0 68 67 65 63 63 66 67 70 66 68 67 65 63 63 67 65 63 65 61 58 0",
|
| 13 |
+
"note_type": "1 2 2 2 3 2 2 1 3 1 2 2 2 3 2 2 3 1 2 2 3 1",
|
| 14 |
+
"f0": "0.0 0.0 382.7 387.7 385.9 379.8 376.0 380.9 390.1 403.2 415.3 423.6 421.6 402.6 385.2 381.1 0.0 0.0 425.8 419.0 409.6 397.8 392.2 389.0 388.5 391.4 389.1 381.4 375.9 0.0 0.0 0.0 0.0 359.0 354.7 353.8 353.7 354.7 353.1 351.1 350.4 349.0 348.9 346.3 337.4 328.0 312.8 303.1 298.4 296.0 298.9 302.0 306.3 307.9 307.3 307.5 307.3 302.9 301.8 0.0 0.0 0.0 0.0 0.0 343.7 364.3 375.9 368.5 358.1 359.1 365.9 378.4 393.1 406.0 412.5 410.9 407.0 404.1 403.5 403.4 401.5 399.4 397.7 395.4 394.4 394.8 395.5 396.5 397.5 400.8 407.9 415.1 417.8 453.1 472.2 481.0 482.3 481.9 480.8 478.7 477.4 476.8 474.8 467.5 446.0 390.4 382.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 374.3 375.5 370.9 370.7 373.1 378.3 392.2 407.6 418.5 423.9 423.2 415.9 395.4 0.0 0.0 0.0 421.5 416.2 405.3 391.0 383.1 380.8 383.0 388.3 388.8 378.3 371.7 0.0 0.0 0.0 371.4 365.1 362.7 358.5 353.0 352.0 353.5 356.1 356.4 353.6 348.3 341.1 330.6 317.7 303.8 293.3 296.5 297.7 301.4 305.3 308.8 308.8 308.2 308.2 306.3 305.6 285.0 269.8 265.6 280.0 304.4 331.2 351.0 357.9 364.2 370.6 381.1 392.9 399.0 399.1 395.0 389.5 379.9 363.0 338.9 318.5 305.6 300.3 299.6 296.3 292.2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 309.6 322.1 329.8 331.2 332.1 332.6 332.4 335.4 340.7 345.0 347.2 346.2 342.6 339.6 337.4 338.3 340.9 342.6 344.0 344.6 344.0 344.2 343.6 341.9 338.8 336.7 337.6 341.1 347.0 350.4 343.0 326.6 330.8 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 310.0 315.7 317.7 317.4 316.5 314.4 314.1 322.5 336.3 350.0 354.2 352.9 350.5 348.4 347.1 347.3 348.4 349.8 349.8 350.4 350.3 323.9 324.7 0.0 0.0 0.0 0.0 0.0 0.0 0.0 297.3 289.8 279.5 275.1 276.1 276.1 274.9 275.4 274.6 271.8 268.6 264.0 258.3 251.7 244.3 239.9 236.1 233.7 234.0 236.0 237.3 236.9 235.2 233.5 231.7 231.0 232.1 233.6 235.4 236.2 236.7 235.8 234.1 232.2 231.3 232.6 233.5 235.2 236.0 232.3 228.8 229.6 233.8 241.3 239.4 226.3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0"
|
| 15 |
+
}
|
| 16 |
+
]
|
example/audio/en_target.mp3
ADDED
|
Binary file (66.9 kB). View file
|
|
|
example/audio/music.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"index": "vocal_240_51240",
|
| 4 |
+
"language": "Mandarin",
|
| 5 |
+
"time": [
|
| 6 |
+
240,
|
| 7 |
+
51240
|
| 8 |
+
],
|
| 9 |
+
"duration": "0.21 0.18 0.34 0.38 0.26 0.22 0.50 0.20 0.33 0.13 0.60 0.22 0.38 0.52 0.30 0.12 1.82 0.98 0.20 0.26 0.38 0.36 0.58 0.54 0.26 0.50 1.46 3.03 0.24 0.24 0.29 0.25 0.20 0.14 0.80 0.22 0.54 0.30 0.24 0.16 0.58 0.28 0.38 1.73 0.79 0.24 0.30 0.34 0.30 0.30 0.34 0.21 0.21 0.36 0.34 0.23 1.85 1.90 0.23 0.39 0.68 0.50 0.31 0.43 0.76 0.38 2.00 1.87 0.68 0.72 0.56 0.62 0.80 0.40 0.42 1.68 1.79 0.70 0.66 0.54 0.24 0.48 0.68 0.40 2.34 0.14",
|
| 10 |
+
"text": "<SP> 只 是 因 为 为 在 人 群 群 中 多 看 了 你 一 眼 <SP> 再 也 没 能 忘 掉 你 容 颜 <SP> 梦 想 着 着 偶 偶 然 然 有 <SP> 一 一 天 再 相 见 <SP> 从 此 我 开 始 始 孤 孤 单 思 念 念 <SP> 想 想 你 时 你 你 在 天 边 <SP> 想 你 时 你 在 眼 前 前 <SP> 想 你 时 你 你 在 脑 海 <SP>",
|
| 11 |
+
"phoneme": "<SP> zh_zhi3 zh_shi4 zh_yin1 zh_wei4 zh_wei2 zh_zai4 zh_ren2 zh_qun2 zh_qun2 zh_zhong1 zh_duo1 zh_kan4 zh_le5 zh_ni3 zh_yi1 zh_yan3 <SP> zh_zai4 zh_ye3 zh_mei2 zh_neng2 zh_wang4 zh_diao4 zh_ni3 zh_rong2 zh_yan2 <SP> zh_meng4 zh_xiang3 zh_zhe5 zh_zhe5 zh_ou3 zh_ou3 zh_ran2 zh_ran2 zh_you3 <SP> zh_yi1 zh_yi1 zh_tian1 zh_zai4 zh_xiang1 zh_jian4 <SP> zh_cong2 zh_ci3 zh_wo3 zh_kai1 zh_shi3 zh_shi3 zh_gu1 zh_gu1 zh_dan1 zh_si1 zh_nian4 zh_nian4 <SP> zh_xiang3 zh_xiang3 zh_ni3 zh_shi2 zh_ni3 zh_ni3 zh_zai4 zh_tian1 zh_bian1 <SP> zh_xiang3 zh_ni3 zh_shi2 zh_ni3 zh_zai4 zh_yan3 zh_qian2 zh_qian2 <SP> zh_xiang3 zh_ni3 zh_shi2 zh_ni3 zh_ni3 zh_zai4 zh_nao3 zh_hai3 <SP>",
|
| 12 |
+
"note_pitch": "0 64 64 64 66 68 66 66 66 64 64 64 64 66 64 60 61 0 63 63 63 64 66 63 61 59 56 0 68 66 66 68 68 66 66 64 64 0 64 66 61 61 61 64 0 62 63 63 64 64 66 65 66 61 58 59 56 0 69 71 66 68 68 71 66 64 61 0 66 61 68 66 64 63 61 59 0 71 66 68 68 71 66 64 61 0",
|
| 13 |
+
"note_type": "1 2 2 2 2 3 2 2 2 3 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 1 2 2 2 3 2 3 2 3 2 1 2 3 2 2 2 2 1 2 2 2 2 2 3 2 3 2 2 2 3 1 2 3 2 2 2 3 2 2 2 1 2 2 2 2 2 2 2 3 1 2 2 2 2 3 2 2 2 1",
|
| 14 |
+
"f0": "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 318.9 332.2 331.2 323.9 312.3 0.0 0.0 0.0 0.0 0.0 0.0 340.9 342.3 340.2 337.4 333.5 329.3 327.8 328.3 329.1 330.5 331.9 331.5 329.2 327.6 327.4 327.7 329.6 330.5 329.5 328.1 328.5 330.1 330.4 329.5 328.9 331.0 333.6 332.8 331.7 330.2 330.6 330.9 329.5 327.6 332.4 343.5 363.9 368.7 369.3 368.2 366.7 371.3 382.0 392.9 402.3 409.8 413.8 416.3 416.4 415.4 414.1 413.4 415.3 417.4 418.3 413.6 416.0 0.0 0.0 0.0 0.0 368.2 362.5 361.3 362.0 362.9 364.3 366.4 367.5 368.4 368.9 368.1 367.6 369.7 371.1 371.1 369.5 368.0 367.3 366.6 367.1 369.5 372.0 371.3 368.2 368.1 369.6 369.9 373.4 376.5 360.2 285.6 0.0 0.0 0.0 0.0 0.0 0.0 0.0 384.0 376.7 372.0 371.7 373.9 376.3 375.5 361.2 333.0 317.8 319.0 329.1 333.5 333.2 333.2 333.3 328.0 314.5 0.0 0.0 0.0 320.6 326.5 332.9 335.7 334.1 329.9 327.4 326.3 326.6 329.1 330.5 330.0 328.0 326.7 328.6 331.1 332.4 332.1 332.8 333.2 333.1 331.8 328.6 318.9 293.2 323.6 327.5 330.0 332.8 331.1 319.4 272.2 0.0 0.0 0.0 0.0 278.5 294.5 310.0 318.5 322.7 327.4 334.1 337.2 336.3 328.5 321.0 324.2 341.5 362.2 374.7 375.8 370.5 366.0 365.2 368.7 371.7 374.4 373.0 368.7 370.4 374.8 375.1 372.5 368.7 363.0 357.1 359.0 366.8 377.1 379.5 371.5 359.9 351.0 358.4 371.7 377.9 375.8 367.8 359.7 363.6 375.3 379.9 377.5 372.9 359.8 345.2 337.1 334.5 333.4 332.8 329.6 319.2 292.8 261.7 253.4 262.8 273.0 278.1 279.3 278.6 277.6 277.9 278.3 278.0 277.3 276.7 275.6 275.3 276.1 277.8 278.2 277.7 277.6 277.2 276.4 275.5 273.9 271.8 270.7 274.7 282.7 285.5 281.4 273.1 264.4 262.2 268.9 278.4 284.3 283.6 277.3 268.5 262.2 263.2 270.7 278.8 285.0 287.3 282.0 272.1 265.2 266.0 272.1 278.2 283.8 285.1 281.4 272.6 267.2 269.1 276.1 282.7 289.3 290.5 285.6 273.9 267.5 267.4 271.8 277.7 282.6 283.1 277.8 269.8 265.6 268.6 273.8 281.1 286.2 287.2 282.3 270.4 264.5 263.3 263.0 268.5 268.3 268.7 270.9 277.2 277.4 284.3 285.7 282.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 311.5 306.4 306.0 306.8 305.8 305.0 307.2 310.8 312.4 313.5 314.4 314.6 313.7 312.2 311.1 311.1 312.8 314.6 312.7 306.5 299.7 302.7 308.7 315.2 314.2 310.9 309.4 309.8 310.6 310.7 310.8 311.4 312.2 313.0 312.5 312.2 312.2 311.9 312.0 312.2 311.4 306.1 303.3 309.3 315.0 325.3 324.4 323.9 325.0 326.4 327.7 327.7 328.1 328.6 327.8 328.3 328.3 327.6 328.6 336.1 347.2 359.4 371.0 374.9 371.3 366.5 369.5 375.5 376.2 370.2 365.7 367.3 372.9 376.9 375.3 368.4 360.1 358.0 365.4 380.8 383.5 380.4 374.4 367.6 365.9 372.0 376.3 377.9 375.7 372.6 368.6 361.2 353.7 0.0 309.4 302.3 299.2 300.6 305.7 308.5 308.6 309.0 310.9 310.9 309.5 308.6 309.0 311.3 313.0 314.4 314.5 313.0 312.1 311.2 308.0 299.5 295.8 295.0 285.7 278.0 280.1 281.2 279.9 278.3 278.3 279.5 279.7 272.2 259.1 0.0 0.0 0.0 240.7 236.0 233.8 233.5 234.6 235.7 235.6 236.2 239.8 245.4 248.8 250.1 250.4 249.9 249.1 247.2 241.7 231.7 216.0 197.6 192.1 197.2 204.3 206.5 205.4 201.7 200.8 203.5 208.2 210.9 210.3 207.1 203.5 202.3 202.1 202.8 205.1 207.3 208.1 207.3 205.4 202.1 199.3 200.5 206.1 212.9 212.9 206.5 197.1 191.1 191.2 196.8 203.6 207.8 208.7 206.4 201.7 197.3 197.4 200.7 206.5 209.9 209.9 206.9 203.3 201.2 203.4 207.7 210.5 208.9 207.6 206.5 204.2 202.0 203.6 205.7 210.8 213.1 214.3 210.6 204.1 199.7 202.3 211.9 217.7 215.1 215.0 215.3 213.3 0.0 0.0 216.3 0.0 0.0 195.2 209.0 205.3 201.3 196.8 195.7 195.8 195.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 431.9 425.8 421.0 417.7 419.3 422.7 423.4 421.1 417.9 411.9 408.6 411.5 414.9 419.6 426.7 424.2 0.0 0.0 0.0 280.8 279.4 369.7 371.8 370.6 369.8 370.5 370.5 371.4 374.6 374.0 361.0 332.7 328.0 0.0 359.0 365.0 371.1 373.8 371.9 369.7 369.6 371.0 378.3 390.4 403.4 414.1 417.9 417.5 417.0 416.6 415.7 414.9 414.5 413.8 413.5 413.6 413.8 415.2 416.9 417.3 415.4 415.0 414.1 411.4 409.2 403.5 392.7 375.6 367.5 365.8 368.8 370.7 370.8 369.0 367.0 366.0 366.2 367.8 369.2 368.8 367.2 366.7 367.3 367.3 368.8 368.5 366.8 365.3 364.0 363.0 365.0 367.3 367.7 365.5 364.0 364.9 367.4 370.3 371.3 369.9 366.5 364.1 368.3 385.8 406.7 409.4 399.7 376.0 357.2 359.9 372.4 377.5 366.9 345.1 320.6 315.8 321.8 329.0 329.9 326.4 324.5 324.8 324.9 325.3 325.1 324.7 325.9 326.8 326.6 326.6 326.4 325.4 323.2 322.5 323.9 326.2 328.6 329.9 329.5 329.0 328.2 327.7 327.9 329.6 331.2 331.4 334.2 340.1 335.5 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 309.0 320.1 329.6 334.4 334.4 334.5 338.2 346.0 355.1 361.4 366.1 369.5 372.7 373.3 372.9 375.2 375.8 374.7 367.6 360.1 0.0 0.0 0.0 0.0 0.0 0.0 283.6 280.4 277.2 274.3 271.8 271.0 273.3 277.2 278.8 278.1 277.2 277.4 278.8 279.3 280.2 281.4 279.2 274.9 268.3 254.1 239.9 0.0 0.0 267.7 267.7 270.1 271.8 273.1 276.3 278.3 269.1 256.1 0.0 0.0 0.0 0.0 286.0 281.1 275.7 272.7 271.7 271.7 272.5 274.0 275.5 277.9 281.1 287.1 314.0 342.2 363.0 376.8 375.4 365.8 351.8 323.3 0.0 0.0 0.0 0.0 299.6 322.7 336.5 336.9 335.4 332.9 330.0 326.6 323.4 322.3 324.4 326.7 328.5 328.5 326.4 324.0 322.5 323.2 325.5 327.7 328.6 328.1 325.4 320.9 316.5 316.1 320.8 327.7 333.7 333.2 325.6 315.5 306.9 305.5 315.1 331.0 343.3 341.6 331.4 319.5 308.3 307.6 318.1 329.5 338.2 342.0 337.4 329.4 321.0 316.6 319.8 330.8 340.6 344.5 342.3 331.4 319.3 315.7 322.9 329.3 336.5 344.0 341.1 328.3 318.7 320.0 328.2 333.5 337.9 339.7 338.5 327.3 323.3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 285.3 284.6 284.3 288.2 295.1 305.6 312.2 315.5 287.8 0.0 0.0 0.0 0.0 0.0 309.8 312.4 313.8 312.8 310.8 310.7 312.5 313.2 310.7 305.7 305.9 308.1 309.9 311.3 311.0 309.8 309.7 310.6 311.1 311.1 311.0 311.9 313.7 316.4 319.5 311.6 281.3 283.4 306.6 0.0 0.0 0.0 334.4 326.7 322.6 321.6 321.7 324.5 332.0 336.7 333.1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 332.5 331.7 332.7 334.8 334.7 335.9 339.7 345.6 353.9 362.3 370.2 374.0 373.8 371.4 370.0 368.2 367.0 368.0 370.4 371.2 371.3 370.0 360.3 0.0 0.0 0.0 319.1 326.8 331.5 332.0 331.6 327.2 328.2 334.1 343.1 350.7 358.0 363.5 368.4 370.8 371.8 371.1 370.2 368.5 368.0 371.1 375.5 376.3 369.5 0.0 0.0 0.0 280.4 267.2 263.6 265.8 271.2 275.0 276.4 276.9 277.1 280.9 287.5 284.7 280.7 0.0 0.0 0.0 0.0 0.0 0.0 231.6 226.0 222.9 224.1 226.2 228.5 234.3 240.9 246.6 249.8 249.6 246.3 245.2 245.8 247.2 248.3 249.3 249.0 245.9 242.7 235.8 226.3 217.8 213.1 209.3 208.0 207.0 205.9 205.8 203.4 201.5 205.3 207.5 209.1 210.4 208.9 203.2 199.2 199.3 201.2 205.6 206.3 204.0 202.7 202.2 203.3 206.1 205.6 201.9 198.8 195.5 195.6 198.7 204.2 214.5 218.3 214.3 207.0 199.1 192.4 189.9 193.6 201.0 210.6 212.4 209.2 202.0 195.2 190.0 190.4 196.0 204.2 209.2 208.9 203.0 195.8 190.7 189.8 194.6 200.8 208.7 213.5 210.5 202.0 193.9 187.7 187.8 193.1 199.8 202.8 204.1 203.8 200.4 197.0 193.9 191.6 192.5 198.4 204.8 203.9 203.5 201.3 200.2 198.4 198.5 201.0 204.0 204.6 205.4 202.0 199.2 194.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 426.6 432.4 436.6 437.2 434.6 433.0 434.0 453.1 473.8 492.8 503.2 501.8 491.3 479.6 478.1 489.7 503.6 508.3 503.2 489.5 482.3 484.2 495.7 505.4 506.6 501.6 495.4 492.7 497.1 500.4 496.1 490.1 480.4 453.8 412.6 373.6 357.7 354.5 356.8 361.0 365.0 367.8 369.5 370.1 371.1 371.3 370.8 370.1 369.7 369.0 369.5 370.4 371.1 370.9 370.2 370.8 371.3 372.7 380.2 386.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 435.0 427.8 417.0 411.6 412.6 411.2 410.7 412.5 413.7 414.9 414.8 415.7 415.4 414.0 413.8 414.6 415.2 416.3 415.2 413.6 413.4 414.9 416.8 416.4 414.7 413.8 413.2 414.7 417.5 420.7 428.2 443.1 459.8 481.7 502.3 510.8 509.1 497.8 490.6 493.5 500.8 508.3 511.8 504.6 493.5 487.3 492.3 507.2 518.8 514.8 496.4 489.6 490.4 491.8 493.3 0.0 0.0 0.0 0.0 0.0 391.9 368.8 360.2 358.3 360.0 362.0 361.8 360.9 360.3 361.8 365.7 368.7 369.3 367.7 364.7 365.6 368.4 370.4 371.7 369.0 366.8 367.5 367.9 369.4 374.6 384.6 396.8 388.2 382.8 0.0 0.0 0.0 0.0 0.0 352.6 339.5 328.2 326.0 326.5 326.9 328.2 329.4 329.4 329.3 337.3 349.1 350.9 343.6 333.4 326.5 327.1 331.5 333.4 326.8 318.4 0.0 0.0 0.0 282.0 283.6 281.6 279.5 273.2 267.9 268.8 273.2 277.5 279.8 278.7 276.8 273.3 268.7 265.0 266.1 272.4 282.3 284.9 279.4 268.5 255.3 253.0 259.9 270.2 281.5 284.3 279.7 270.1 260.2 256.4 260.1 266.6 273.0 279.6 283.6 282.3 272.3 262.1 255.8 259.8 269.6 274.9 278.7 278.4 271.5 261.7 255.8 256.8 263.9 272.1 279.7 278.4 269.9 259.0 255.9 262.1 269.2 274.5 279.6 282.1 280.7 274.3 269.9 270.9 270.7 273.8 276.9 274.9 271.1 266.5 266.7 269.4 276.9 282.9 282.0 279.5 274.9 271.2 267.9 263.1 270.2 281.3 285.5 283.9 280.6 271.1 262.8 263.9 263.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 362.9 364.0 366.8 366.6 363.6 359.8 357.7 359.9 367.6 371.6 365.4 359.3 358.6 359.5 365.5 375.7 378.1 371.5 359.5 350.8 358.6 374.7 381.7 376.8 363.6 353.3 356.9 367.8 378.3 380.8 375.7 372.3 372.6 374.7 372.9 365.2 346.4 314.8 285.6 277.3 275.4 275.3 277.5 279.2 279.2 278.8 277.6 276.4 276.5 277.6 278.8 279.9 281.8 282.7 281.0 279.2 278.6 279.4 279.9 279.2 278.7 281.7 285.8 289.9 288.7 284.4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 428.0 417.9 410.0 409.5 411.2 414.0 417.5 417.9 418.1 419.3 419.0 417.8 417.1 417.2 416.6 415.8 415.0 415.1 414.5 410.9 401.5 388.2 371.5 359.2 363.0 371.7 375.7 370.6 359.6 353.4 358.7 367.0 374.5 378.2 373.7 362.8 358.0 361.6 368.5 373.7 377.7 375.8 368.8 361.7 360.5 366.7 372.5 375.6 380.1 384.9 383.5 376.6 0.0 0.0 0.0 0.0 0.0 331.7 320.0 316.8 320.7 325.1 327.0 327.4 326.6 326.4 325.5 325.1 326.2 327.6 328.5 328.4 327.7 328.7 329.9 329.8 329.7 327.7 326.5 327.8 328.9 329.3 329.8 329.2 328.8 329.4 330.4 331.0 331.0 330.5 329.2 328.2 328.3 328.7 329.4 327.5 322.9 314.8 299.3 290.8 292.3 299.0 305.4 309.2 312.1 313.9 316.8 320.4 327.6 333.1 338.0 341.2 342.1 338.4 331.0 0.0 0.0 0.0 0.0 0.0 0.0 311.7 292.5 281.8 276.1 275.1 276.6 277.6 277.4 274.0 264.8 249.6 239.1 238.8 243.0 245.7 245.0 241.9 237.9 234.6 237.1 242.7 250.4 252.5 248.3 240.7 231.6 228.3 234.0 240.6 247.0 250.6 248.5 242.1 233.6 226.8 227.2 234.0 240.7 246.9 248.9 244.4 237.6 230.8 228.8 236.4 246.9 250.2 249.2 245.4 239.9 232.5 224.7 226.5 238.5 252.0 257.4 255.9 248.4 237.4 230.5 231.6 238.8 247.6 252.6 253.9 253.0 248.9 241.8 236.4 234.1 236.5 246.2 257.7 255.4 248.2 238.5 234.8 237.8 244.4 250.5 253.8 251.7 246.7 240.0 238.3 244.8 251.6 256.4 256.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 484.0 484.5 487.5 490.6 487.2 479.7 478.9 483.8 494.1 505.7 510.3 501.1 485.4 472.4 472.3 495.6 515.8 517.1 503.8 486.5 475.1 479.9 491.4 507.9 509.9 505.0 496.2 485.7 484.5 490.8 495.4 497.7 496.6 489.9 469.1 424.3 387.2 373.0 367.3 364.4 364.6 366.2 366.6 368.8 371.9 373.1 371.4 368.9 367.1 367.0 367.7 368.8 370.1 369.8 369.2 369.8 370.0 369.0 369.0 371.0 375.2 385.3 391.6 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 425.2 418.5 409.4 409.5 408.1 406.3 408.4 410.3 412.8 413.8 412.6 411.2 410.8 410.4 410.2 410.5 412.2 413.6 413.4 412.6 412.6 413.4 415.6 417.9 418.5 417.0 415.0 415.0 416.8 421.5 437.9 455.0 472.6 493.0 502.2 496.7 484.6 480.6 487.2 498.6 506.3 506.2 499.4 489.7 489.5 496.2 503.8 507.2 501.8 491.1 489.9 501.1 520.5 533.7 525.8 0.0 0.0 0.0 0.0 0.0 421.8 365.2 351.0 352.0 358.0 362.6 365.6 365.4 363.2 362.5 362.2 362.4 364.1 365.2 364.9 363.5 363.4 366.3 368.5 370.0 369.7 368.5 366.9 365.7 366.3 368.6 370.8 372.2 370.2 373.1 377.0 376.5 371.0 346.8 326.1 325.2 328.4 331.4 334.0 332.2 327.2 324.4 331.3 349.7 363.2 362.0 346.1 327.4 321.8 329.8 340.7 348.4 349.9 346.1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 269.5 275.7 282.6 282.9 278.8 275.2 274.3 273.5 273.9 275.3 276.4 277.3 276.1 272.3 268.0 268.3 272.5 279.1 281.7 280.1 271.3 261.0 257.3 262.7 273.2 280.8 283.3 279.3 269.1 258.3 258.0 266.3 278.2 286.7 287.8 283.4 274.9 264.7 257.7 260.0 272.0 286.3 294.7 289.8 274.5 263.8 263.4 268.8 277.3 284.1 286.5 285.2 281.7 272.7 264.3 260.3 267.7 281.7 289.2 289.0 281.2 266.2 256.0 254.4 261.5 276.4 286.8 288.8 286.7 273.8 260.9 260.8 270.2 283.4 291.3 292.8 286.1 273.8 265.2 264.0 271.3 283.8 290.3 289.7 277.7 266.9 260.5 263.0 267.8 281.5 286.7 286.1 285.5 280.1 279.0 283.0 284.1 284.5 285.5 282.2 0.0 0.0 280.0 276.4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0"
|
| 15 |
+
}
|
| 16 |
+
]
|
example/audio/music.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:04b35a7b9d03adc494c304af5c4413aa33a02a54a7110016d6e3b559843d90de
|
| 3 |
+
size 1243961
|
example/audio/yue_target.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"index": "vocal_420_14370",
|
| 4 |
+
"language": "Cantonese",
|
| 5 |
+
"time": [
|
| 6 |
+
420,
|
| 7 |
+
14370
|
| 8 |
+
],
|
| 9 |
+
"duration": "0.31 0.26 0.28 0.26 0.40 0.20 0.42 0.24 0.36 0.24 0.32 0.26 0.94 0.32 0.24 0.30 0.34 0.22 0.34 0.90 0.22 0.36 0.32 0.30 0.22 0.36 0.22 0.32 0.34 0.20 0.40 0.24 0.30 0.38 0.22 0.32 0.28 0.36 0.24 0.34 0.26 0.60",
|
| 10 |
+
"text": "<SP> 我 的 心 情 又 像 真 该 等 被 揭 开 嘴 巴 却 再 仰 千 台 人 潮 内 越 文 静 越 变 得 不 受 理 睬 睬 自 己 己 要 交 出 意 外",
|
| 11 |
+
"phoneme": "<SP> yue_ngo5 yue_dik1 yue_sam1 yue_cing4 yue_jau6 yue_zoeng6 yue_zan1 yue_goi1 yue_dang2 yue_bei6 yue_kit3 yue_hoi1 yue_zeoi2 yue_baa1 yue_koek3 yue_zoi3 yue_joeng5 yue_cin1 yue_toi4 yue_jan4 yue_ciu4 yue_noi6 yue_jyut6 yue_man4 yue_zing6 yue_jyut6 yue_bin3 yue_dak1 yue_bat1 yue_sau6 yue_lei5 yue_coi2 yue_coi2 yue_zi6 yue_gei2 yue_gei2 yue_jiu3 yue_gaau1 yue_ceot1 yue_ji3 yue_ngoi6",
|
| 12 |
+
"note_pitch": "0 52 57 59 55 57 59 62 60 58 54 57 59 59 57 55 54 53 57 51 50 54 57 58 54 57 59 61 64 59 54 54 57 59 51 56 58 57 56 56 55 52",
|
| 13 |
+
"note_type": "1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 2 2 3 2 2 2 2 2",
|
| 14 |
+
"f0": "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 74.8 74.8 82.4 85.7 81.0 76.2 78.8 111.9 129.0 146.8 160.6 175.0 182.1 172.9 163.9 190.3 214.7 218.4 221.5 223.5 220.2 209.1 173.8 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 254.5 248.7 245.1 243.3 239.8 238.8 241.2 244.0 245.7 245.6 239.8 216.8 191.2 0.0 0.0 170.9 179.4 189.0 192.6 193.3 192.8 193.5 194.1 194.1 194.4 195.4 197.4 199.7 202.2 205.8 210.6 212.3 214.4 216.5 219.4 221.9 222.4 222.4 222.8 222.9 217.1 189.6 175.4 0.0 0.0 255.5 251.2 246.5 247.3 248.9 249.1 249.4 251.9 253.6 250.2 247.1 246.6 239.1 193.8 191.0 0.0 295.9 301.1 302.8 301.5 296.6 287.7 286.4 290.4 294.2 297.1 297.3 294.7 287.9 273.4 221.0 262.3 265.3 259.7 255.9 254.7 255.1 256.2 257.4 259.3 260.8 261.3 249.3 209.1 194.0 236.7 224.9 210.2 202.6 197.9 201.4 210.6 220.6 230.1 237.9 242.4 242.7 241.0 234.7 220.3 190.9 179.0 185.6 182.3 178.3 177.1 179.0 181.3 182.4 184.5 187.6 185.9 172.3 161.7 167.7 0.0 0.0 206.9 210.4 213.0 214.2 215.6 217.3 217.1 194.7 181.9 184.5 182.3 171.8 155.9 161.4 167.7 195.8 235.3 245.0 245.8 241.9 237.0 232.3 231.3 234.9 241.8 250.9 253.2 248.9 238.0 226.0 220.5 224.9 236.6 250.9 259.6 262.0 259.1 252.5 246.5 241.8 236.0 228.8 216.3 210.5 211.1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 209.1 197.4 191.2 191.2 200.4 222.5 235.6 248.0 252.9 249.4 243.0 235.6 216.1 182.1 175.8 213.4 214.1 214.7 215.1 215.2 215.9 215.7 214.8 211.6 202.2 187.0 183.6 0.0 0.0 0.0 191.3 192.1 191.4 192.4 193.3 192.6 188.9 171.2 0.0 0.0 0.0 0.0 0.0 0.0 192.1 186.5 181.8 178.2 176.1 176.8 178.2 179.4 179.7 180.4 183.3 184.4 182.7 179.0 176.1 174.6 168.7 166.5 167.2 168.8 171.8 175.6 185.4 194.7 197.2 192.6 181.9 0.0 0.0 0.0 192.6 204.8 217.5 220.0 220.4 218.7 216.0 214.2 216.0 216.8 213.0 200.5 183.9 0.0 0.0 158.5 156.8 159.6 161.4 160.7 159.4 159.5 159.3 157.1 152.4 152.1 156.4 162.9 167.3 166.4 160.3 149.6 146.1 149.5 155.4 161.3 164.1 163.3 161.3 154.8 148.5 144.6 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 99.4 97.6 102.0 115.3 131.4 142.5 150.1 155.1 159.7 162.5 161.1 152.3 0.0 0.0 0.0 0.0 174.2 183.7 189.8 191.6 190.4 189.1 189.5 190.4 191.9 192.1 192.5 196.2 203.3 212.2 214.1 216.7 217.4 217.2 216.8 216.4 214.9 214.7 216.3 218.9 220.4 221.5 224.6 231.2 237.9 242.2 242.6 240.3 239.0 240.4 241.7 242.3 238.4 184.9 178.6 194.1 203.6 194.7 185.4 180.5 181.9 185.4 188.7 191.6 193.9 194.6 192.2 191.0 189.5 186.7 181.0 0.0 0.0 219.2 217.8 216.8 215.3 213.3 212.0 213.6 215.2 215.0 214.7 215.8 218.2 221.0 224.4 229.9 237.5 244.9 246.6 246.6 248.0 250.2 249.8 193.6 186.3 193.0 0.0 0.0 0.0 288.3 289.9 287.8 287.2 287.7 285.0 283.4 281.9 280.1 283.4 287.8 290.9 291.6 287.6 240.0 238.5 0.0 333.3 328.9 325.3 323.1 321.6 317.4 298.0 274.9 0.0 0.0 0.0 0.0 0.0 0.0 243.7 248.7 245.5 243.3 243.8 246.2 244.5 226.2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 197.1 189.0 184.3 182.5 184.8 188.1 188.8 190.3 190.4 189.5 188.7 188.0 189.8 188.2 185.0 182.2 179.7 178.1 178.7 188.8 208.5 217.2 220.0 218.6 212.8 191.9 167.3 0.0 0.0 0.0 183.9 201.1 210.5 210.3 209.5 206.7 206.5 212.0 221.1 235.9 250.8 253.0 247.7 238.4 229.7 229.1 235.1 244.7 251.7 253.5 251.4 246.7 242.4 234.6 209.5 180.1 173.1 0.0 0.0 0.0 147.9 147.5 155.4 159.1 160.2 161.4 162.0 161.7 159.0 152.4 139.6 121.5 126.4 0.0 0.0 197.2 200.2 196.8 195.8 200.3 203.2 203.8 202.5 203.7 212.8 220.0 225.8 231.6 234.1 231.7 228.3 225.4 226.0 229.7 234.3 237.3 238.2 236.9 232.9 227.1 220.3 215.1 211.3 205.9 210.1 216.0 217.6 218.4 218.9 219.3 218.5 217.1 216.9 217.7 216.4 212.1 196.3 171.7 171.3 210.9 203.0 194.9 194.8 199.1 205.4 210.5 214.9 219.6 224.8 225.1 221.4 212.4 198.5 0.0 0.0 204.9 204.1 208.9 212.7 212.8 213.2 214.9 214.4 208.4 189.0 159.7 160.4 193.1 198.9 196.0 192.8 193.0 195.3 195.2 194.5 193.6 193.9 192.5 192.3 192.2 183.0 164.8 150.0 147.3 150.7 155.9 160.8 163.6 164.8 161.8 156.6 155.0 160.5 166.2 167.3 165.1 162.0 154.3 142.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0"
|
| 15 |
+
}
|
| 16 |
+
]
|
example/audio/yue_target.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a699c2649eec48ed1e9a6caae2af918bf7d49e5e4ad39cf3cca0916942bc7db2
|
| 3 |
+
size 353361
|
example/audio/zh_prompt.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"index": "vocal_320_10687",
|
| 4 |
+
"language": "Mandarin",
|
| 5 |
+
"time": [
|
| 6 |
+
320,
|
| 7 |
+
10687
|
| 8 |
+
],
|
| 9 |
+
"duration": "0.23 0.34 0.26 0.70 0.52 0.46 0.36 0.44 0.14 0.24 0.64 0.47 0.51 1.10 0.28 0.38 0.32 0.32 0.38 0.32 0.31 0.19 1.45",
|
| 10 |
+
"text": "<SP> 除 了 想 你 你 <SP> 除 了 了 爱 你 你 <SP> 我 什 么 什 么 都 愿 愿 意",
|
| 11 |
+
"phoneme": "<SP> zh_chu2 zh_le5 zh_xiang3 zh_ni3 zh_ni3 <SP> zh_chu2 zh_le5 zh_le5 zh_ai4 zh_ni3 zh_ni3 <SP> zh_wo3 zh_shen2 zh_me5 zh_shen2 zh_me5 zh_dou1 zh_yuan4 zh_yuan4 zh_yi4",
|
| 12 |
+
"note_pitch": "0 62 65 67 67 69 0 67 69 67 65 67 69 67 67 66 64 64 60 60 65 67 0",
|
| 13 |
+
"note_type": "1 2 2 2 2 3 1 2 2 3 2 2 3 1 2 2 2 2 2 2 2 3 2",
|
| 14 |
+
"f0": "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 294.1 288.2 290.6 294.6 295.7 292.8 291.5 294.4 295.4 294.7 293.5 292.2 294.4 295.8 293.2 297.7 320.1 338.1 348.3 348.7 344.4 342.6 346.2 354.8 356.9 353.1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 401.1 403.8 405.0 402.1 398.4 395.6 393.3 392.0 391.4 390.3 390.2 390.3 390.3 391.3 392.8 393.6 391.7 390.5 391.4 391.6 391.5 393.5 393.8 390.8 387.4 387.8 389.3 390.8 392.2 391.6 390.2 389.8 389.1 388.4 390.0 395.5 397.2 396.7 395.5 395.1 394.6 394.9 395.6 395.2 394.6 395.4 395.9 394.0 391.7 390.7 391.7 392.6 391.6 395.7 405.8 441.7 462.5 463.8 450.2 430.9 414.5 415.2 426.7 439.8 454.4 462.9 447.8 422.5 400.6 403.2 423.5 451.3 482.3 492.8 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 435.9 414.3 406.2 402.8 398.2 396.7 396.7 395.2 398.3 416.9 441.6 446.9 441.5 437.3 435.3 432.9 431.6 432.3 434.8 436.7 433.8 422.8 406.7 390.9 382.0 381.7 384.7 382.7 368.1 357.4 355.6 355.1 352.4 348.1 346.4 348.7 351.6 355.0 354.4 351.7 349.9 349.2 348.1 346.0 345.4 344.2 344.4 345.5 346.6 349.0 349.7 349.1 349.5 349.6 349.7 349.4 349.5 352.2 354.5 355.7 355.6 356.9 359.4 361.5 363.8 360.4 354.3 357.2 363.9 372.4 382.6 399.1 402.9 400.6 395.5 390.5 388.9 390.2 391.1 391.9 391.5 390.4 390.2 391.3 391.6 391.0 388.6 386.2 389.1 403.8 430.2 441.8 449.5 448.1 443.2 438.3 432.9 430.7 434.6 442.0 447.6 446.0 440.3 434.7 431.1 435.7 442.3 445.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 423.8 402.1 398.5 397.6 393.4 390.5 391.4 402.5 427.6 442.5 442.5 435.1 430.1 430.8 439.9 447.4 442.2 426.1 412.3 399.8 391.8 389.5 388.5 387.8 386.4 384.7 384.5 387.9 391.2 391.8 392.9 393.8 392.0 392.0 395.4 398.1 398.2 396.3 393.1 391.1 388.9 386.6 383.1 381.5 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 357.8 367.6 370.9 365.4 359.3 355.7 358.1 372.7 396.8 404.0 398.4 392.6 389.2 388.8 383.6 362.3 341.1 325.1 326.3 327.9 331.3 333.3 326.0 319.4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 356.4 353.5 349.0 346.5 338.4 328.4 323.7 326.2 334.2 338.4 331.6 309.8 280.6 256.4 252.8 256.1 258.3 256.9 257.6 261.1 259.7 258.7 258.6 260.2 262.2 262.4 262.9 264.2 263.4 260.2 256.4 253.2 238.7 223.3 231.6 257.7 258.1 258.4 258.8 258.0 256.8 255.8 254.0 255.1 258.1 261.9 263.7 262.6 256.7 253.1 250.2 246.7 258.7 294.4 327.1 342.6 346.0 344.1 341.2 342.3 345.2 350.1 364.7 382.5 396.1 396.1 389.2 381.8 381.1 387.2 397.0 399.3 390.7 374.3 360.1 350.6 346.6 347.4 350.7 354.4 354.3 351.7 349.8 348.4 346.9 347.0 348.4 349.5 351.0 352.3 353.6 353.3 350.5 348.3 345.5 344.3 344.4 347.0 350.6 352.0 351.0 350.8 350.1 347.6 345.7 347.0 350.3 351.7 350.7 348.5 346.9 347.7 349.0 349.1 348.5 346.8 346.3 348.4 349.0 349.2 351.1 349.6 348.3 350.5 351.1 348.0 347.6 349.1 351.3 356.0 361.3 360.6 354.0 341.0 316.2 302.9 302.1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0"
|
| 15 |
+
}
|
| 16 |
+
]
|
example/audio/zh_prompt.mp3
ADDED
|
Binary file (86.1 kB). View file
|
|
|
example/audio/zh_target.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"index": "vocal_0_6710",
|
| 4 |
+
"language": "Mandarin",
|
| 5 |
+
"time": [
|
| 6 |
+
0,
|
| 7 |
+
6710
|
| 8 |
+
],
|
| 9 |
+
"duration": "0.13 0.26 0.24 0.22 0.24 0.33 0.13 0.24 0.22 0.46 0.69 0.84 0.26 0.30 0.16 0.26 0.26 0.20 0.32 0.94",
|
| 10 |
+
"text": "<SP> 像 我 这 样 懦 懦 弱 的 人 人 <SP> 凡 事 都 要 留 留 几 分",
|
| 11 |
+
"phoneme": "<SP> zh_xiang4 zh_wo3 zh_zhe4 zh_yang4 zh_nuo4 zh_nuo4 zh_ruo4 zh_de5 zh_ren2 zh_ren2 <SP> zh_fan2 zh_shi4 zh_dou1 zh_yao4 zh_liu2 zh_liu2 zh_ji3 zh_fen1",
|
| 12 |
+
"note_pitch": "0 50 53 55 53 56 54 53 50 51 53 0 51 53 55 53 54 56 51 53",
|
| 13 |
+
"note_type": "1 2 2 2 2 2 3 2 2 2 3 1 2 2 2 2 2 3 2 2",
|
| 14 |
+
"f0": "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 132.7 137.2 144.9 147.0 148.0 148.6 148.9 148.5 147.9 147.4 148.1 149.1 154.3 166.4 173.1 175.0 176.3 175.9 172.7 173.6 175.9 172.9 159.1 165.8 0.0 214.2 213.4 210.2 201.5 198.1 197.1 197.2 197.8 200.8 206.4 206.1 200.5 189.9 180.6 172.0 170.8 171.4 176.2 180.9 182.3 182.2 181.0 180.5 183.4 192.6 211.8 220.6 223.7 219.3 211.9 207.0 203.6 202.6 204.2 204.4 204.1 202.1 198.0 192.9 185.9 177.9 174.1 174.6 174.5 173.8 173.6 172.4 168.3 168.3 172.7 173.2 171.9 170.7 170.2 169.9 170.6 173.1 172.4 164.2 148.2 147.8 152.3 148.5 143.8 145.6 149.2 149.9 150.1 152.5 153.6 154.7 156.1 155.0 152.4 152.1 153.7 155.3 156.4 156.8 157.3 157.7 157.1 156.8 157.8 158.9 157.9 157.5 157.1 157.0 159.1 162.0 167.7 172.1 174.9 176.2 174.5 172.0 170.9 171.3 172.5 173.1 173.5 173.1 174.1 174.6 175.2 176.7 177.2 177.3 176.9 175.9 174.1 172.4 174.1 174.8 171.8 172.1 176.5 177.3 176.0 179.4 179.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 137.5 138.5 145.8 151.0 153.3 154.3 156.2 158.5 161.0 158.1 148.6 151.3 157.0 163.5 172.7 176.4 176.1 175.8 175.6 174.7 171.4 169.3 161.7 194.2 199.3 199.7 201.0 202.7 201.4 199.4 198.5 196.6 194.0 190.6 186.1 181.9 179.4 177.4 177.1 176.3 175.8 175.5 174.8 173.6 172.4 170.8 165.0 160.6 175.9 179.3 179.8 180.4 180.3 178.8 178.6 181.5 185.0 190.7 198.3 206.3 210.6 210.9 207.4 203.5 203.4 204.6 203.5 195.6 182.5 0.0 0.0 0.0 0.0 144.7 144.1 146.6 150.2 151.9 153.4 155.0 156.0 155.9 155.4 155.2 153.5 147.0 144.8 0.0 0.0 0.0 0.0 0.0 0.0 181.1 178.9 178.0 177.5 176.0 173.2 172.9 172.9 174.1 176.2 177.3 178.7 178.8 176.0 175.2 175.1 176.3 178.1 177.6 177.4 177.9 177.6 177.3 177.3 177.5 176.6 175.7 176.6 177.5 177.2 175.9 174.8 173.5 174.0 175.7 177.4 177.8 174.7"
|
| 15 |
+
}
|
| 16 |
+
]
|
example/audio/zh_target.mp3
ADDED
|
Binary file (54.2 kB). View file
|
|
|
example/infer.sh
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
script_dir=$(dirname "$(realpath "$0")")
|
| 4 |
+
root_dir=$(dirname "$script_dir")
|
| 5 |
+
|
| 6 |
+
cd $root_dir || exit
|
| 7 |
+
export PYTHONPATH=$root_dir:$PYTHONPATH
|
| 8 |
+
|
| 9 |
+
model_path=pretrained_models/SoulX-Singer/model.pt
|
| 10 |
+
config=soulxsinger/config/soulxsinger.yaml
|
| 11 |
+
prompt_wav_path=example/audio/zh_prompt.mp3
|
| 12 |
+
prompt_metadata_path=example/audio/zh_prompt.json
|
| 13 |
+
target_metadata_path=example/audio/music.json
|
| 14 |
+
phoneset_path=soulxsinger/utils/phoneme/phone_set.json
|
| 15 |
+
save_dir=example/generated/music
|
| 16 |
+
control=score # melody or score
|
| 17 |
+
|
| 18 |
+
python -m cli.inference \
|
| 19 |
+
--device cuda \
|
| 20 |
+
--model_path $model_path \
|
| 21 |
+
--config $config \
|
| 22 |
+
--prompt_wav_path $prompt_wav_path \
|
| 23 |
+
--prompt_metadata_path $prompt_metadata_path \
|
| 24 |
+
--target_metadata_path $target_metadata_path \
|
| 25 |
+
--phoneset_path $phoneset_path \
|
| 26 |
+
--save_dir $save_dir \
|
| 27 |
+
--auto_shift \
|
| 28 |
+
--pitch_shift 0
|
example/preprocess.sh
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
script_dir=$(dirname "$(realpath "$0")")
|
| 4 |
+
root_dir=$(dirname "$script_dir")
|
| 5 |
+
|
| 6 |
+
cd $root_dir || exit
|
| 7 |
+
export PYTHONPATH=$root_dir:$PYTHONPATH
|
| 8 |
+
|
| 9 |
+
device=cuda
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
####### Run Prompt Annotation #######
|
| 13 |
+
audio_path=example/audio/zh_prompt.mp3
|
| 14 |
+
save_dir=example/transcriptions/zh_prompt
|
| 15 |
+
language=Mandarin
|
| 16 |
+
vocal_sep=False
|
| 17 |
+
max_merge_duration=30000
|
| 18 |
+
|
| 19 |
+
python -m preprocess.pipeline \
|
| 20 |
+
--audio_path $audio_path \
|
| 21 |
+
--save_dir $save_dir \
|
| 22 |
+
--language $language \
|
| 23 |
+
--device $device \
|
| 24 |
+
--vocal_sep $vocal_sep \
|
| 25 |
+
--max_merge_duration $max_merge_duration
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
####### Run Target Annotation #######
|
| 29 |
+
audio_path=example/audio/music.mp3
|
| 30 |
+
save_dir=example/transcriptions/music
|
| 31 |
+
language=Mandarin
|
| 32 |
+
vocal_sep=True
|
| 33 |
+
max_merge_duration=60000
|
| 34 |
+
|
| 35 |
+
python -m preprocess.pipeline \
|
| 36 |
+
--audio_path $audio_path \
|
| 37 |
+
--save_dir $save_dir \
|
| 38 |
+
--language $language \
|
| 39 |
+
--device $device \
|
| 40 |
+
--vocal_sep $vocal_sep \
|
| 41 |
+
--max_merge_duration $max_merge_duration
|
preprocess/README.md
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🎵 SoulX-Singer-Preprocess
|
| 2 |
+
|
| 3 |
+
This part offers a comprehensive **singing transcription and editing toolkit** for real-world music audio. It provides the pipeline from vocal extraction to high-level annotation optimized for SVS dataset construction. By integrating state-of-the-art models, it transforms raw audio into structured singing data and supports the **customizable creation and editing of lyric-aligned MIDI scores**.
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
## ✨ Features
|
| 7 |
+
|
| 8 |
+
The toolkit includes the following core modules:
|
| 9 |
+
|
| 10 |
+
- 🎤 **Clean Dry Vocal Extraction**
|
| 11 |
+
Extracts the lead vocal track from polyphonic music audio and dereverberation.
|
| 12 |
+
|
| 13 |
+
- 📝 **Lyrics Transcription**
|
| 14 |
+
Automatically transcribes lyrics from clean vocal.
|
| 15 |
+
|
| 16 |
+
- 🎶 **Note Transcription**
|
| 17 |
+
Converts singing voice into note-level representations for SVS.
|
| 18 |
+
|
| 19 |
+
- 🎼 **MIDI Editor**
|
| 20 |
+
Supports customizable creation and editing of MIDI scores integrated with lyrics.
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
## 🔧 Python Environment
|
| 24 |
+
|
| 25 |
+
Before running the pipeline, set up the Python environment as follows:
|
| 26 |
+
|
| 27 |
+
1. **Install Conda** (if not already installed): https://docs.conda.io/en/latest/miniconda.html
|
| 28 |
+
|
| 29 |
+
2. **Activate or create a conda environment** (recommended Python 3.10):
|
| 30 |
+
|
| 31 |
+
- If you already have the `soulxsinger` environment:
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
conda activate soulxsinger
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
- Otherwise, create it first:
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
conda create -n soulxsinger -y python=3.10
|
| 41 |
+
conda activate soulxsinger
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
3. **Install dependencies** from the `preprocess` directory:
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
cd preprocess
|
| 48 |
+
pip install -r requirements.txt
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
## 📁 Data Preparation
|
| 52 |
+
|
| 53 |
+
Before running the pipeline, prepare the following inputs:
|
| 54 |
+
|
| 55 |
+
- **Prompt audio**
|
| 56 |
+
Reference audio that provides timbre and style
|
| 57 |
+
|
| 58 |
+
- **Target audio**
|
| 59 |
+
Original vocal or music audio to be processed and transcribed.
|
| 60 |
+
|
| 61 |
+
Configure the corresponding parameters in:
|
| 62 |
+
|
| 63 |
+
```
|
| 64 |
+
example/preprocess.sh
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
Typical configuration includes:
|
| 68 |
+
- Input / output paths
|
| 69 |
+
- Module enable switches
|
| 70 |
+
|
| 71 |
+
## 🚀 Usage
|
| 72 |
+
|
| 73 |
+
After configuring `preprocess.sh`, run the transcription pipeline with:
|
| 74 |
+
|
| 75 |
+
```bash
|
| 76 |
+
bash example/preprocess.sh
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
The script will automatically execute the following steps:
|
| 80 |
+
|
| 81 |
+
1. **Vocal separation and dereverberation**
|
| 82 |
+
2. **F0 extraction and voice activity detection (VAD)**
|
| 83 |
+
3. **Lyrics transcription**
|
| 84 |
+
4. **Note transcription**
|
| 85 |
+
|
| 86 |
+
---
|
| 87 |
+
|
| 88 |
+
After the pipeline completes, you will obtain **SoulX-Singer–style metadata** that can be directly used for Singing Voice Synthesis (SVS).
|
| 89 |
+
|
| 90 |
+
**Output paths:**
|
| 91 |
+
- The final metadata (**JSON file**) is written **in the same directory as your input audio**, with the **same filename** (e.g. `audio.mp3` → `audio.json`)
|
| 92 |
+
- All **intermediate results** (separated vocal and accompaniment, F0, VAD outputs, etc.) are also saved under the configured **`save_dir`**.
|
| 93 |
+
|
| 94 |
+
⚠️ **Important Note**
|
| 95 |
+
|
| 96 |
+
Transcription errors—especially in **lyrics** and **note annotations**—can significantly affect the final SVS quality. We **strongly recommend manually reviewing and correcting** the generated metadata before inference.
|
| 97 |
+
|
| 98 |
+
To support this, we provide a **MIDI Editor** for editing lyrics, phoneme alignment, note pitches, and durations. The workflow is:
|
| 99 |
+
|
| 100 |
+
**Export metadata to MIDI** → edit in the MIDI Editor → **Import edited MIDI back to metadata** for SVS.
|
| 101 |
+
|
| 102 |
+
---
|
| 103 |
+
|
| 104 |
+
#### Step 1: Metadata → MIDI (for editing)
|
| 105 |
+
|
| 106 |
+
Convert SoulX-Singer metadata to a MIDI file so you can open it in the MIDI Editor:
|
| 107 |
+
|
| 108 |
+
```bash
|
| 109 |
+
preprocess_root=example/transcriptions/music
|
| 110 |
+
|
| 111 |
+
python -m preprocess.tools.midi_parser \
|
| 112 |
+
--meta2midi \
|
| 113 |
+
--meta "${preprocess_root}/metadata.json" \
|
| 114 |
+
--midi "${preprocess_root}/vocal.mid"
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
#### Step 2: Edit in the MIDI Editor
|
| 118 |
+
|
| 119 |
+
Open the MIDI Editor (see [MIDI Editor Tutorial](tools/midi_editor/README.md)), load `vocal.mid`, and correct lyrics, pitches, or durations as needed. Save the result as e.g. `vocal_edited.mid`.
|
| 120 |
+
|
| 121 |
+
#### Step 3: MIDI → Metadata (for SoulX-Singer inference)
|
| 122 |
+
|
| 123 |
+
Convert the edited MIDI back into SoulX-Singer-style metadata (and cut wavs) for SVS:
|
| 124 |
+
|
| 125 |
+
```bash
|
| 126 |
+
python -m preprocess.tools.midi_parser \
|
| 127 |
+
--midi2meta \
|
| 128 |
+
--midi "${preprocess_root}/vocal_edited.mid" \
|
| 129 |
+
--meta "${preprocess_root}/edit_metadata.json" \
|
| 130 |
+
--vocal "${preprocess_root}/vocal.wav" \
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
Use `edit_metadata.json` (and the wavs under `edit_cut_wavs`) as the target metadata in your inference pipeline.
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
## 🔗 References & Dependencies
|
| 137 |
+
|
| 138 |
+
This project builds upon the following excellent open-source works:
|
| 139 |
+
|
| 140 |
+
### 🎧 Vocal Separation & Dereverberation
|
| 141 |
+
- [Music Source Separation Training](https://github.com/ZFTurbo/Music-Source-Separation-Training)
|
| 142 |
+
- [Lead Vocal Separation](https://huggingface.co/becruily/mel-band-roformer-karaoke)
|
| 143 |
+
- [Vocal Dereverberation](https://huggingface.co/anvuew/dereverb_mel_band_roformer)
|
| 144 |
+
|
| 145 |
+
### 🎼 F0 Extraction
|
| 146 |
+
- [RMVPE](https://github.com/Dream-High/RMVPE)
|
| 147 |
+
|
| 148 |
+
### 📝 Lyrics Transcription (ASR)
|
| 149 |
+
- [Paraformer](https://modelscope.cn/models/iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch)
|
| 150 |
+
- [Parakeet-tdt-0.6b-v2](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2)
|
| 151 |
+
|
| 152 |
+
### 🎶 Note Transcription
|
| 153 |
+
- [ROSVOT](https://github.com/RickyL-2000/ROSVOT)
|
| 154 |
+
|
| 155 |
+
We sincerely thank the authors of these repositories for their exceptional open-source contributions, which have been fundamental to the development of this toolkit.
|
preprocess/pipeline.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import shutil
|
| 3 |
+
import soundfile as sf
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import librosa
|
| 6 |
+
|
| 7 |
+
from preprocess.utils import convert_metadata, merge_short_segments
|
| 8 |
+
|
| 9 |
+
from preprocess.tools import (
|
| 10 |
+
F0Extractor,
|
| 11 |
+
VocalDetector,
|
| 12 |
+
VocalSeparator,
|
| 13 |
+
NoteTranscriber,
|
| 14 |
+
LyricTranscriber,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class PreprocessPipeline:
|
| 19 |
+
def __init__(self, device: str, language: str, save_dir: str, vocal_sep: bool = True, max_merge_duration: int = 60000):
|
| 20 |
+
self.device = device
|
| 21 |
+
self.language = language
|
| 22 |
+
self.save_dir = save_dir
|
| 23 |
+
self.vocal_sep = vocal_sep
|
| 24 |
+
self.max_merge_duration = max_merge_duration
|
| 25 |
+
|
| 26 |
+
if vocal_sep:
|
| 27 |
+
self.vocal_separator = VocalSeparator(
|
| 28 |
+
sep_model_path="pretrained_models/SoulX-Singer-Preprocess/mel-band-roformer-karaoke/mel_band_roformer_karaoke_becruily.ckpt",
|
| 29 |
+
sep_config_path="pretrained_models/SoulX-Singer-Preprocess/mel-band-roformer-karaoke/config_karaoke_becruily.yaml",
|
| 30 |
+
der_model_path="pretrained_models/SoulX-Singer-Preprocess/dereverb_mel_band_roformer/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt",
|
| 31 |
+
der_config_path="pretrained_models/SoulX-Singer-Preprocess/dereverb_mel_band_roformer/dereverb_mel_band_roformer_anvuew.yaml",
|
| 32 |
+
device=device
|
| 33 |
+
)
|
| 34 |
+
else:
|
| 35 |
+
self.vocal_separator = None
|
| 36 |
+
self.f0_extractor = F0Extractor(
|
| 37 |
+
model_path="pretrained_models/SoulX-Singer-Preprocess/rmvpe/rmvpe.pt",
|
| 38 |
+
device=device,
|
| 39 |
+
)
|
| 40 |
+
self.vocal_detector = VocalDetector(
|
| 41 |
+
cut_wavs_output_dir= f"{save_dir}/cut_wavs",
|
| 42 |
+
)
|
| 43 |
+
self.lyric_transcriber = LyricTranscriber(
|
| 44 |
+
zh_model_path="pretrained_models/SoulX-Singer-Preprocess/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
| 45 |
+
en_model_path="pretrained_models/SoulX-Singer-Preprocess/parakeet-tdt-0.6b-v2/parakeet-tdt-0.6b-v2.nemo",
|
| 46 |
+
device=device
|
| 47 |
+
)
|
| 48 |
+
self.note_transcriber = NoteTranscriber(
|
| 49 |
+
rosvot_model_path="pretrained_models/SoulX-Singer-Preprocess/rosvot/rosvot/model.pt",
|
| 50 |
+
rwbd_model_path="pretrained_models/SoulX-Singer-Preprocess/rosvot/rwbd/model.pt",
|
| 51 |
+
device=device
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def run(
|
| 55 |
+
self,
|
| 56 |
+
audio_path: str,
|
| 57 |
+
vocal_sep: bool = True,
|
| 58 |
+
max_merge_duration: int = 60000,
|
| 59 |
+
language: str = "Mandarin"
|
| 60 |
+
) -> None:
|
| 61 |
+
vocal_sep = self.vocal_sep if vocal_sep is None else vocal_sep
|
| 62 |
+
max_merge_duration = self.max_merge_duration if max_merge_duration is None else max_merge_duration
|
| 63 |
+
language = self.language if language is None else language
|
| 64 |
+
output_dir = Path(self.save_dir)
|
| 65 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 66 |
+
|
| 67 |
+
if vocal_sep:
|
| 68 |
+
# Perform vocal/accompaniment separation
|
| 69 |
+
sep = self.vocal_separator.process(audio_path)
|
| 70 |
+
vocal = sep.vocals_dereverbed.T
|
| 71 |
+
acc = sep.accompaniment.T
|
| 72 |
+
sample_rate = sep.sample_rate
|
| 73 |
+
|
| 74 |
+
vocal_path = output_dir / "vocal.wav"
|
| 75 |
+
acc_path = output_dir / "acc.wav"
|
| 76 |
+
sf.write(vocal_path, vocal, sample_rate)
|
| 77 |
+
sf.write(acc_path, acc, sample_rate)
|
| 78 |
+
else:
|
| 79 |
+
# Use the original audio as vocal source (no separation)
|
| 80 |
+
vocal, sample_rate = librosa.load(audio_path, sr=None, mono=True)
|
| 81 |
+
vocal_path = output_dir / "vocal.wav"
|
| 82 |
+
sf.write(vocal_path, vocal, sample_rate)
|
| 83 |
+
|
| 84 |
+
vocal_f0 = self.f0_extractor.process(str(vocal_path))
|
| 85 |
+
segments = self.vocal_detector.process(str(vocal_path), f0=vocal_f0)
|
| 86 |
+
|
| 87 |
+
metadata = []
|
| 88 |
+
for seg in segments:
|
| 89 |
+
self.f0_extractor.process(seg["wav_fn"], f0_path=seg["wav_fn"].replace(".wav", "_f0.npy"))
|
| 90 |
+
words, durs = self.lyric_transcriber.process(
|
| 91 |
+
seg["wav_fn"], language
|
| 92 |
+
)
|
| 93 |
+
seg["words"] = words
|
| 94 |
+
seg["word_durs"] = durs
|
| 95 |
+
seg["language"] = language
|
| 96 |
+
metadata.append(
|
| 97 |
+
self.note_transcriber.process(seg, segment_info=seg)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
merged = merge_short_segments(
|
| 101 |
+
vocal,
|
| 102 |
+
sample_rate,
|
| 103 |
+
metadata,
|
| 104 |
+
output_dir / "long_cut_wavs",
|
| 105 |
+
max_duration_ms=max_merge_duration,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
final_metadata = []
|
| 109 |
+
|
| 110 |
+
for item in merged:
|
| 111 |
+
self.f0_extractor.process(item.wav_fn, f0_path=item.wav_fn.replace(".wav", "_f0.npy"))
|
| 112 |
+
final_metadata.append(convert_metadata(item))
|
| 113 |
+
|
| 114 |
+
with open(output_dir / "metadata.json", "w", encoding="utf-8") as f:
|
| 115 |
+
json.dump(final_metadata, f, ensure_ascii=False, indent=2)
|
| 116 |
+
|
| 117 |
+
shutil.copy(output_dir / "metadata.json", audio_path.replace(".wav", ".json").replace(".mp3", ".json").replace(".flac", ".json"))
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def main(args):
|
| 121 |
+
pipeline = PreprocessPipeline(
|
| 122 |
+
device=args.device,
|
| 123 |
+
language=args.language,
|
| 124 |
+
save_dir=args.save_dir,
|
| 125 |
+
vocal_sep=args.vocal_sep,
|
| 126 |
+
max_merge_duration=args.max_merge_duration,
|
| 127 |
+
)
|
| 128 |
+
pipeline.run(
|
| 129 |
+
audio_path=args.audio_path,
|
| 130 |
+
language=args.language
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
if __name__ == "__main__":
|
| 135 |
+
import argparse
|
| 136 |
+
|
| 137 |
+
parser = argparse.ArgumentParser()
|
| 138 |
+
parser.add_argument("--audio_path", type=str, required=True, help="Path to the input audio file")
|
| 139 |
+
parser.add_argument("--save_dir", type=str, required=True, help="Directory to save the output files")
|
| 140 |
+
parser.add_argument("--language", type=str, default="Mandarin", help="Language of the audio")
|
| 141 |
+
parser.add_argument("--device", type=str, default="cuda:0", help="Device to run the models on")
|
| 142 |
+
parser.add_argument("--vocal_sep", type=bool, default=True, help="Whether to perform vocal separation")
|
| 143 |
+
parser.add_argument("--max_merge_duration", type=int, default=60000, help="Maximum merged segment duration in milliseconds")
|
| 144 |
+
args = parser.parse_args()
|
| 145 |
+
|
| 146 |
+
main(args)
|
preprocess/requirements.txt
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
beartype==0.22.9
|
| 2 |
+
einops==0.8.2
|
| 3 |
+
funasr==1.3.0
|
| 4 |
+
g2p_en==2.1.0
|
| 5 |
+
g2pM==0.1.2.5
|
| 6 |
+
librosa==0.11.0
|
| 7 |
+
loralib==0.1.2
|
| 8 |
+
matplotlib==3.10.8
|
| 9 |
+
mido==1.3.3
|
| 10 |
+
ml_collections==1.1.0
|
| 11 |
+
nemo_toolkit==2.6.1
|
| 12 |
+
nltk==3.9.2
|
| 13 |
+
numba==0.63.1
|
| 14 |
+
numpy==2.2.6
|
| 15 |
+
omegaconf==2.3.0
|
| 16 |
+
packaging==24.2
|
| 17 |
+
praat-parselmouth==0.4.7
|
| 18 |
+
pretty_midi==0.2.11
|
| 19 |
+
pyloudnorm==0.2.0
|
| 20 |
+
pyworld==0.3.5
|
| 21 |
+
rotary_embedding_torch==0.8.9
|
| 22 |
+
sageattention==1.0.6
|
| 23 |
+
scikit_learn==1.7.2
|
| 24 |
+
scipy==1.15.3
|
| 25 |
+
six==1.17.0
|
| 26 |
+
scikit_image==0.25.2
|
| 27 |
+
soundfile==0.13.1
|
| 28 |
+
ToJyutping==3.2.0
|
| 29 |
+
torch==2.10.0
|
| 30 |
+
torchaudio==2.10.0
|
| 31 |
+
tqdm==4.67.1
|
| 32 |
+
wandb==0.24.2
|
| 33 |
+
webrtcvad==2.0.10
|
preprocess/tools/__init__.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Preprocess tools.
|
| 2 |
+
|
| 3 |
+
This package provides a thin, stable import surface for common preprocess components.
|
| 4 |
+
|
| 5 |
+
Examples:
|
| 6 |
+
from preprocess.tools import (
|
| 7 |
+
F0Extractor,
|
| 8 |
+
PitchExtractor,
|
| 9 |
+
VocalDetectionModel,
|
| 10 |
+
VocalSeparationModel,
|
| 11 |
+
VocalExtractionModel,
|
| 12 |
+
NoteTranscriptionModel,
|
| 13 |
+
LyricTranscriptionModel,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
Note:
|
| 17 |
+
Keep these imports lightweight. If a tool pulls heavy dependencies at import time,
|
| 18 |
+
consider switching to lazy imports.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
# Core tools
|
| 24 |
+
from .f0_extraction import F0Extractor
|
| 25 |
+
from .vocal_detection import VocalDetector
|
| 26 |
+
|
| 27 |
+
# Some tools may live outside this package in different layouts across branches.
|
| 28 |
+
# Keep the public surface stable while avoiding hard import failures.
|
| 29 |
+
try:
|
| 30 |
+
from .vocal_separation.model import VocalSeparator # type: ignore
|
| 31 |
+
except Exception: # pragma: no cover
|
| 32 |
+
VocalSeparator = None # type: ignore
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
from .note_transcription.model import NoteTranscriber # type: ignore
|
| 36 |
+
except Exception: # pragma: no cover
|
| 37 |
+
NoteTranscriber = None # type: ignore
|
| 38 |
+
try:
|
| 39 |
+
from .lyric_transcription import LyricTranscriber
|
| 40 |
+
except Exception: # pragma: no cover
|
| 41 |
+
LyricTranscriber = None # type: ignore
|
| 42 |
+
|
| 43 |
+
__all__ = [
|
| 44 |
+
"F0Extractor",
|
| 45 |
+
"VocalDetector",
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
if VocalSeparator is not None:
|
| 49 |
+
__all__.append("VocalSeparator")
|
| 50 |
+
if LyricTranscriber is not None:
|
| 51 |
+
__all__.append("LyricTranscriber")
|
| 52 |
+
if NoteTranscriber is not None:
|
| 53 |
+
__all__.append("NoteTranscriber")
|
preprocess/tools/f0_extraction.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/Dream-High/RMVPE
|
| 2 |
+
import math
|
| 3 |
+
import time
|
| 4 |
+
import librosa
|
| 5 |
+
import numpy as np
|
| 6 |
+
from librosa.filters import mel
|
| 7 |
+
from scipy.interpolate import interp1d
|
| 8 |
+
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BiGRU(nn.Module):
|
| 17 |
+
def __init__(self, input_features, hidden_features, num_layers):
|
| 18 |
+
super(BiGRU, self).__init__()
|
| 19 |
+
self.gru = nn.GRU(
|
| 20 |
+
input_features,
|
| 21 |
+
hidden_features,
|
| 22 |
+
num_layers=num_layers,
|
| 23 |
+
batch_first=True,
|
| 24 |
+
bidirectional=True,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
return self.gru(x)[0]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ConvBlockRes(nn.Module):
|
| 32 |
+
def __init__(self, in_channels, out_channels, momentum=0.01):
|
| 33 |
+
super(ConvBlockRes, self).__init__()
|
| 34 |
+
self.conv = nn.Sequential(
|
| 35 |
+
nn.Conv2d(
|
| 36 |
+
in_channels=in_channels,
|
| 37 |
+
out_channels=out_channels,
|
| 38 |
+
kernel_size=(3, 3),
|
| 39 |
+
stride=(1, 1),
|
| 40 |
+
padding=(1, 1),
|
| 41 |
+
bias=False,
|
| 42 |
+
),
|
| 43 |
+
nn.BatchNorm2d(out_channels, momentum=momentum),
|
| 44 |
+
nn.ReLU(),
|
| 45 |
+
nn.Conv2d(
|
| 46 |
+
in_channels=out_channels,
|
| 47 |
+
out_channels=out_channels,
|
| 48 |
+
kernel_size=(3, 3),
|
| 49 |
+
stride=(1, 1),
|
| 50 |
+
padding=(1, 1),
|
| 51 |
+
bias=False,
|
| 52 |
+
),
|
| 53 |
+
nn.BatchNorm2d(out_channels, momentum=momentum),
|
| 54 |
+
nn.ReLU(),
|
| 55 |
+
)
|
| 56 |
+
if in_channels != out_channels:
|
| 57 |
+
self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
if not hasattr(self, "shortcut"):
|
| 61 |
+
return self.conv(x) + x
|
| 62 |
+
else:
|
| 63 |
+
return self.conv(x) + self.shortcut(x)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class ResEncoderBlock(nn.Module):
|
| 67 |
+
def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
|
| 68 |
+
super(ResEncoderBlock, self).__init__()
|
| 69 |
+
self.n_blocks = n_blocks
|
| 70 |
+
self.conv = nn.ModuleList()
|
| 71 |
+
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
|
| 72 |
+
for i in range(n_blocks - 1):
|
| 73 |
+
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
|
| 74 |
+
self.kernel_size = kernel_size
|
| 75 |
+
if self.kernel_size is not None:
|
| 76 |
+
self.pool = nn.AvgPool2d(kernel_size=kernel_size)
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
for conv in self.conv:
|
| 80 |
+
x = conv(x)
|
| 81 |
+
if self.kernel_size is not None:
|
| 82 |
+
return x, self.pool(x)
|
| 83 |
+
else:
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class Encoder(nn.Module):
|
| 88 |
+
def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
|
| 89 |
+
super(Encoder, self).__init__()
|
| 90 |
+
self.n_encoders = n_encoders
|
| 91 |
+
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
|
| 92 |
+
self.layers = nn.ModuleList()
|
| 93 |
+
self.latent_channels = []
|
| 94 |
+
for i in range(self.n_encoders):
|
| 95 |
+
self.layers.append(
|
| 96 |
+
ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum)
|
| 97 |
+
)
|
| 98 |
+
self.latent_channels.append([out_channels, in_size])
|
| 99 |
+
in_channels = out_channels
|
| 100 |
+
out_channels *= 2
|
| 101 |
+
in_size //= 2
|
| 102 |
+
self.out_size = in_size
|
| 103 |
+
self.out_channel = out_channels
|
| 104 |
+
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
concat_tensors = []
|
| 107 |
+
x = self.bn(x)
|
| 108 |
+
for layer in self.layers:
|
| 109 |
+
t, x = layer(x)
|
| 110 |
+
concat_tensors.append(t)
|
| 111 |
+
return x, concat_tensors
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class Intermediate(nn.Module):
|
| 115 |
+
def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
|
| 116 |
+
super(Intermediate, self).__init__()
|
| 117 |
+
self.n_inters = n_inters
|
| 118 |
+
self.layers = nn.ModuleList()
|
| 119 |
+
self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
|
| 120 |
+
for i in range(self.n_inters - 1):
|
| 121 |
+
self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
|
| 122 |
+
|
| 123 |
+
def forward(self, x):
|
| 124 |
+
for layer in self.layers:
|
| 125 |
+
x = layer(x)
|
| 126 |
+
return x
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class ResDecoderBlock(nn.Module):
|
| 130 |
+
def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
|
| 131 |
+
super(ResDecoderBlock, self).__init__()
|
| 132 |
+
out_padding = (0, 1) if stride == (1, 2) else (1, 1)
|
| 133 |
+
self.n_blocks = n_blocks
|
| 134 |
+
self.conv1 = nn.Sequential(
|
| 135 |
+
nn.ConvTranspose2d(
|
| 136 |
+
in_channels=in_channels,
|
| 137 |
+
out_channels=out_channels,
|
| 138 |
+
kernel_size=(3, 3),
|
| 139 |
+
stride=stride,
|
| 140 |
+
padding=(1, 1),
|
| 141 |
+
output_padding=out_padding,
|
| 142 |
+
bias=False,
|
| 143 |
+
),
|
| 144 |
+
nn.BatchNorm2d(out_channels, momentum=momentum),
|
| 145 |
+
nn.ReLU(),
|
| 146 |
+
)
|
| 147 |
+
self.conv2 = nn.ModuleList()
|
| 148 |
+
self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
|
| 149 |
+
for i in range(n_blocks - 1):
|
| 150 |
+
self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
|
| 151 |
+
|
| 152 |
+
def forward(self, x, concat_tensor):
|
| 153 |
+
x = self.conv1(x)
|
| 154 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
| 155 |
+
for conv2 in self.conv2:
|
| 156 |
+
x = conv2(x)
|
| 157 |
+
return x
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class Decoder(nn.Module):
|
| 161 |
+
def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
|
| 162 |
+
super(Decoder, self).__init__()
|
| 163 |
+
self.layers = nn.ModuleList()
|
| 164 |
+
self.n_decoders = n_decoders
|
| 165 |
+
for i in range(self.n_decoders):
|
| 166 |
+
out_channels = in_channels // 2
|
| 167 |
+
self.layers.append(
|
| 168 |
+
ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
|
| 169 |
+
)
|
| 170 |
+
in_channels = out_channels
|
| 171 |
+
|
| 172 |
+
def forward(self, x, concat_tensors):
|
| 173 |
+
for i, layer in enumerate(self.layers):
|
| 174 |
+
x = layer(x, concat_tensors[-1 - i])
|
| 175 |
+
return x
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class DeepUnet(nn.Module):
|
| 179 |
+
def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
|
| 180 |
+
super(DeepUnet, self).__init__()
|
| 181 |
+
self.encoder = Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels)
|
| 182 |
+
self.intermediate = Intermediate(
|
| 183 |
+
self.encoder.out_channel // 2,
|
| 184 |
+
self.encoder.out_channel,
|
| 185 |
+
inter_layers,
|
| 186 |
+
n_blocks,
|
| 187 |
+
)
|
| 188 |
+
self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
|
| 189 |
+
|
| 190 |
+
def forward(self, x):
|
| 191 |
+
x, concat_tensors = self.encoder(x)
|
| 192 |
+
x = self.intermediate(x)
|
| 193 |
+
x = self.decoder(x, concat_tensors)
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class E2E(nn.Module):
|
| 198 |
+
def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
|
| 199 |
+
super(E2E, self).__init__()
|
| 200 |
+
self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
|
| 201 |
+
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
|
| 202 |
+
if n_gru:
|
| 203 |
+
self.fc = nn.Sequential(
|
| 204 |
+
BiGRU(3 * 128, 256, n_gru),
|
| 205 |
+
nn.Linear(512, 360),
|
| 206 |
+
nn.Dropout(0.25),
|
| 207 |
+
nn.Sigmoid(),
|
| 208 |
+
)
|
| 209 |
+
else:
|
| 210 |
+
self.fc = nn.Sequential(
|
| 211 |
+
nn.Linear(3 * 128, 360),
|
| 212 |
+
nn.Dropout(0.25),
|
| 213 |
+
nn.Sigmoid()
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
def forward(self, mel):
|
| 217 |
+
mel = mel.transpose(-1, -2).unsqueeze(1)
|
| 218 |
+
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
|
| 219 |
+
x = self.fc(x)
|
| 220 |
+
return x
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class MelSpectrogram(torch.nn.Module):
|
| 225 |
+
def __init__(self, is_half, n_mel_channels, sampling_rate, win_length, hop_length,
|
| 226 |
+
n_fft=None, mel_fmin=0, mel_fmax=None, clamp=1e-5):
|
| 227 |
+
super().__init__()
|
| 228 |
+
n_fft = win_length if n_fft is None else n_fft
|
| 229 |
+
self.hann_window = {}
|
| 230 |
+
mel_basis = mel(
|
| 231 |
+
sr=sampling_rate,
|
| 232 |
+
n_fft=n_fft,
|
| 233 |
+
n_mels=n_mel_channels,
|
| 234 |
+
fmin=mel_fmin,
|
| 235 |
+
fmax=mel_fmax,
|
| 236 |
+
htk=True,
|
| 237 |
+
)
|
| 238 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
| 239 |
+
self.register_buffer("mel_basis", mel_basis)
|
| 240 |
+
self.n_fft = win_length if n_fft is None else n_fft
|
| 241 |
+
self.hop_length = hop_length
|
| 242 |
+
self.win_length = win_length
|
| 243 |
+
self.sampling_rate = sampling_rate
|
| 244 |
+
self.n_mel_channels = n_mel_channels
|
| 245 |
+
self.clamp = clamp
|
| 246 |
+
self.is_half = is_half
|
| 247 |
+
|
| 248 |
+
def forward(self, audio, keyshift=0, speed=1, center=True):
|
| 249 |
+
factor = 2 ** (keyshift / 12)
|
| 250 |
+
n_fft_new = int(np.round(self.n_fft * factor))
|
| 251 |
+
win_length_new = int(np.round(self.win_length * factor))
|
| 252 |
+
hop_length_new = int(np.round(self.hop_length * speed))
|
| 253 |
+
|
| 254 |
+
keyshift_key = str(keyshift) + "_" + str(audio.device)
|
| 255 |
+
if keyshift_key not in self.hann_window:
|
| 256 |
+
self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
|
| 257 |
+
|
| 258 |
+
fft = torch.stft(
|
| 259 |
+
audio,
|
| 260 |
+
n_fft=n_fft_new,
|
| 261 |
+
hop_length=hop_length_new,
|
| 262 |
+
win_length=win_length_new,
|
| 263 |
+
window=self.hann_window[keyshift_key],
|
| 264 |
+
center=center,
|
| 265 |
+
return_complex=True,
|
| 266 |
+
)
|
| 267 |
+
magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
|
| 268 |
+
|
| 269 |
+
if keyshift != 0:
|
| 270 |
+
size = self.n_fft // 2 + 1
|
| 271 |
+
resize = magnitude.size(1)
|
| 272 |
+
if resize < size:
|
| 273 |
+
magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
|
| 274 |
+
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
|
| 275 |
+
|
| 276 |
+
mel_output = torch.matmul(self.mel_basis, magnitude)
|
| 277 |
+
if self.is_half:
|
| 278 |
+
mel_output = mel_output.half()
|
| 279 |
+
log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
|
| 280 |
+
return log_mel_spec
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class RMVPE:
|
| 285 |
+
def __init__(self, model_path: str, is_half, device=None):
|
| 286 |
+
self.is_half = is_half
|
| 287 |
+
if device is None:
|
| 288 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 289 |
+
self.device = torch.device(device) if isinstance(device, str) else device
|
| 290 |
+
|
| 291 |
+
self.mel_extractor = MelSpectrogram(
|
| 292 |
+
is_half=is_half,
|
| 293 |
+
n_mel_channels=128,
|
| 294 |
+
sampling_rate=16000,
|
| 295 |
+
win_length=1024,
|
| 296 |
+
hop_length=160,
|
| 297 |
+
n_fft=None,
|
| 298 |
+
mel_fmin=30,
|
| 299 |
+
mel_fmax=8000
|
| 300 |
+
).to(self.device)
|
| 301 |
+
|
| 302 |
+
model = E2E(n_blocks=4, n_gru=1, kernel_size=(2, 2))
|
| 303 |
+
ckpt = torch.load(model_path, map_location=self.device)
|
| 304 |
+
model.load_state_dict(ckpt)
|
| 305 |
+
model.eval()
|
| 306 |
+
|
| 307 |
+
if is_half:
|
| 308 |
+
model = model.half()
|
| 309 |
+
else:
|
| 310 |
+
model = model.float()
|
| 311 |
+
|
| 312 |
+
self.model = model.to(self.device)
|
| 313 |
+
|
| 314 |
+
cents_mapping = 20 * np.arange(360) + 1997.3794084376191
|
| 315 |
+
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
|
| 316 |
+
|
| 317 |
+
def mel2hidden(self, mel):
|
| 318 |
+
with torch.no_grad():
|
| 319 |
+
n_frames = mel.shape[-1]
|
| 320 |
+
n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
|
| 321 |
+
if n_pad > 0:
|
| 322 |
+
mel = F.pad(mel, (0, n_pad), mode="constant")
|
| 323 |
+
mel = mel.half() if self.is_half else mel.float()
|
| 324 |
+
hidden = self.model(mel)
|
| 325 |
+
return hidden[:, :n_frames]
|
| 326 |
+
|
| 327 |
+
def decode(self, hidden, thred=0.03):
|
| 328 |
+
cents_pred = self.to_local_average_cents(hidden, thred=thred)
|
| 329 |
+
f0 = 10 * (2 ** (cents_pred / 1200))
|
| 330 |
+
f0[f0 == 10] = 0
|
| 331 |
+
return f0
|
| 332 |
+
|
| 333 |
+
def infer_from_audio(self, audio, thred=0.03):
|
| 334 |
+
if not torch.is_tensor(audio):
|
| 335 |
+
audio = torch.from_numpy(audio)
|
| 336 |
+
|
| 337 |
+
mel = self.mel_extractor(audio.float().to(self.device).unsqueeze(0), center=True)
|
| 338 |
+
hidden = self.mel2hidden(mel)
|
| 339 |
+
hidden = hidden.squeeze(0).cpu().numpy()
|
| 340 |
+
|
| 341 |
+
if self.is_half:
|
| 342 |
+
hidden = hidden.astype("float32")
|
| 343 |
+
|
| 344 |
+
f0 = self.decode(hidden, thred=thred)
|
| 345 |
+
return f0
|
| 346 |
+
|
| 347 |
+
def to_local_average_cents(self, salience, thred=0.05):
|
| 348 |
+
center = np.argmax(salience, axis=1)
|
| 349 |
+
salience = np.pad(salience, ((0, 0), (4, 4)))
|
| 350 |
+
center += 4
|
| 351 |
+
|
| 352 |
+
todo_salience = []
|
| 353 |
+
todo_cents_mapping = []
|
| 354 |
+
starts = center - 4
|
| 355 |
+
ends = center + 5
|
| 356 |
+
|
| 357 |
+
for idx in range(salience.shape[0]):
|
| 358 |
+
todo_salience.append(salience[:, starts[idx]:ends[idx]][idx])
|
| 359 |
+
todo_cents_mapping.append(self.cents_mapping[starts[idx]:ends[idx]])
|
| 360 |
+
|
| 361 |
+
todo_salience = np.array(todo_salience)
|
| 362 |
+
todo_cents_mapping = np.array(todo_cents_mapping)
|
| 363 |
+
product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
|
| 364 |
+
weight_sum = np.sum(todo_salience, 1)
|
| 365 |
+
devided = product_sum / weight_sum
|
| 366 |
+
|
| 367 |
+
maxx = np.max(salience, axis=1)
|
| 368 |
+
devided[maxx <= thred] = 0
|
| 369 |
+
|
| 370 |
+
return devided
|
| 371 |
+
|
| 372 |
+
class F0Extractor:
|
| 373 |
+
"""Extract frame-level f0 from singing voice.
|
| 374 |
+
|
| 375 |
+
Wrapper around an RMVPE network that:
|
| 376 |
+
1) loads the checkpoint once in ``__init__``
|
| 377 |
+
2) exposes a simple :py:meth:`process` API and optionally saves ``*_f0.npy``.
|
| 378 |
+
"""
|
| 379 |
+
def __init__(
|
| 380 |
+
self,
|
| 381 |
+
model_path: str,
|
| 382 |
+
device: str = "cpu",
|
| 383 |
+
*,
|
| 384 |
+
is_half: bool = False,
|
| 385 |
+
input_sr: int = 16000,
|
| 386 |
+
target_sr: int = 24000,
|
| 387 |
+
hop_size: int = 480,
|
| 388 |
+
max_duration: float = 300,
|
| 389 |
+
thred: float = 0.03,
|
| 390 |
+
verbose: bool = True,
|
| 391 |
+
):
|
| 392 |
+
"""Initialize the f0 extractor.
|
| 393 |
+
|
| 394 |
+
Args:
|
| 395 |
+
model_path: Path to RMVPE checkpoint.
|
| 396 |
+
device: Torch device string, e.g. ``"cuda:0"`` / ``"cpu"``.
|
| 397 |
+
is_half: Whether to run the model in fp16.
|
| 398 |
+
input_sr: Input resample rate used by RMVPE frontend.
|
| 399 |
+
target_sr: Target sample rate for the output f0 grid.
|
| 400 |
+
hop_size: Target hop size for the output f0 grid.
|
| 401 |
+
max_duration: Max duration (seconds) for interpolation grid.
|
| 402 |
+
thred: Voicing threshold used when decoding salience.
|
| 403 |
+
verbose: Whether to print verbose logs.
|
| 404 |
+
"""
|
| 405 |
+
self.model_path = model_path
|
| 406 |
+
self.input_sr = input_sr
|
| 407 |
+
self.target_sr = target_sr
|
| 408 |
+
self.hop_size = hop_size
|
| 409 |
+
self.max_duration = max_duration
|
| 410 |
+
self.thred = thred
|
| 411 |
+
|
| 412 |
+
self.verbose = verbose
|
| 413 |
+
|
| 414 |
+
self.model = RMVPE(model_path, is_half=is_half, device=device)
|
| 415 |
+
|
| 416 |
+
if self.verbose:
|
| 417 |
+
print(
|
| 418 |
+
"[f0 extraction] init success:",
|
| 419 |
+
f"device={device}",
|
| 420 |
+
f"model_path={model_path}",
|
| 421 |
+
f"is_half={is_half}",
|
| 422 |
+
f"input_sr={input_sr}",
|
| 423 |
+
f"target_sr={target_sr}",
|
| 424 |
+
f"hop_size={hop_size}",
|
| 425 |
+
f"thred={thred}",
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
@staticmethod
|
| 429 |
+
def interpolate_f0(
|
| 430 |
+
f0_16k: np.ndarray,
|
| 431 |
+
original_length: int,
|
| 432 |
+
original_sr: int,
|
| 433 |
+
*,
|
| 434 |
+
target_sr: int = 48000,
|
| 435 |
+
hop_size: int = 256,
|
| 436 |
+
max_duration: float = 20.0,
|
| 437 |
+
) -> np.ndarray:
|
| 438 |
+
"""Interpolate f0 from RMVPE's 16k hop grid to target mel hop grid."""
|
| 439 |
+
mel_target_sr = target_sr
|
| 440 |
+
mel_hop_size = hop_size
|
| 441 |
+
mel_max_duration = max_duration
|
| 442 |
+
|
| 443 |
+
batch_max_length = int(mel_max_duration * mel_target_sr / mel_hop_size)
|
| 444 |
+
duration_in_seconds = original_length / original_sr
|
| 445 |
+
effective_target_length = int(duration_in_seconds * mel_target_sr)
|
| 446 |
+
original_frames = math.ceil(effective_target_length / mel_hop_size)
|
| 447 |
+
target_frames = min(original_frames, batch_max_length)
|
| 448 |
+
|
| 449 |
+
rmvpe_hop = 160
|
| 450 |
+
t_16k = np.arange(len(f0_16k)) * (rmvpe_hop / 16000.0)
|
| 451 |
+
t_target = np.arange(target_frames) * (mel_hop_size / float(mel_target_sr))
|
| 452 |
+
|
| 453 |
+
if len(f0_16k) > 0:
|
| 454 |
+
f_interp = interp1d(
|
| 455 |
+
t_16k,
|
| 456 |
+
f0_16k,
|
| 457 |
+
kind="linear",
|
| 458 |
+
bounds_error=False,
|
| 459 |
+
fill_value=0.0,
|
| 460 |
+
assume_sorted=True,
|
| 461 |
+
)
|
| 462 |
+
f0 = f_interp(t_target)
|
| 463 |
+
else:
|
| 464 |
+
f0 = np.zeros(target_frames)
|
| 465 |
+
|
| 466 |
+
if len(f0) != target_frames:
|
| 467 |
+
f0 = (
|
| 468 |
+
f0[:target_frames]
|
| 469 |
+
if len(f0) > target_frames
|
| 470 |
+
else np.pad(f0, (0, target_frames - len(f0)), "constant")
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
return f0
|
| 474 |
+
|
| 475 |
+
def process(self, audio_path: str, *, f0_path: str | None = None, verbose: Optional[bool] = None) -> np.ndarray:
|
| 476 |
+
"""Run f0 extraction for a single wav.
|
| 477 |
+
|
| 478 |
+
Args:
|
| 479 |
+
audio_path: Path to the input wav file.
|
| 480 |
+
f0_path: if is not None, save the f0 data to this path.
|
| 481 |
+
verbose: Override instance-level verbose flag for this call.
|
| 482 |
+
|
| 483 |
+
Returns:
|
| 484 |
+
np.ndarray: shape ``[T]``, f0 in Hz (0 for unvoiced).
|
| 485 |
+
"""
|
| 486 |
+
verbose = self.verbose if verbose is None else verbose
|
| 487 |
+
if verbose:
|
| 488 |
+
print(f"[f0 extraction] process: start: {audio_path}")
|
| 489 |
+
t0 = time.time()
|
| 490 |
+
|
| 491 |
+
audio, _ = librosa.load(audio_path, sr=self.input_sr)
|
| 492 |
+
f0_16k = self.model.infer_from_audio(audio, thred=self.thred)
|
| 493 |
+
f0 = self.interpolate_f0(
|
| 494 |
+
f0_16k,
|
| 495 |
+
original_length=audio.shape[-1],
|
| 496 |
+
original_sr=self.input_sr,
|
| 497 |
+
target_sr=self.target_sr,
|
| 498 |
+
hop_size=self.hop_size,
|
| 499 |
+
max_duration=self.max_duration,
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
if verbose:
|
| 503 |
+
dt = time.time() - t0
|
| 504 |
+
voiced_ratio = float(np.mean(f0 > 0)) if len(f0) else 0.0
|
| 505 |
+
print(
|
| 506 |
+
"[f0 extraction] process: done:",
|
| 507 |
+
f"frames={len(f0)}",
|
| 508 |
+
f"voiced_ratio={voiced_ratio:.3f}",
|
| 509 |
+
f"time={dt:.3f}s",
|
| 510 |
+
)
|
| 511 |
+
if f0_path is not None:
|
| 512 |
+
np.save(f0_path, f0)
|
| 513 |
+
|
| 514 |
+
return f0
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
if __name__ == "__main__":
|
| 518 |
+
model_path = (
|
| 519 |
+
"pretrained_models/rmvpe/rmvpe.pt"
|
| 520 |
+
)
|
| 521 |
+
audio_path = "./outputs/transcription/test.wav"
|
| 522 |
+
|
| 523 |
+
pe = F0Extractor(
|
| 524 |
+
model_path,
|
| 525 |
+
device="cuda",
|
| 526 |
+
)
|
| 527 |
+
f0 = pe.process(audio_path)
|
preprocess/tools/g2p.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
import ToJyutping
|
| 4 |
+
from g2pM import G2pM
|
| 5 |
+
from g2p_en import G2p as G2pE
|
| 6 |
+
|
| 7 |
+
_EN_WORD_RE = re.compile(r"^[A-Za-z]+(?:'[A-Za-z]+)*$")
|
| 8 |
+
_ZH_WORD_RE = re.compile(r"[\u4e00-\u9fff]")
|
| 9 |
+
|
| 10 |
+
EN_FLAG = "en_"
|
| 11 |
+
YUE_FLAG = "yue_"
|
| 12 |
+
ZH_FLAG = "zh_"
|
| 13 |
+
|
| 14 |
+
g2p_zh = G2pM()
|
| 15 |
+
g2p_en = G2pE()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def is_chinese_char(word: str) -> bool:
|
| 19 |
+
if len(word) != 1:
|
| 20 |
+
return False
|
| 21 |
+
return bool(_ZH_WORD_RE.fullmatch(word))
|
| 22 |
+
|
| 23 |
+
def is_english_word(word: str) -> bool:
|
| 24 |
+
if not word:
|
| 25 |
+
return False
|
| 26 |
+
return bool(_EN_WORD_RE.fullmatch(word))
|
| 27 |
+
|
| 28 |
+
def g2p_cantonese(sent):
|
| 29 |
+
return ToJyutping.get_jyutping_list(sent) # with tone
|
| 30 |
+
|
| 31 |
+
def g2p_mandarin(sent):
|
| 32 |
+
return g2p_zh(sent, tone=True, char_split=False)
|
| 33 |
+
|
| 34 |
+
def g2p_english(word):
|
| 35 |
+
return g2p_en(word)
|
| 36 |
+
|
| 37 |
+
def g2p_transform(words, lang):
|
| 38 |
+
|
| 39 |
+
zh_words = []
|
| 40 |
+
transformed_words = [0] * len(words)
|
| 41 |
+
|
| 42 |
+
for idx, w in enumerate(words):
|
| 43 |
+
if w == "<SP>":
|
| 44 |
+
transformed_words[idx] = w
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
w = w.replace("?", "").replace(".", "").replace("!", "").replace(",", "")
|
| 48 |
+
|
| 49 |
+
if is_chinese_char(w):
|
| 50 |
+
zh_words.append([idx, w])
|
| 51 |
+
else:
|
| 52 |
+
if is_english_word(w):
|
| 53 |
+
w = EN_FLAG + "-".join(g2p_english(w.lower()))
|
| 54 |
+
else:
|
| 55 |
+
w = "<SP>"
|
| 56 |
+
transformed_words[idx] = w
|
| 57 |
+
|
| 58 |
+
sent = "".join([k[1] for k in zh_words])
|
| 59 |
+
|
| 60 |
+
# zh (zh and yue) transformer to g2p
|
| 61 |
+
if len(sent) > 0:
|
| 62 |
+
if lang == "Cantonese":
|
| 63 |
+
g2pm_rst = g2p_cantonese(sent) # with tone
|
| 64 |
+
g2pm_rst = [YUE_FLAG + k[1] for k in g2pm_rst]
|
| 65 |
+
else:
|
| 66 |
+
g2pm_rst = g2p_mandarin(sent)
|
| 67 |
+
g2pm_rst = [ZH_FLAG + k for k in g2pm_rst]
|
| 68 |
+
for p, w in zip([k[0] for k in zh_words], g2pm_rst):
|
| 69 |
+
transformed_words[p] = w
|
| 70 |
+
|
| 71 |
+
return transformed_words
|
| 72 |
+
|
preprocess/tools/lyric_transcription.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://modelscope.cn/models/iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary
|
| 2 |
+
# https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import time
|
| 6 |
+
from typing import Any, Dict, List, Tuple
|
| 7 |
+
|
| 8 |
+
import librosa
|
| 9 |
+
import numpy as np
|
| 10 |
+
from funasr import AutoModel
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _build_words_with_gaps(raw_words, raw_timestamps, wav_fn: str):
|
| 14 |
+
words, word_durs = [], []
|
| 15 |
+
prev = 0.0
|
| 16 |
+
for w, t in zip(raw_words, raw_timestamps):
|
| 17 |
+
s, e = float(t[0]), float(t[1])
|
| 18 |
+
if s > prev:
|
| 19 |
+
words.append("<SP>")
|
| 20 |
+
word_durs.append(s - prev)
|
| 21 |
+
words.append(w)
|
| 22 |
+
word_durs.append(e - s)
|
| 23 |
+
prev = e
|
| 24 |
+
|
| 25 |
+
wav_len = librosa.get_duration(filename=wav_fn)
|
| 26 |
+
if wav_len > prev:
|
| 27 |
+
if len(words) == 0:
|
| 28 |
+
words.append("<SP>")
|
| 29 |
+
word_durs.append(wav_len)
|
| 30 |
+
return words, word_durs
|
| 31 |
+
if words[-1] != "<SP>":
|
| 32 |
+
words.append("<SP>")
|
| 33 |
+
word_durs.append(wav_len - prev)
|
| 34 |
+
else:
|
| 35 |
+
word_durs[-1] += wav_len - prev
|
| 36 |
+
|
| 37 |
+
return words, word_durs
|
| 38 |
+
|
| 39 |
+
def _word_dur_post_process(words, word_durs, f0):
|
| 40 |
+
"""Post-process word durations using f0 to better place silences.
|
| 41 |
+
"""
|
| 42 |
+
# f0 time grid parameters
|
| 43 |
+
sr = 24000 # f0 sample rate
|
| 44 |
+
hop_length = 480 # f0 hop length
|
| 45 |
+
|
| 46 |
+
# Convert word durations (seconds) to frame boundaries on the f0 grid.
|
| 47 |
+
boundaries = np.cumsum([
|
| 48 |
+
0,
|
| 49 |
+
*[
|
| 50 |
+
int(dur * sr / hop_length)
|
| 51 |
+
for dur in word_durs
|
| 52 |
+
],
|
| 53 |
+
]).tolist()
|
| 54 |
+
|
| 55 |
+
sil_tolerance = 5 # tolerance frames for silence detection
|
| 56 |
+
ext_tolerance = 5 # tolerance frames for vocal extension
|
| 57 |
+
|
| 58 |
+
new_words: list[str] = []
|
| 59 |
+
new_word_durs: list[float] = []
|
| 60 |
+
if words:
|
| 61 |
+
new_words.append(words[0])
|
| 62 |
+
new_word_durs.append(word_durs[0])
|
| 63 |
+
|
| 64 |
+
for i in range(1, len(words)):
|
| 65 |
+
word = words[i]
|
| 66 |
+
if word == "<SP>":
|
| 67 |
+
start_frame = boundaries[i]
|
| 68 |
+
end_frame = boundaries[i + 1]
|
| 69 |
+
|
| 70 |
+
num_frames = end_frame - start_frame
|
| 71 |
+
frame_idx = start_frame
|
| 72 |
+
|
| 73 |
+
# Find first region with at least 5 consecutive "unvoiced" frames.
|
| 74 |
+
unvoiced_count = 0
|
| 75 |
+
while frame_idx < end_frame:
|
| 76 |
+
if f0[frame_idx] <= 1: # unvoiced
|
| 77 |
+
unvoiced_count += 1
|
| 78 |
+
if unvoiced_count >= sil_tolerance:
|
| 79 |
+
frame_idx -= sil_tolerance - 1 # back to the last voiced frame
|
| 80 |
+
break
|
| 81 |
+
else:
|
| 82 |
+
unvoiced_count = 0
|
| 83 |
+
frame_idx += 1
|
| 84 |
+
|
| 85 |
+
voice_frames = frame_idx - start_frame
|
| 86 |
+
|
| 87 |
+
if voice_frames >= int(num_frames * 0.9): # over 90% voiced
|
| 88 |
+
# Treat the whole "<SP>" as silence and merge into previous word.
|
| 89 |
+
new_word_durs[-1] += word_durs[i]
|
| 90 |
+
elif voice_frames >= ext_tolerance: # over 5 frames voiced
|
| 91 |
+
# Split the "<SP>" into two parts: leading silence and tail kept as "<SP>".
|
| 92 |
+
dur = voice_frames * hop_length / sr
|
| 93 |
+
new_word_durs[-1] += dur
|
| 94 |
+
new_words.append("<SP>")
|
| 95 |
+
new_word_durs.append(word_durs[i] - dur)
|
| 96 |
+
else:
|
| 97 |
+
# Too short to adjust, keep as-is.
|
| 98 |
+
new_words.append(word)
|
| 99 |
+
new_word_durs.append(word_durs[i])
|
| 100 |
+
else:
|
| 101 |
+
new_words.append(word)
|
| 102 |
+
new_word_durs.append(word_durs[i])
|
| 103 |
+
|
| 104 |
+
return new_words, new_word_durs
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class _ASRZhModel:
|
| 108 |
+
"""Mandarin/Cantonese ASR wrapper."""
|
| 109 |
+
|
| 110 |
+
def __init__(self, model_path: str, device: str):
|
| 111 |
+
self.model = AutoModel(
|
| 112 |
+
model=model_path,
|
| 113 |
+
disable_update=True,
|
| 114 |
+
device=device,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def process(self, wav_fn):
|
| 118 |
+
out = self.model.generate(wav_fn, output_timestamp=True)[0]
|
| 119 |
+
raw_words = out["text"].replace("@", "").split(" ")
|
| 120 |
+
raw_timestamps = [[t[0] / 1000, t[1] / 1000] for t in out["timestamp"]]
|
| 121 |
+
words, word_durs = _build_words_with_gaps(raw_words, raw_timestamps, wav_fn)
|
| 122 |
+
|
| 123 |
+
if os.path.exists(wav_fn.replace(".wav", "_f0.npy")):
|
| 124 |
+
words, word_durs = _word_dur_post_process(
|
| 125 |
+
words, word_durs, np.load(wav_fn.replace(".wav", "_f0.npy"))
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
return words, word_durs
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class _ASREnModel:
|
| 132 |
+
"""English ASR wrapper for NeMo Parakeet-TDT."""
|
| 133 |
+
|
| 134 |
+
def __init__(self, model_path: str, device: str):
|
| 135 |
+
try:
|
| 136 |
+
import nemo.collections.asr as nemo_asr # type: ignore
|
| 137 |
+
except Exception as e: # pragma: no cover
|
| 138 |
+
raise ImportError(
|
| 139 |
+
"NeMo (nemo_toolkit) is required for ASR English but is not available in this Python env. "
|
| 140 |
+
"Install it in the active environment, then retry."
|
| 141 |
+
) from e
|
| 142 |
+
|
| 143 |
+
self.model = nemo_asr.models.ASRModel.restore_from(
|
| 144 |
+
restore_path=model_path,
|
| 145 |
+
map_location=device,
|
| 146 |
+
)
|
| 147 |
+
self.model.eval()
|
| 148 |
+
|
| 149 |
+
@staticmethod
|
| 150 |
+
def _clean_word(word: str) -> str:
|
| 151 |
+
return re.sub(r"[\?\.,:]", "", word).strip()
|
| 152 |
+
|
| 153 |
+
@staticmethod
|
| 154 |
+
def _extract_word_segments(output: Any) -> List[Dict[str, Any]]:
|
| 155 |
+
ts = getattr(output, "timestamp", None)
|
| 156 |
+
if not ts or not isinstance(ts, dict):
|
| 157 |
+
return []
|
| 158 |
+
word_ts = ts.get("word")
|
| 159 |
+
return word_ts if isinstance(word_ts, list) else []
|
| 160 |
+
|
| 161 |
+
def process(self, wav_fn: str) -> Tuple[List[str], List[float]]:
|
| 162 |
+
outputs = self.model.transcribe(
|
| 163 |
+
[wav_fn],
|
| 164 |
+
timestamps=True,
|
| 165 |
+
batch_size=1,
|
| 166 |
+
num_workers=0,
|
| 167 |
+
)
|
| 168 |
+
output = outputs[0] if outputs else None
|
| 169 |
+
|
| 170 |
+
raw_words: List[str] = []
|
| 171 |
+
raw_timestamps: List[List[float]] = []
|
| 172 |
+
if output is not None:
|
| 173 |
+
for w in self._extract_word_segments(output):
|
| 174 |
+
s, e = float(w.get("start", 0.0)), float(w.get("end", 0.0))
|
| 175 |
+
word = self._clean_word(str(w.get("word", "")))
|
| 176 |
+
if word:
|
| 177 |
+
raw_words.append(word)
|
| 178 |
+
raw_timestamps.append([s, e])
|
| 179 |
+
|
| 180 |
+
words, durs = _build_words_with_gaps(raw_words, raw_timestamps, wav_fn)
|
| 181 |
+
|
| 182 |
+
if os.path.exists(wav_fn.replace(".wav", "_f0.npy")):
|
| 183 |
+
words, durs = _word_dur_post_process(
|
| 184 |
+
words, durs, np.load(wav_fn.replace(".wav", "_f0.npy"))
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
return words, durs
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class LyricTranscriber:
|
| 191 |
+
"""Transcribe lyrics from singing voice segment
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
def __init__(
|
| 195 |
+
self,
|
| 196 |
+
zh_model_path: str,
|
| 197 |
+
en_model_path: str,
|
| 198 |
+
device: str = "cuda",
|
| 199 |
+
*,
|
| 200 |
+
verbose: bool = True,
|
| 201 |
+
):
|
| 202 |
+
"""Initialize lyric transcriber.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
zh_model_path (str): Path to the Chinese model file.
|
| 206 |
+
en_model_path (str): Path to the English model file.
|
| 207 |
+
device (str): Device to use for tensor operations.
|
| 208 |
+
verbose (bool): Whether to print verbose logs.
|
| 209 |
+
"""
|
| 210 |
+
self.verbose = verbose
|
| 211 |
+
self.device = device
|
| 212 |
+
self.zh_model_path = zh_model_path
|
| 213 |
+
self.en_model_path = en_model_path
|
| 214 |
+
|
| 215 |
+
if self.verbose:
|
| 216 |
+
print(
|
| 217 |
+
"[lyric transcription] init: start:",
|
| 218 |
+
f"device={device}",
|
| 219 |
+
f"model_path={zh_model_path}",
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Always initialize Chinese ASR.
|
| 223 |
+
self.zh_model = _ASRZhModel(device=device, model_path=zh_model_path)
|
| 224 |
+
|
| 225 |
+
# English ASR will be lazily initialized on first English request to avoid long waiting cost when importing NeMo
|
| 226 |
+
self.en_model = None
|
| 227 |
+
|
| 228 |
+
if self.verbose:
|
| 229 |
+
print("[lyric transcription] init: success")
|
| 230 |
+
|
| 231 |
+
def process(self, wav_fn, language: str | None = "Mandarin", *, verbose: bool | None = None):
|
| 232 |
+
""" Lyric transcriber process
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
wav_fn (str): Path to the audio file.
|
| 236 |
+
language (str | None): Language of the audio. Defaults to "Mandarin". Supports "Mandarin", "Cantonese" and "English".
|
| 237 |
+
verbose (bool | None): Whether to print verbose logs. Defaults to None.
|
| 238 |
+
"""
|
| 239 |
+
v = self.verbose if verbose is None else verbose
|
| 240 |
+
if language not in {"Mandarin", "Cantonese", "English"}:
|
| 241 |
+
raise ValueError(f"Unsupported language: {language}, should be one of ['Mandarin', 'Cantonese', 'English']")
|
| 242 |
+
if v:
|
| 243 |
+
print(f"[lyric transcription] process: start: wav_fn={wav_fn} language={language}")
|
| 244 |
+
t0 = time.time()
|
| 245 |
+
|
| 246 |
+
lang = (language or "auto").lower()
|
| 247 |
+
if lang in {"english"}:
|
| 248 |
+
if self.en_model is None:
|
| 249 |
+
# Lazy-load NeMo model only when English is actually used.
|
| 250 |
+
if v:
|
| 251 |
+
print("[lyric transcription] init English ASR, please make sure NeMo is installed")
|
| 252 |
+
self.en_model = _ASREnModel(model_path=self.en_model_path, device=self.device)
|
| 253 |
+
out = self.en_model.process(wav_fn)
|
| 254 |
+
else:
|
| 255 |
+
out = self.zh_model.process(wav_fn)
|
| 256 |
+
|
| 257 |
+
if v:
|
| 258 |
+
words, durs = out
|
| 259 |
+
n_words = len(words) if isinstance(words, list) else 0
|
| 260 |
+
dur_sum = float(sum(durs)) if isinstance(durs, list) else 0.0
|
| 261 |
+
dt = time.time() - t0
|
| 262 |
+
print(
|
| 263 |
+
"[lyric transcription] process: done:",
|
| 264 |
+
f"n_words={n_words}",
|
| 265 |
+
f"dur_sum={dur_sum:.3f}s",
|
| 266 |
+
f"time={dt:.3f}s",
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
return out
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
if __name__ == "__main__":
|
| 273 |
+
m = LyricTranscriber(
|
| 274 |
+
zh_model_path="pretrained_models/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
| 275 |
+
en_model_path="pretrained_models/parakeet-tdt-0.6b-v2/parakeet-tdt-0.6b-v2.nemo",
|
| 276 |
+
device="cuda"
|
| 277 |
+
)
|
| 278 |
+
print(m.process("example/test/asr_zh.wav", language="Mandarin"))
|
| 279 |
+
print(m.process("example/test/asr_en.wav", language="English"))
|
preprocess/tools/midi_parser.py
ADDED
|
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SoulX-Singer MIDI <-> metadata converter.
|
| 3 |
+
|
| 4 |
+
Converts between SoulX-Singer-style metadata JSON (with note_text, note_dur,
|
| 5 |
+
note_pitch, note_type per segment) and standard MIDI files. Uses an internal
|
| 6 |
+
Note dataclass (start_s, note_dur, note_text, note_pitch, note_type) as the
|
| 7 |
+
intermediate representation.
|
| 8 |
+
"""
|
| 9 |
+
import os
|
| 10 |
+
import json
|
| 11 |
+
import shutil
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from typing import Any, List, Tuple, Union
|
| 14 |
+
|
| 15 |
+
import librosa
|
| 16 |
+
import mido
|
| 17 |
+
from soundfile import write
|
| 18 |
+
|
| 19 |
+
from .f0_extraction import F0Extractor
|
| 20 |
+
from .g2p import g2p_transform
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Audio and segmenting constants (used by _edit_data_to_meta)
|
| 24 |
+
SAMPLE_RATE = 44100
|
| 25 |
+
DEFAULT_LANGUAGE = "Mandarin"
|
| 26 |
+
MAX_GAP_SEC = 5.0 # gap (sec) above which we start a new segment
|
| 27 |
+
MAX_SEGMENT_DUR_SUM_SEC = 60.0 # max cumulative note duration per segment (sec)
|
| 28 |
+
MIN_GAP_THRESHOLD_SEC = 0.001 # ignore gaps smaller than this
|
| 29 |
+
LONG_SILENCE_THRESHOLD_SEC = 0.05 # treat as separate <SP> if gap larger
|
| 30 |
+
MAX_LEADING_SP_DUR_SEC = 2.0 # cap leading silence in a segment to this (sec)
|
| 31 |
+
DEFAULT_RMVPE_MODEL_PATH = "pretrained_models/SoulX-Singer-Preprocess/rmvpe/rmvpe.pt"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class Note:
|
| 36 |
+
"""Single note: text, duration (seconds), pitch (MIDI), type. start_s is absolute start time in seconds (for ordering / MIDI)."""
|
| 37 |
+
start_s: float
|
| 38 |
+
note_dur: float
|
| 39 |
+
note_text: str
|
| 40 |
+
note_pitch: int
|
| 41 |
+
note_type: int
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def end_s(self) -> float:
|
| 45 |
+
return self.start_s + self.note_dur
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def remove_duplicate_segments(meta_data: List[dict]) -> None:
|
| 50 |
+
"""Merge consecutive identical notes (same text, pitch, type) within each segment. Mutates meta_data in place."""
|
| 51 |
+
for idx, segment in enumerate(meta_data):
|
| 52 |
+
texts = segment["note_text"]
|
| 53 |
+
durs = segment["note_dur"]
|
| 54 |
+
pitches = segment["note_pitch"]
|
| 55 |
+
types = segment["note_type"]
|
| 56 |
+
new_texts = []
|
| 57 |
+
new_durs = []
|
| 58 |
+
new_pitches = []
|
| 59 |
+
new_types = []
|
| 60 |
+
for i in range(len(texts)):
|
| 61 |
+
if i == 0:
|
| 62 |
+
new_texts.append(texts[i])
|
| 63 |
+
new_durs.append(durs[i])
|
| 64 |
+
new_pitches.append(pitches[i])
|
| 65 |
+
new_types.append(types[i])
|
| 66 |
+
continue
|
| 67 |
+
t, d, p, ty = texts[i], durs[i], pitches[i], types[i]
|
| 68 |
+
if t == "<SP>" and texts[i - 1] == "<SP>":
|
| 69 |
+
new_durs[-1] += d
|
| 70 |
+
continue
|
| 71 |
+
if t == texts[i - 1] and p == pitches[i - 1] and ty == types[i - 1]:
|
| 72 |
+
new_durs[-1] += d
|
| 73 |
+
else:
|
| 74 |
+
new_texts.append(t)
|
| 75 |
+
new_durs.append(d)
|
| 76 |
+
new_pitches.append(p)
|
| 77 |
+
new_types.append(ty)
|
| 78 |
+
meta_data[idx]["note_text"] = new_texts
|
| 79 |
+
meta_data[idx]["note_dur"] = new_durs
|
| 80 |
+
meta_data[idx]["note_pitch"] = new_pitches
|
| 81 |
+
meta_data[idx]["note_type"] = new_types
|
| 82 |
+
|
| 83 |
+
def meta2notes(meta_path: str) -> List[Note]:
|
| 84 |
+
"""Parse SoulX-Singer metadata JSON into a flat list of Note (absolute start_s)."""
|
| 85 |
+
with open(meta_path, "r", encoding="utf-8") as f:
|
| 86 |
+
segments = json.load(f)
|
| 87 |
+
if not isinstance(segments, list):
|
| 88 |
+
raise ValueError(f"Metadata must be a list of segments, got {type(segments).__name__}")
|
| 89 |
+
if not segments:
|
| 90 |
+
raise ValueError("Metadata has no segments.")
|
| 91 |
+
|
| 92 |
+
notes: List[Note] = []
|
| 93 |
+
for seg in segments:
|
| 94 |
+
offset_s = seg["time"][0] / 1000
|
| 95 |
+
words = [str(x).replace("<AP>", "<SP>") for i, x in enumerate(seg["text"].split())]
|
| 96 |
+
word_durs = [float(x) for x in seg["duration"].split()]
|
| 97 |
+
pitches = [int(x) for x in seg["note_pitch"].split()]
|
| 98 |
+
types = [int(x) if words[i] != "<SP>" else 1 for i, x in enumerate(seg["note_type"].split())]
|
| 99 |
+
if len(words) != len(word_durs) or len(word_durs) != len(pitches) or len(pitches) != len(types):
|
| 100 |
+
raise ValueError(
|
| 101 |
+
f"Length mismatch in segment {seg.get('item_name', '?')}: "
|
| 102 |
+
"note_text, note_dur, note_pitch, note_type must have same length"
|
| 103 |
+
)
|
| 104 |
+
current_s = offset_s
|
| 105 |
+
for text, dur, pitch, type_ in zip(words, word_durs, pitches, types):
|
| 106 |
+
notes.append(
|
| 107 |
+
Note(
|
| 108 |
+
start_s=current_s,
|
| 109 |
+
note_dur=float(dur),
|
| 110 |
+
note_text=str(text),
|
| 111 |
+
note_pitch=int(pitch),
|
| 112 |
+
note_type=int(type_),
|
| 113 |
+
)
|
| 114 |
+
)
|
| 115 |
+
current_s += float(dur)
|
| 116 |
+
return notes
|
| 117 |
+
|
| 118 |
+
def _append_segment_to_meta(
|
| 119 |
+
meta_path_str: str,
|
| 120 |
+
cut_wavs_output_dir: str,
|
| 121 |
+
vocal_file: str,
|
| 122 |
+
audio_data: Any,
|
| 123 |
+
meta_data: List[dict],
|
| 124 |
+
note_start: List[float],
|
| 125 |
+
note_end: List[float],
|
| 126 |
+
note_text: List[Any],
|
| 127 |
+
note_pitch: List[Any],
|
| 128 |
+
note_type: List[Any],
|
| 129 |
+
note_dur: List[float],
|
| 130 |
+
end_time_ms_override: float | None = None,
|
| 131 |
+
) -> None:
|
| 132 |
+
"""Write one segment wav and append one segment dict to meta_data. Caller clears note_* lists after."""
|
| 133 |
+
base_name = os.path.splitext(os.path.basename(meta_path_str))[0]
|
| 134 |
+
item_name = f"{base_name}_{len(meta_data)}"
|
| 135 |
+
wav_fn = os.path.join(cut_wavs_output_dir, f"{item_name}.wav")
|
| 136 |
+
start_ms = int(note_start[0] * 1000)
|
| 137 |
+
end_ms = (
|
| 138 |
+
int(end_time_ms_override)
|
| 139 |
+
if end_time_ms_override is not None
|
| 140 |
+
else int(note_end[-1] * 1000)
|
| 141 |
+
)
|
| 142 |
+
start_sample = int(note_start[0] * SAMPLE_RATE)
|
| 143 |
+
end_sample = int(note_end[-1] * SAMPLE_RATE)
|
| 144 |
+
write(wav_fn, audio_data[start_sample:end_sample], SAMPLE_RATE)
|
| 145 |
+
meta_data.append({
|
| 146 |
+
"item_name": item_name,
|
| 147 |
+
"wav_fn": wav_fn,
|
| 148 |
+
"origin_wav_fn": vocal_file,
|
| 149 |
+
"start_time_ms": start_ms,
|
| 150 |
+
"end_time_ms": end_ms,
|
| 151 |
+
"language": DEFAULT_LANGUAGE,
|
| 152 |
+
"note_text": list(note_text),
|
| 153 |
+
"note_pitch": list(note_pitch),
|
| 154 |
+
"note_type": list(note_type),
|
| 155 |
+
"note_dur": list(note_dur),
|
| 156 |
+
})
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def convert_meta(meta_data: List[dict], rmvpe_model_path, device="cuda"):
|
| 160 |
+
pitch_extractor = F0Extractor(rmvpe_model_path, device=device, verbose=False)
|
| 161 |
+
converted_data = []
|
| 162 |
+
|
| 163 |
+
for item in meta_data:
|
| 164 |
+
wav_fn = item.get("wav_fn")
|
| 165 |
+
if not wav_fn or not os.path.isfile(wav_fn):
|
| 166 |
+
raise FileNotFoundError(f"Segment wav file not found: {wav_fn}")
|
| 167 |
+
f0 = pitch_extractor.process(wav_fn)
|
| 168 |
+
converted_item = {
|
| 169 |
+
"index": item.get("item_name"),
|
| 170 |
+
"language": item.get("language"),
|
| 171 |
+
"time": [item.get("start_time_ms", 0), item.get("end_time_ms", sum(item["note_dur"]) * 1000)],
|
| 172 |
+
"duration": " ".join(str(round(x, 2)) for x in item.get("note_dur", [])),
|
| 173 |
+
"text": " ".join(item.get("note_text", [])),
|
| 174 |
+
"phoneme": " ".join(g2p_transform(item.get("note_text", []), DEFAULT_LANGUAGE)),
|
| 175 |
+
"note_pitch": " ".join(str(x) for x in item.get("note_pitch", [])),
|
| 176 |
+
"note_type": " ".join(str(x) for x in item.get("note_type", [])),
|
| 177 |
+
"f0": " ".join(str(round(float(x), 1)) for x in f0),
|
| 178 |
+
}
|
| 179 |
+
converted_data.append(converted_item)
|
| 180 |
+
|
| 181 |
+
return converted_data
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _edit_data_to_meta(
|
| 185 |
+
meta_path_str: str,
|
| 186 |
+
edit_data: List[dict],
|
| 187 |
+
vocal_file: str,
|
| 188 |
+
rmvpe_model_path: str | None = None,
|
| 189 |
+
device: str = "cuda",
|
| 190 |
+
) -> None:
|
| 191 |
+
"""Write SoulX-Singer metadata JSON from edit_data (list of {start, end, note_text, note_pitch, note_type})."""
|
| 192 |
+
# Use a fixed temporary directory for cut wavs
|
| 193 |
+
cut_wavs_output_dir = os.path.join(os.path.dirname(vocal_file), "cut_wavs_tmp")
|
| 194 |
+
os.makedirs(cut_wavs_output_dir, exist_ok=True)
|
| 195 |
+
|
| 196 |
+
note_text: List[Any] = []
|
| 197 |
+
note_pitch: List[Any] = []
|
| 198 |
+
note_type: List[Any] = []
|
| 199 |
+
note_dur: List[float] = []
|
| 200 |
+
note_start: List[float] = []
|
| 201 |
+
note_end: List[float] = []
|
| 202 |
+
prev_end = 0.0
|
| 203 |
+
meta_data: List[dict] = []
|
| 204 |
+
audio_data, _ = librosa.load(vocal_file, sr=SAMPLE_RATE, mono=True)
|
| 205 |
+
dur_sum = 0.0
|
| 206 |
+
|
| 207 |
+
for entry in edit_data:
|
| 208 |
+
start = float(entry["start"])
|
| 209 |
+
end = float(entry["end"])
|
| 210 |
+
text = entry["note_text"]
|
| 211 |
+
pitch = entry["note_pitch"]
|
| 212 |
+
type_ = entry["note_type"]
|
| 213 |
+
|
| 214 |
+
if text == "" or pitch == "" or type_ == "":
|
| 215 |
+
note_text.append("<SP>")
|
| 216 |
+
note_pitch.append(0)
|
| 217 |
+
note_type.append(1)
|
| 218 |
+
note_dur.append(end - start)
|
| 219 |
+
note_start.append(start)
|
| 220 |
+
note_end.append(end)
|
| 221 |
+
prev_end = end
|
| 222 |
+
dur_sum += end - start
|
| 223 |
+
continue
|
| 224 |
+
|
| 225 |
+
if (
|
| 226 |
+
len(note_text) > 0
|
| 227 |
+
and note_text[-1] == "<SP>"
|
| 228 |
+
and note_dur[-1] > MAX_LEADING_SP_DUR_SEC
|
| 229 |
+
):
|
| 230 |
+
cut_time = note_dur[-1] - MAX_LEADING_SP_DUR_SEC
|
| 231 |
+
note_dur[-1] = MAX_LEADING_SP_DUR_SEC
|
| 232 |
+
end_ms_override = note_end[-1] * 1000 - cut_time * 1000
|
| 233 |
+
_append_segment_to_meta(
|
| 234 |
+
meta_path_str,
|
| 235 |
+
cut_wavs_output_dir,
|
| 236 |
+
vocal_file,
|
| 237 |
+
audio_data,
|
| 238 |
+
meta_data,
|
| 239 |
+
note_start,
|
| 240 |
+
note_end,
|
| 241 |
+
note_text,
|
| 242 |
+
note_pitch,
|
| 243 |
+
note_type,
|
| 244 |
+
note_dur,
|
| 245 |
+
end_time_ms_override=end_ms_override,
|
| 246 |
+
)
|
| 247 |
+
note_text = []
|
| 248 |
+
note_pitch = []
|
| 249 |
+
note_type = []
|
| 250 |
+
note_dur = []
|
| 251 |
+
note_start = []
|
| 252 |
+
note_end = []
|
| 253 |
+
prev_end = start
|
| 254 |
+
dur_sum = 0.0
|
| 255 |
+
|
| 256 |
+
gap_from_prev = start - prev_end
|
| 257 |
+
gap_from_last_note = (start - note_end[-1]) if note_end else 0.0
|
| 258 |
+
if (
|
| 259 |
+
gap_from_prev >= MAX_GAP_SEC
|
| 260 |
+
or gap_from_last_note >= MAX_GAP_SEC
|
| 261 |
+
or dur_sum >= MAX_SEGMENT_DUR_SUM_SEC
|
| 262 |
+
):
|
| 263 |
+
if len(note_text) > 0:
|
| 264 |
+
_append_segment_to_meta(
|
| 265 |
+
meta_path_str,
|
| 266 |
+
cut_wavs_output_dir,
|
| 267 |
+
vocal_file,
|
| 268 |
+
audio_data,
|
| 269 |
+
meta_data,
|
| 270 |
+
note_start,
|
| 271 |
+
note_end,
|
| 272 |
+
note_text,
|
| 273 |
+
note_pitch,
|
| 274 |
+
note_type,
|
| 275 |
+
note_dur,
|
| 276 |
+
)
|
| 277 |
+
note_text = []
|
| 278 |
+
note_pitch = []
|
| 279 |
+
note_type = []
|
| 280 |
+
note_dur = []
|
| 281 |
+
note_start = []
|
| 282 |
+
note_end = []
|
| 283 |
+
prev_end = start
|
| 284 |
+
dur_sum = 0.0
|
| 285 |
+
|
| 286 |
+
if start - prev_end > MIN_GAP_THRESHOLD_SEC:
|
| 287 |
+
if start - prev_end > LONG_SILENCE_THRESHOLD_SEC or len(note_text) == 0:
|
| 288 |
+
note_text.append("<SP>")
|
| 289 |
+
note_pitch.append(0)
|
| 290 |
+
note_type.append(1)
|
| 291 |
+
note_dur.append(start - prev_end)
|
| 292 |
+
note_start.append(prev_end)
|
| 293 |
+
note_end.append(start)
|
| 294 |
+
else:
|
| 295 |
+
if len(note_dur) > 0:
|
| 296 |
+
note_dur[-1] += start - prev_end
|
| 297 |
+
note_end[-1] = start
|
| 298 |
+
|
| 299 |
+
prev_end = end
|
| 300 |
+
note_text.append(text)
|
| 301 |
+
note_pitch.append(int(pitch))
|
| 302 |
+
note_type.append(int(type_))
|
| 303 |
+
note_dur.append(end - start)
|
| 304 |
+
note_start.append(start)
|
| 305 |
+
note_end.append(end)
|
| 306 |
+
dur_sum += end - start
|
| 307 |
+
|
| 308 |
+
if len(note_text) > 0:
|
| 309 |
+
_append_segment_to_meta(
|
| 310 |
+
meta_path_str,
|
| 311 |
+
cut_wavs_output_dir,
|
| 312 |
+
vocal_file,
|
| 313 |
+
audio_data,
|
| 314 |
+
meta_data,
|
| 315 |
+
note_start,
|
| 316 |
+
note_end,
|
| 317 |
+
note_text,
|
| 318 |
+
note_pitch,
|
| 319 |
+
note_type,
|
| 320 |
+
note_dur,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
remove_duplicate_segments(meta_data)
|
| 324 |
+
|
| 325 |
+
_rmvpe_path = rmvpe_model_path or DEFAULT_RMVPE_MODEL_PATH
|
| 326 |
+
converted_data = convert_meta(meta_data, _rmvpe_path, device)
|
| 327 |
+
|
| 328 |
+
with open(meta_path_str, "w", encoding="utf-8") as f:
|
| 329 |
+
json.dump(converted_data, f, ensure_ascii=False, indent=2)
|
| 330 |
+
|
| 331 |
+
# Clean up temporary cut wavs directory
|
| 332 |
+
try:
|
| 333 |
+
shutil.rmtree(cut_wavs_output_dir, ignore_errors=True)
|
| 334 |
+
except Exception:
|
| 335 |
+
pass
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def notes2meta(
|
| 339 |
+
notes: List[Note],
|
| 340 |
+
meta_path: str,
|
| 341 |
+
vocal_file: str,
|
| 342 |
+
rmvpe_model_path: str | None = None,
|
| 343 |
+
device: str = "cuda",
|
| 344 |
+
) -> None:
|
| 345 |
+
"""Write SoulX-Singer metadata JSON from a list of Note (segmenting + wav cuts)."""
|
| 346 |
+
edit_data = [
|
| 347 |
+
{
|
| 348 |
+
"start": n.start_s,
|
| 349 |
+
"end": n.end_s,
|
| 350 |
+
"note_text": n.note_text,
|
| 351 |
+
"note_pitch": str(n.note_pitch),
|
| 352 |
+
"note_type": str(n.note_type),
|
| 353 |
+
}
|
| 354 |
+
for n in notes
|
| 355 |
+
]
|
| 356 |
+
_edit_data_to_meta(
|
| 357 |
+
str(meta_path),
|
| 358 |
+
edit_data,
|
| 359 |
+
vocal_file,
|
| 360 |
+
rmvpe_model_path=rmvpe_model_path,
|
| 361 |
+
device=device,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
@dataclass(frozen=True)
|
| 366 |
+
class MidiDefaults:
|
| 367 |
+
ticks_per_beat: int = 500
|
| 368 |
+
tempo: int = 500000 # microseconds per beat (120 BPM)
|
| 369 |
+
time_signature: Tuple[int, int] = (4, 4)
|
| 370 |
+
velocity: int = 64
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def _seconds_to_ticks(seconds: float, ticks_per_beat: int, tempo: int) -> int:
|
| 374 |
+
return int(round(seconds * ticks_per_beat * 1_000_000 / tempo))
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def notes2midi(
|
| 378 |
+
notes: List[Note],
|
| 379 |
+
midi_path: str,
|
| 380 |
+
defaults: MidiDefaults | None = None,
|
| 381 |
+
) -> None:
|
| 382 |
+
"""Write MIDI file from a list of Note."""
|
| 383 |
+
defaults = defaults or MidiDefaults()
|
| 384 |
+
if not notes:
|
| 385 |
+
raise ValueError("Empty note list.")
|
| 386 |
+
|
| 387 |
+
events: List[Tuple[int, int, Union[mido.Message, mido.MetaMessage]]] = []
|
| 388 |
+
for n in notes:
|
| 389 |
+
start_s = n.start_s
|
| 390 |
+
end_s = n.end_s
|
| 391 |
+
if end_s <= start_s:
|
| 392 |
+
continue
|
| 393 |
+
|
| 394 |
+
start_ticks = _seconds_to_ticks(
|
| 395 |
+
start_s, defaults.ticks_per_beat, defaults.tempo
|
| 396 |
+
)
|
| 397 |
+
end_ticks = _seconds_to_ticks(
|
| 398 |
+
end_s, defaults.ticks_per_beat, defaults.tempo
|
| 399 |
+
)
|
| 400 |
+
if end_ticks <= start_ticks:
|
| 401 |
+
end_ticks = start_ticks + 1
|
| 402 |
+
|
| 403 |
+
lyric = n.note_text
|
| 404 |
+
try:
|
| 405 |
+
lyric = lyric.encode("utf-8").decode("latin1")
|
| 406 |
+
except (UnicodeEncodeError, UnicodeDecodeError):
|
| 407 |
+
pass
|
| 408 |
+
if n.note_type == 3:
|
| 409 |
+
lyric = "-"
|
| 410 |
+
|
| 411 |
+
events.append(
|
| 412 |
+
(start_ticks, 1, mido.MetaMessage("lyrics", text=lyric, time=0))
|
| 413 |
+
)
|
| 414 |
+
events.append(
|
| 415 |
+
(
|
| 416 |
+
start_ticks,
|
| 417 |
+
2,
|
| 418 |
+
mido.Message(
|
| 419 |
+
"note_on",
|
| 420 |
+
note=n.note_pitch,
|
| 421 |
+
velocity=defaults.velocity,
|
| 422 |
+
time=0,
|
| 423 |
+
),
|
| 424 |
+
)
|
| 425 |
+
)
|
| 426 |
+
events.append(
|
| 427 |
+
(
|
| 428 |
+
end_ticks,
|
| 429 |
+
0,
|
| 430 |
+
mido.Message("note_off", note=n.note_pitch, velocity=0, time=0),
|
| 431 |
+
)
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
events.sort(key=lambda x: (x[0], x[1]))
|
| 435 |
+
|
| 436 |
+
mid = mido.MidiFile(ticks_per_beat=defaults.ticks_per_beat)
|
| 437 |
+
track = mido.MidiTrack()
|
| 438 |
+
mid.tracks.append(track)
|
| 439 |
+
|
| 440 |
+
track.append(mido.MetaMessage("set_tempo", tempo=defaults.tempo, time=0))
|
| 441 |
+
track.append(
|
| 442 |
+
mido.MetaMessage(
|
| 443 |
+
"time_signature",
|
| 444 |
+
numerator=defaults.time_signature[0],
|
| 445 |
+
denominator=defaults.time_signature[1],
|
| 446 |
+
time=0,
|
| 447 |
+
)
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
last_tick = 0
|
| 451 |
+
for tick, _, msg in events:
|
| 452 |
+
msg.time = max(0, tick - last_tick)
|
| 453 |
+
track.append(msg)
|
| 454 |
+
last_tick = tick
|
| 455 |
+
|
| 456 |
+
track.append(mido.MetaMessage("end_of_track", time=0))
|
| 457 |
+
mid.save(midi_path)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def midi2notes(midi_path: str) -> List[Note]:
|
| 461 |
+
"""Parse MIDI file into a list of Note. Merges all tracks; tempo from last set_tempo event."""
|
| 462 |
+
mid = mido.MidiFile(midi_path)
|
| 463 |
+
ticks_per_beat = mid.ticks_per_beat
|
| 464 |
+
tempo = 500000
|
| 465 |
+
|
| 466 |
+
raw_notes: List[dict] = []
|
| 467 |
+
lyrics: List[Tuple[int, str]] = []
|
| 468 |
+
|
| 469 |
+
for track in mid.tracks:
|
| 470 |
+
abs_ticks = 0
|
| 471 |
+
active = {}
|
| 472 |
+
for msg in track:
|
| 473 |
+
abs_ticks += msg.time
|
| 474 |
+
if msg.type == "set_tempo":
|
| 475 |
+
tempo = msg.tempo
|
| 476 |
+
elif msg.type == "lyrics":
|
| 477 |
+
text = msg.text
|
| 478 |
+
try:
|
| 479 |
+
text = text.encode("latin1").decode("utf-8")
|
| 480 |
+
except Exception:
|
| 481 |
+
pass
|
| 482 |
+
lyrics.append((abs_ticks, text))
|
| 483 |
+
elif msg.type == "note_on":
|
| 484 |
+
key = (msg.channel, msg.note)
|
| 485 |
+
if msg.velocity > 0:
|
| 486 |
+
active[key] = (abs_ticks, msg.velocity)
|
| 487 |
+
else:
|
| 488 |
+
if key in active:
|
| 489 |
+
start_ticks, vel = active.pop(key)
|
| 490 |
+
raw_notes.append(
|
| 491 |
+
{
|
| 492 |
+
"midi": msg.note,
|
| 493 |
+
"start_ticks": start_ticks,
|
| 494 |
+
"duration_ticks": abs_ticks - start_ticks,
|
| 495 |
+
"velocity": vel,
|
| 496 |
+
"lyric": "",
|
| 497 |
+
}
|
| 498 |
+
)
|
| 499 |
+
elif msg.type == "note_off":
|
| 500 |
+
key = (msg.channel, msg.note)
|
| 501 |
+
if key in active:
|
| 502 |
+
start_ticks, vel = active.pop(key)
|
| 503 |
+
raw_notes.append(
|
| 504 |
+
{
|
| 505 |
+
"midi": msg.note,
|
| 506 |
+
"start_ticks": start_ticks,
|
| 507 |
+
"duration_ticks": abs_ticks - start_ticks,
|
| 508 |
+
"velocity": vel,
|
| 509 |
+
"lyric": "",
|
| 510 |
+
}
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
if not raw_notes:
|
| 514 |
+
raise ValueError("No notes found in MIDI file")
|
| 515 |
+
|
| 516 |
+
for n in raw_notes:
|
| 517 |
+
n["end_ticks"] = n["start_ticks"] + n["duration_ticks"]
|
| 518 |
+
|
| 519 |
+
raw_notes.sort(key=lambda n: n["start_ticks"])
|
| 520 |
+
lyrics.sort(key=lambda x: x[0])
|
| 521 |
+
|
| 522 |
+
trimmed = []
|
| 523 |
+
for note in raw_notes:
|
| 524 |
+
while trimmed:
|
| 525 |
+
prev = trimmed[-1]
|
| 526 |
+
if note["start_ticks"] < prev["end_ticks"]:
|
| 527 |
+
prev["end_ticks"] = note["start_ticks"]
|
| 528 |
+
prev["duration_ticks"] = prev["end_ticks"] - prev["start_ticks"]
|
| 529 |
+
if prev["duration_ticks"] <= 0:
|
| 530 |
+
trimmed.pop()
|
| 531 |
+
continue
|
| 532 |
+
break
|
| 533 |
+
trimmed.append(note)
|
| 534 |
+
raw_notes = trimmed
|
| 535 |
+
|
| 536 |
+
tolerance = ticks_per_beat // 100
|
| 537 |
+
lyric_idx = 0
|
| 538 |
+
for note in raw_notes:
|
| 539 |
+
while lyric_idx < len(lyrics) and lyrics[lyric_idx][0] < note["start_ticks"] - tolerance:
|
| 540 |
+
lyric_idx += 1
|
| 541 |
+
if lyric_idx < len(lyrics):
|
| 542 |
+
lyric_ticks, lyric_text = lyrics[lyric_idx]
|
| 543 |
+
if abs(lyric_ticks - note["start_ticks"]) <= tolerance:
|
| 544 |
+
note["lyric"] = lyric_text
|
| 545 |
+
lyric_idx += 1
|
| 546 |
+
|
| 547 |
+
def ticks_to_seconds(ticks: int) -> float:
|
| 548 |
+
return (ticks / ticks_per_beat) * (tempo / 1_000_000)
|
| 549 |
+
|
| 550 |
+
result: List[Note] = []
|
| 551 |
+
prev_end_s = 0.0
|
| 552 |
+
for idx, n in enumerate(raw_notes):
|
| 553 |
+
start_s = ticks_to_seconds(n["start_ticks"])
|
| 554 |
+
end_s = ticks_to_seconds(n["end_ticks"])
|
| 555 |
+
if prev_end_s > start_s:
|
| 556 |
+
start_s = prev_end_s
|
| 557 |
+
dur_s = end_s - start_s
|
| 558 |
+
if dur_s <= 0:
|
| 559 |
+
continue
|
| 560 |
+
|
| 561 |
+
lyric = n.get("lyric", "")
|
| 562 |
+
if not lyric:
|
| 563 |
+
tp = 2
|
| 564 |
+
text = "啦"
|
| 565 |
+
elif lyric == "<SP>":
|
| 566 |
+
tp = 1
|
| 567 |
+
text = "<SP>"
|
| 568 |
+
elif lyric == "-":
|
| 569 |
+
tp = 3
|
| 570 |
+
text = raw_notes[idx - 1].get("lyric", "-") if idx > 0 else "-"
|
| 571 |
+
else:
|
| 572 |
+
tp = 2
|
| 573 |
+
text = lyric
|
| 574 |
+
|
| 575 |
+
result.append(
|
| 576 |
+
Note(
|
| 577 |
+
start_s=start_s,
|
| 578 |
+
note_dur=dur_s,
|
| 579 |
+
note_text=text,
|
| 580 |
+
note_pitch=n["midi"],
|
| 581 |
+
note_type=tp,
|
| 582 |
+
)
|
| 583 |
+
)
|
| 584 |
+
prev_end_s = end_s
|
| 585 |
+
|
| 586 |
+
return result
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
def meta2midi(meta_path: str, midi_path: str, defaults: MidiDefaults | None = None) -> None:
|
| 590 |
+
"""Convert SoulX-Singer metadata JSON to MIDI file (meta -> List[Note] -> midi)."""
|
| 591 |
+
notes = meta2notes(meta_path)
|
| 592 |
+
notes2midi(notes, midi_path, defaults)
|
| 593 |
+
print(f"Saved MIDI to {midi_path}")
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def midi2meta(
|
| 597 |
+
midi_path: str,
|
| 598 |
+
meta_path: str,
|
| 599 |
+
vocal_file: str,
|
| 600 |
+
rmvpe_model_path: str | None = None,
|
| 601 |
+
device: str = "cuda",
|
| 602 |
+
) -> None:
|
| 603 |
+
"""Convert MIDI file to SoulX-Singer metadata JSON (midi -> List[Note] -> meta)."""
|
| 604 |
+
meta_dir = os.path.dirname(meta_path)
|
| 605 |
+
if meta_dir:
|
| 606 |
+
os.makedirs(meta_dir, exist_ok=True)
|
| 607 |
+
# cut_wavs will be written to a fixed temporary directory inside _edit_data_to_meta
|
| 608 |
+
notes = midi2notes(midi_path)
|
| 609 |
+
notes2meta(
|
| 610 |
+
notes,
|
| 611 |
+
meta_path,
|
| 612 |
+
vocal_file,
|
| 613 |
+
rmvpe_model_path=rmvpe_model_path,
|
| 614 |
+
device=device,
|
| 615 |
+
)
|
| 616 |
+
print(f"Saved Meta to {meta_path}")
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
if __name__ == "__main__":
|
| 620 |
+
import argparse
|
| 621 |
+
|
| 622 |
+
parser = argparse.ArgumentParser(
|
| 623 |
+
description="Convert SoulX-Singer metadata JSON <-> MIDI."
|
| 624 |
+
)
|
| 625 |
+
parser.add_argument("--meta", type=str, help="Path to metadata JSON")
|
| 626 |
+
parser.add_argument("--midi", type=str, help="Path to MIDI file")
|
| 627 |
+
parser.add_argument("--vocal", type=str, help="Path to vocal wav (for midi2meta)")
|
| 628 |
+
parser.add_argument(
|
| 629 |
+
"--meta2midi",
|
| 630 |
+
action="store_true",
|
| 631 |
+
help="Convert meta -> midi (requires --meta and --midi)",
|
| 632 |
+
)
|
| 633 |
+
parser.add_argument(
|
| 634 |
+
"--midi2meta",
|
| 635 |
+
action="store_true",
|
| 636 |
+
help="Convert midi -> meta (requires --midi, --meta, --vocal, --cut_wavs_dir)",
|
| 637 |
+
)
|
| 638 |
+
parser.add_argument(
|
| 639 |
+
"--rmvpe_model_path",
|
| 640 |
+
type=str,
|
| 641 |
+
help="Path to RMVPE model",
|
| 642 |
+
default="pretrained_models/SoulX-Singer-Preprocess/rmvpe/rmvpe.pt",
|
| 643 |
+
)
|
| 644 |
+
parser.add_argument(
|
| 645 |
+
"--device",
|
| 646 |
+
type=str,
|
| 647 |
+
help="Device to use for RMVPE",
|
| 648 |
+
default="cuda",
|
| 649 |
+
)
|
| 650 |
+
args = parser.parse_args()
|
| 651 |
+
|
| 652 |
+
if args.meta2midi:
|
| 653 |
+
if not args.meta or not args.midi:
|
| 654 |
+
parser.error("--meta2midi requires --meta and --midi")
|
| 655 |
+
meta2midi(args.meta, args.midi)
|
| 656 |
+
elif args.midi2meta:
|
| 657 |
+
if not args.midi or not args.meta or not args.vocal:
|
| 658 |
+
parser.error(
|
| 659 |
+
"--midi2meta requires --midi, --meta, --vocal"
|
| 660 |
+
)
|
| 661 |
+
midi2meta(
|
| 662 |
+
args.midi,
|
| 663 |
+
args.meta,
|
| 664 |
+
args.vocal,
|
| 665 |
+
rmvpe_model_path=args.rmvpe_model_path,
|
| 666 |
+
device=args.device,
|
| 667 |
+
)
|
| 668 |
+
else:
|
| 669 |
+
parser.print_help()
|
preprocess/tools/note_transcription/__init__.py
ADDED
|
File without changes
|
preprocess/tools/note_transcription/model.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/RickyL-2000/ROSVOT
|
| 2 |
+
import math
|
| 3 |
+
import sys
|
| 4 |
+
import traceback
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Dict, Optional
|
| 9 |
+
|
| 10 |
+
import librosa
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
|
| 15 |
+
from .utils.os_utils import safe_path
|
| 16 |
+
from .utils.commons.hparams import set_hparams
|
| 17 |
+
from .utils.commons.ckpt_utils import load_ckpt
|
| 18 |
+
from .utils.commons.dataset_utils import pad_or_cut_xd
|
| 19 |
+
from .utils.audio.mel import MelNet
|
| 20 |
+
from .utils.audio.pitch_utils import (
|
| 21 |
+
norm_interp_f0,
|
| 22 |
+
denorm_f0,
|
| 23 |
+
f0_to_coarse,
|
| 24 |
+
boundary2Interval,
|
| 25 |
+
save_midi,
|
| 26 |
+
midi_to_hz,
|
| 27 |
+
)
|
| 28 |
+
from .utils.rosvot_utils import (
|
| 29 |
+
get_mel_len,
|
| 30 |
+
align_word,
|
| 31 |
+
regulate_real_note_itv,
|
| 32 |
+
regulate_ill_slur,
|
| 33 |
+
bd_to_durs,
|
| 34 |
+
)
|
| 35 |
+
from .modules.pe.rmvpe import RMVPE
|
| 36 |
+
from .modules.rosvot.rosvot import MidiExtractor, WordbdExtractor
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@torch.no_grad()
|
| 40 |
+
def infer_sample(
|
| 41 |
+
item: Dict[str, Any],
|
| 42 |
+
hparams: Dict[str, Any],
|
| 43 |
+
models: Dict[str, Any],
|
| 44 |
+
device: torch.device,
|
| 45 |
+
*,
|
| 46 |
+
save_dir: Optional[str] = None,
|
| 47 |
+
apply_rwbd: Optional[bool] = None,
|
| 48 |
+
# outputs
|
| 49 |
+
save_plot: bool = False,
|
| 50 |
+
no_save_midi: bool = True,
|
| 51 |
+
no_save_npy: bool = True,
|
| 52 |
+
verbose: bool = False,
|
| 53 |
+
) -> Dict[str, Any]:
|
| 54 |
+
if "item_name" not in item or "wav_fn" not in item:
|
| 55 |
+
raise ValueError('item must contain keys: "item_name" and "wav_fn"')
|
| 56 |
+
|
| 57 |
+
item_name = item["item_name"]
|
| 58 |
+
wav_src = item["wav_fn"]
|
| 59 |
+
|
| 60 |
+
# Decide RWBD usage
|
| 61 |
+
if apply_rwbd is None:
|
| 62 |
+
apply_rwbd_ = ("word_durs" not in item)
|
| 63 |
+
else:
|
| 64 |
+
apply_rwbd_ = bool(apply_rwbd)
|
| 65 |
+
|
| 66 |
+
# Models
|
| 67 |
+
model = models["model"]
|
| 68 |
+
mel_net = models["mel_net"]
|
| 69 |
+
pe = models.get("pe")
|
| 70 |
+
wbd_predictor = models.get("wbd_predictor")
|
| 71 |
+
|
| 72 |
+
if wbd_predictor is None and apply_rwbd_:
|
| 73 |
+
raise ValueError("apply_rwbd is True but wbd_predictor model is not provided in models")
|
| 74 |
+
|
| 75 |
+
# ---- Prepare Data ----
|
| 76 |
+
if isinstance(wav_src, str):
|
| 77 |
+
wav, _ = librosa.core.load(wav_src, sr=hparams["audio_sample_rate"])
|
| 78 |
+
else:
|
| 79 |
+
wav = wav_src
|
| 80 |
+
if not isinstance(wav, np.ndarray):
|
| 81 |
+
wav = np.asarray(wav)
|
| 82 |
+
wav = wav.astype(np.float32)
|
| 83 |
+
|
| 84 |
+
# Calculate timestamps and alignment lengths
|
| 85 |
+
wav_len_samples = wav.shape[-1]
|
| 86 |
+
mel_len = get_mel_len(wav_len_samples, hparams["hop_size"])
|
| 87 |
+
|
| 88 |
+
# Word boundary preparation
|
| 89 |
+
mel2word = None
|
| 90 |
+
word_durs_filtered = None
|
| 91 |
+
|
| 92 |
+
if not apply_rwbd_:
|
| 93 |
+
if "word_durs" not in item:
|
| 94 |
+
raise ValueError('apply_rwbd=False but item has no "word_durs"')
|
| 95 |
+
|
| 96 |
+
wd_raw = list(item["word_durs"])
|
| 97 |
+
min_word_dur = hparams.get("min_word_dur", 20) / 1000
|
| 98 |
+
word_durs_filtered = []
|
| 99 |
+
|
| 100 |
+
for i, wd in enumerate(wd_raw):
|
| 101 |
+
if wd < min_word_dur:
|
| 102 |
+
if i == 0 and len(wd_raw) > 1:
|
| 103 |
+
wd_raw[i + 1] += wd
|
| 104 |
+
elif len(word_durs_filtered) > 0:
|
| 105 |
+
word_durs_filtered[-1] += wd
|
| 106 |
+
else:
|
| 107 |
+
word_durs_filtered.append(wd)
|
| 108 |
+
|
| 109 |
+
mel2word, _ = align_word(word_durs_filtered, mel_len, hparams["hop_size"], hparams["audio_sample_rate"])
|
| 110 |
+
mel2word = np.asarray(mel2word)
|
| 111 |
+
if mel2word.size > 0 and mel2word[0] == 0:
|
| 112 |
+
mel2word = mel2word + 1
|
| 113 |
+
|
| 114 |
+
mel2word_len = int(np.sum(mel2word > 0))
|
| 115 |
+
real_len = min(mel_len, mel2word_len)
|
| 116 |
+
else:
|
| 117 |
+
real_len = min(mel_len, hparams["max_frames"])
|
| 118 |
+
|
| 119 |
+
T = math.ceil(min(real_len, hparams["max_frames"]) / hparams["frames_multiple"]) * hparams["frames_multiple"]
|
| 120 |
+
|
| 121 |
+
# ---- Input Tensors & Padding ----
|
| 122 |
+
target_samples = T * hparams["hop_size"]
|
| 123 |
+
wav_t = torch.from_numpy(wav).float().to(device).unsqueeze(0) # [1, L]
|
| 124 |
+
if wav_t.shape[-1] < target_samples:
|
| 125 |
+
wav_t = pad_or_cut_xd(wav_t, target_samples, 1)
|
| 126 |
+
|
| 127 |
+
# ---- Pitch Extraction ----
|
| 128 |
+
if pe is not None:
|
| 129 |
+
f0s, uvs = pe.get_pitch_batch(
|
| 130 |
+
wav_t,
|
| 131 |
+
sample_rate=hparams["audio_sample_rate"],
|
| 132 |
+
hop_size=hparams["hop_size"],
|
| 133 |
+
lengths=[real_len],
|
| 134 |
+
fmax=hparams["f0_max"],
|
| 135 |
+
fmin=hparams["f0_min"],
|
| 136 |
+
)
|
| 137 |
+
f0_1d, uv_1d = norm_interp_f0(f0s[0][:T])
|
| 138 |
+
f0_t = pad_or_cut_xd(torch.FloatTensor(f0_1d).to(device), T, 0).unsqueeze(0)
|
| 139 |
+
uv_t = pad_or_cut_xd(torch.FloatTensor(uv_1d).to(device), T, 0).long().unsqueeze(0)
|
| 140 |
+
pitch_coarse = f0_to_coarse(denorm_f0(f0_t, uv_t)).to(device)
|
| 141 |
+
f0_np = denorm_f0(f0_t, uv_t)[0].detach().cpu().numpy()[:real_len]
|
| 142 |
+
else:
|
| 143 |
+
f0_t = uv_t = pitch_coarse = None
|
| 144 |
+
f0_np = None
|
| 145 |
+
|
| 146 |
+
# ---- Mel Extraction ----
|
| 147 |
+
mel = mel_net(wav_t) # [1, T_padded, C]
|
| 148 |
+
mel = pad_or_cut_xd(mel, T, 1)
|
| 149 |
+
|
| 150 |
+
# Construct non-padding mask
|
| 151 |
+
mel_nonpadding_mask = torch.zeros(1, T, device=device)
|
| 152 |
+
mel_nonpadding_mask[:, :real_len] = 1.0
|
| 153 |
+
|
| 154 |
+
# Apply mask to mel (zero out padding)
|
| 155 |
+
mel = (mel.transpose(1, 2) * mel_nonpadding_mask.unsqueeze(1)).transpose(1, 2)
|
| 156 |
+
# Re-calculate non_padding bool mask
|
| 157 |
+
mel_nonpadding = mel.abs().sum(-1) > 0
|
| 158 |
+
|
| 159 |
+
# ---- Word Boundary ----
|
| 160 |
+
word_durs_used = None
|
| 161 |
+
if apply_rwbd_:
|
| 162 |
+
mel_input = mel[:, :, : hparams.get("wbd_use_mel_bins", 80)]
|
| 163 |
+
wbd_outputs = wbd_predictor(
|
| 164 |
+
mel=mel_input,
|
| 165 |
+
pitch=pitch_coarse,
|
| 166 |
+
uv=uv_t,
|
| 167 |
+
non_padding=mel_nonpadding,
|
| 168 |
+
train=False,
|
| 169 |
+
)
|
| 170 |
+
word_bd = wbd_outputs["word_bd_pred"] # [1, T]
|
| 171 |
+
else:
|
| 172 |
+
# Construct word_bd from provided durs
|
| 173 |
+
mel2word_t = pad_or_cut_xd(torch.LongTensor(mel2word).to(device), T, 0)
|
| 174 |
+
word_bd = torch.zeros_like(mel2word_t)
|
| 175 |
+
# Vectorized check
|
| 176 |
+
word_bd[1:] = (mel2word_t[1:] != mel2word_t[:-1]).long()
|
| 177 |
+
word_bd[real_len:] = 0
|
| 178 |
+
word_bd = word_bd.unsqueeze(0) # [1, T]
|
| 179 |
+
|
| 180 |
+
word_durs_used = np.array(word_durs_filtered)
|
| 181 |
+
|
| 182 |
+
# ---- Main Inference ----
|
| 183 |
+
mel_input = mel[:, :, : hparams.get("use_mel_bins", 80)]
|
| 184 |
+
outputs = model(
|
| 185 |
+
mel=mel_input,
|
| 186 |
+
word_bd=word_bd,
|
| 187 |
+
pitch=pitch_coarse,
|
| 188 |
+
uv=uv_t,
|
| 189 |
+
non_padding=mel_nonpadding,
|
| 190 |
+
train=False,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
note_lengths = outputs["note_lengths"].detach().cpu().numpy()
|
| 194 |
+
note_bd_pred = outputs["note_bd_pred"][0].detach().cpu().numpy()[:real_len]
|
| 195 |
+
note_pred = outputs["note_pred"][0].detach().cpu().numpy()[: note_lengths[0]]
|
| 196 |
+
note_bd_logits = torch.sigmoid(outputs["note_bd_logits"])[0].detach().cpu().numpy()[:real_len]
|
| 197 |
+
|
| 198 |
+
if note_pred.shape == (0,):
|
| 199 |
+
if verbose:
|
| 200 |
+
print(f"skip {item_name}: no notes detected")
|
| 201 |
+
return {
|
| 202 |
+
"item_name": item_name,
|
| 203 |
+
"pitches": [],
|
| 204 |
+
"note_durs": [],
|
| 205 |
+
"note2words": None,
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
# ---- Post-Processing & Regulation ----
|
| 209 |
+
note_itv_pred = boundary2Interval(note_bd_pred)
|
| 210 |
+
note2words = None
|
| 211 |
+
|
| 212 |
+
if apply_rwbd_:
|
| 213 |
+
word_bd_np = outputs['word_bd_pred'][0].detach().cpu().numpy()[:real_len]
|
| 214 |
+
word_durs_derived = np.array(bd_to_durs(word_bd_np)) * hparams['hop_size'] / hparams['audio_sample_rate']
|
| 215 |
+
word_durs_for_reg = word_durs_derived
|
| 216 |
+
word_bd_for_reg = word_bd_np
|
| 217 |
+
else:
|
| 218 |
+
word_bd_for_reg = word_bd[0].detach().cpu().numpy()[:real_len]
|
| 219 |
+
word_durs_for_reg = word_durs_used
|
| 220 |
+
|
| 221 |
+
should_regulate = hparams.get("infer_regulate_real_note_itv", True) and (not apply_rwbd_)
|
| 222 |
+
|
| 223 |
+
if should_regulate and (word_durs_for_reg is not None):
|
| 224 |
+
try:
|
| 225 |
+
note_itv_pred_secs, note2words = regulate_real_note_itv(
|
| 226 |
+
note_itv_pred,
|
| 227 |
+
note_bd_pred,
|
| 228 |
+
word_bd_for_reg,
|
| 229 |
+
word_durs_for_reg,
|
| 230 |
+
hparams["hop_size"],
|
| 231 |
+
hparams["audio_sample_rate"],
|
| 232 |
+
)
|
| 233 |
+
note_pred, note_itv_pred_secs, note2words = regulate_ill_slur(note_pred, note_itv_pred_secs, note2words)
|
| 234 |
+
except Exception as err:
|
| 235 |
+
if verbose:
|
| 236 |
+
_, exc_value, exc_tb = sys.exc_info()
|
| 237 |
+
tb = traceback.extract_tb(exc_tb)[-1]
|
| 238 |
+
print(f"postprocess failed: {err}: {exc_value} in {tb[0]}:{tb[1]} '{tb[2]}' in {tb[3]}")
|
| 239 |
+
# Fallback
|
| 240 |
+
note_itv_pred_secs = note_itv_pred * hparams["hop_size"] / hparams["audio_sample_rate"]
|
| 241 |
+
note2words = None
|
| 242 |
+
else:
|
| 243 |
+
note_itv_pred_secs = note_itv_pred * hparams["hop_size"] / hparams["audio_sample_rate"]
|
| 244 |
+
|
| 245 |
+
# ---- Output ----
|
| 246 |
+
note_durs = [float((itv[1] - itv[0])) for itv in note_itv_pred_secs]
|
| 247 |
+
|
| 248 |
+
out = {
|
| 249 |
+
"item_name": item_name,
|
| 250 |
+
"pitches": note_pred.tolist(),
|
| 251 |
+
"note_durs": note_durs,
|
| 252 |
+
"note2words": note2words.tolist() if note2words is not None else None,
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
# ---- Saving ----
|
| 256 |
+
if save_dir is not None:
|
| 257 |
+
save_dir_path = Path(save_dir)
|
| 258 |
+
save_dir_path.mkdir(parents=True, exist_ok=True)
|
| 259 |
+
fn = str(item_name)
|
| 260 |
+
|
| 261 |
+
if not no_save_midi:
|
| 262 |
+
save_midi(note_pred, note_itv_pred_secs, safe_path(save_dir_path / "midi" / f"{fn}.mid"))
|
| 263 |
+
|
| 264 |
+
if not no_save_npy:
|
| 265 |
+
np.save(safe_path(save_dir_path / "npy" / f"[note]{fn}.npy"), out, allow_pickle=True)
|
| 266 |
+
|
| 267 |
+
if save_plot:
|
| 268 |
+
fig = plt.figure()
|
| 269 |
+
if f0_np is not None:
|
| 270 |
+
plt.plot(f0_np, color="red", label="f0")
|
| 271 |
+
|
| 272 |
+
midi_pred = np.zeros(note_bd_pred.shape[0], dtype=np.float32)
|
| 273 |
+
itvs = np.round(note_itv_pred_secs * hparams["audio_sample_rate"] / hparams["hop_size"]).astype(int)
|
| 274 |
+
for i, itv in enumerate(itvs):
|
| 275 |
+
midi_pred[itv[0] : itv[1]] = note_pred[i]
|
| 276 |
+
plt.plot(midi_to_hz(midi_pred), color="blue", label="pred midi")
|
| 277 |
+
plt.plot(note_bd_logits * 100, color="green", label="note bd logits x100")
|
| 278 |
+
plt.legend()
|
| 279 |
+
plt.tight_layout()
|
| 280 |
+
plt.savefig(safe_path(save_dir_path / "plot" / f"[MIDI]{fn}.png"), format="png")
|
| 281 |
+
plt.close(fig)
|
| 282 |
+
|
| 283 |
+
return out
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def load_rosvot_models(ckpt, config="", wbd_ckpt="", wbd_config="", device="cuda:0", verbose=False, thr=0.85):
|
| 287 |
+
"""
|
| 288 |
+
Load models once to reuse across multiple items.
|
| 289 |
+
"""
|
| 290 |
+
dev = torch.device(device)
|
| 291 |
+
|
| 292 |
+
# 1. Hparams
|
| 293 |
+
config_path = Path(ckpt).with_name("config.yaml") if config == "" else config
|
| 294 |
+
pe_ckpt = Path(ckpt).parent.parent / "rmvpe/model.pt"
|
| 295 |
+
hparams = set_hparams(
|
| 296 |
+
config=config_path,
|
| 297 |
+
print_hparams=verbose,
|
| 298 |
+
hparams_str=f"note_bd_threshold={thr}",
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# 2. Main Model
|
| 302 |
+
model = MidiExtractor(hparams)
|
| 303 |
+
load_ckpt(model, ckpt, verbose=verbose)
|
| 304 |
+
model.eval().to(dev)
|
| 305 |
+
|
| 306 |
+
# 3. MelNet
|
| 307 |
+
mel_net = MelNet(hparams)
|
| 308 |
+
mel_net.to(dev)
|
| 309 |
+
|
| 310 |
+
# 4. Pitch Extractor
|
| 311 |
+
pe = None
|
| 312 |
+
if hparams.get("use_pitch_embed", False):
|
| 313 |
+
pe = RMVPE(pe_ckpt, device=dev)
|
| 314 |
+
|
| 315 |
+
# 5. Word Boundary Predictor (optional but we load if ckpt provided or needed)
|
| 316 |
+
wbd_predictor = None
|
| 317 |
+
if wbd_ckpt:
|
| 318 |
+
wbd_config_path = Path(wbd_ckpt).with_name("config.yaml") if wbd_config == "" else wbd_config
|
| 319 |
+
wbd_hparams = set_hparams(
|
| 320 |
+
config=wbd_config_path,
|
| 321 |
+
print_hparams=False,
|
| 322 |
+
hparams_str="",
|
| 323 |
+
)
|
| 324 |
+
hparams.update({
|
| 325 |
+
"wbd_use_mel_bins": wbd_hparams["use_mel_bins"],
|
| 326 |
+
"min_word_dur": wbd_hparams["min_word_dur"],
|
| 327 |
+
})
|
| 328 |
+
wbd_predictor = WordbdExtractor(wbd_hparams)
|
| 329 |
+
load_ckpt(wbd_predictor, wbd_ckpt, verbose=verbose)
|
| 330 |
+
wbd_predictor.eval().to(dev)
|
| 331 |
+
|
| 332 |
+
models = {
|
| 333 |
+
"model": model,
|
| 334 |
+
"mel_net": mel_net,
|
| 335 |
+
"pe": pe,
|
| 336 |
+
"wbd_predictor": wbd_predictor
|
| 337 |
+
}
|
| 338 |
+
return hparams, models
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
class NoteTranscriber:
|
| 342 |
+
"""Note transcription wrapper based on ROSVOT.
|
| 343 |
+
|
| 344 |
+
Loads ROSVOT and optional RWBD models once in ``__init__`` and
|
| 345 |
+
exposes a :py:meth:`process` API that turns an item dict into
|
| 346 |
+
aligned note metadata for downstream SVS.
|
| 347 |
+
"""
|
| 348 |
+
|
| 349 |
+
def __init__(
|
| 350 |
+
self,
|
| 351 |
+
rosvot_model_path: str,
|
| 352 |
+
rwbd_model_path: str,
|
| 353 |
+
*,
|
| 354 |
+
rosvot_config_path: str = "",
|
| 355 |
+
rwbd_config_path: str = "",
|
| 356 |
+
device: str = "cuda:0",
|
| 357 |
+
thr: float = 0.85,
|
| 358 |
+
verbose: bool = True,
|
| 359 |
+
):
|
| 360 |
+
"""Initialize the note transcriber.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
ckpt: Path to the main ROSVOT checkpoint.
|
| 364 |
+
config: Optional config YAML path for ROSVOT.
|
| 365 |
+
wbd_ckpt: Optional word-boundary checkpoint path.
|
| 366 |
+
wbd_config: Optional config YAML path for RWBD.
|
| 367 |
+
device: Torch device string, e.g. ``"cuda:0"`` / ``"cpu"``.
|
| 368 |
+
thr: Note boundary threshold.
|
| 369 |
+
verbose: Whether to print verbose logs.
|
| 370 |
+
"""
|
| 371 |
+
self.verbose = verbose
|
| 372 |
+
self.device = torch.device(device)
|
| 373 |
+
self.hparams, self.models = load_rosvot_models(
|
| 374 |
+
ckpt=rosvot_model_path,
|
| 375 |
+
config=rosvot_config_path,
|
| 376 |
+
wbd_ckpt=rwbd_model_path,
|
| 377 |
+
wbd_config=rwbd_config_path,
|
| 378 |
+
device=device,
|
| 379 |
+
verbose=verbose,
|
| 380 |
+
thr=thr,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
if self.verbose:
|
| 384 |
+
print(
|
| 385 |
+
"[note transcription] init success:",
|
| 386 |
+
f"device={self.device}",
|
| 387 |
+
f"rosvot_model_path={rosvot_model_path}",
|
| 388 |
+
f"rwbd_model_path={rwbd_model_path if rwbd_model_path else 'None'}",
|
| 389 |
+
f"thr={thr}",
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
def process(
|
| 393 |
+
self,
|
| 394 |
+
item: Dict[str, Any],
|
| 395 |
+
*,
|
| 396 |
+
segment_info: Optional[Dict[str, Any]] = None,
|
| 397 |
+
save_dir: Optional[str] = None,
|
| 398 |
+
apply_rwbd: Optional[bool] = None,
|
| 399 |
+
save_plot: bool = False,
|
| 400 |
+
no_save_midi: bool = True,
|
| 401 |
+
no_save_npy: bool = True,
|
| 402 |
+
verbose: Optional[bool] = None,
|
| 403 |
+
) -> Dict[str, Any]:
|
| 404 |
+
"""Run ROSVOT on a single item and post-process outputs.
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
item: Input metadata dict with at least ``item_name`` and ``wav_fn``.
|
| 408 |
+
segment_info: Optional segment metadata for sliced audio.
|
| 409 |
+
save_dir: Optional directory for debug artifacts (plots, midis).
|
| 410 |
+
apply_rwbd: Whether to run RWBD-based word boundary refinement.
|
| 411 |
+
save_plot: Whether to save diagnostic plots.
|
| 412 |
+
no_save_midi: If True, skip saving midi.
|
| 413 |
+
no_save_npy: If True, skip saving numpy intermediates.
|
| 414 |
+
verbose: Override instance-level verbose flag for this call.
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
Dict with aligned note information for downstream SVS.
|
| 418 |
+
"""
|
| 419 |
+
v = self.verbose if verbose is None else verbose
|
| 420 |
+
if v:
|
| 421 |
+
item_name = item.get("item_name", "")
|
| 422 |
+
wav_fn = item.get("wav_fn", "")
|
| 423 |
+
print(f"[note transcription] process: start: item_name={item_name} wav_fn={wav_fn}")
|
| 424 |
+
t0 = time.time()
|
| 425 |
+
|
| 426 |
+
rosvot_out = infer_sample(
|
| 427 |
+
item,
|
| 428 |
+
self.hparams,
|
| 429 |
+
self.models,
|
| 430 |
+
device=self.device,
|
| 431 |
+
save_dir=save_dir,
|
| 432 |
+
apply_rwbd=apply_rwbd,
|
| 433 |
+
save_plot=save_plot,
|
| 434 |
+
no_save_midi=no_save_midi,
|
| 435 |
+
no_save_npy=no_save_npy,
|
| 436 |
+
verbose=v,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
out = self.post_process(
|
| 440 |
+
metadata=item,
|
| 441 |
+
segment_info=segment_info,
|
| 442 |
+
rosvot_out=rosvot_out,
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
if v:
|
| 446 |
+
dt = time.time() - t0
|
| 447 |
+
print(
|
| 448 |
+
"[note transcription] process: done:",
|
| 449 |
+
f"item_name={out.get('item_name','')}",
|
| 450 |
+
f"n_notes={len(out.get('note_pitch', []) or [])}",
|
| 451 |
+
f"time={dt:.3f}s",
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
return out
|
| 455 |
+
|
| 456 |
+
@staticmethod
|
| 457 |
+
def _normalize_note2words(note2words: list[int]) -> list[int]:
|
| 458 |
+
if not note2words:
|
| 459 |
+
return []
|
| 460 |
+
normalized = [note2words[0]]
|
| 461 |
+
for idx in range(1, len(note2words)):
|
| 462 |
+
if note2words[idx] < normalized[-1]:
|
| 463 |
+
normalized.append(normalized[-1])
|
| 464 |
+
else:
|
| 465 |
+
normalized.append(note2words[idx])
|
| 466 |
+
return normalized
|
| 467 |
+
|
| 468 |
+
@staticmethod
|
| 469 |
+
def _build_ep_types(note2words: list[int], align_words: list[str]) -> list[int]:
|
| 470 |
+
ep_types: list[int] = []
|
| 471 |
+
prev = -1
|
| 472 |
+
for i, w in zip(note2words, align_words):
|
| 473 |
+
if w == "<SP>":
|
| 474 |
+
ep_types.append(1)
|
| 475 |
+
else:
|
| 476 |
+
ep_types.append(2 if i != prev else 3)
|
| 477 |
+
prev = i
|
| 478 |
+
return ep_types
|
| 479 |
+
|
| 480 |
+
def post_process(
|
| 481 |
+
self,
|
| 482 |
+
*,
|
| 483 |
+
metadata: Dict[str, Any],
|
| 484 |
+
segment_info: Dict[str, Any],
|
| 485 |
+
rosvot_out: Dict[str, Any],
|
| 486 |
+
) -> Dict[str, Any]:
|
| 487 |
+
"""Build aligned note metadata using ROSVOT outputs."""
|
| 488 |
+
note2words_raw = rosvot_out.get("note2words") or []
|
| 489 |
+
note2words = self._normalize_note2words(note2words_raw)
|
| 490 |
+
align_words = [
|
| 491 |
+
metadata["words"][idx - 1]
|
| 492 |
+
for idx in note2words_raw
|
| 493 |
+
if 0 < idx <= len(metadata["words"])
|
| 494 |
+
]
|
| 495 |
+
ep_types = self._build_ep_types(note2words, align_words) if align_words else []
|
| 496 |
+
|
| 497 |
+
return {
|
| 498 |
+
"item_name": rosvot_out.get("item_name", "") if not segment_info else segment_info["item_name"],
|
| 499 |
+
"wav_fn": metadata.get("wav_fn", "") if not segment_info else segment_info["wav_fn"],
|
| 500 |
+
"origin_wav_fn": metadata.get("origin_wav_fn", "") if not segment_info else segment_info["origin_wav_fn"],
|
| 501 |
+
"start_time_ms": "" if not segment_info else segment_info["start_time_ms"],
|
| 502 |
+
"end_time_ms": "" if not segment_info else segment_info["end_time_ms"],
|
| 503 |
+
"language": metadata.get("language", ""),
|
| 504 |
+
"note_text": align_words,
|
| 505 |
+
"note_dur": rosvot_out.get("note_durs", []),
|
| 506 |
+
"note_type": ep_types,
|
| 507 |
+
"note_pitch": rosvot_out.get("pitches", []),
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
if __name__ == "__main__":
|
| 511 |
+
|
| 512 |
+
items = json.load(open("example/test/rosvot_input.json", "r"))
|
| 513 |
+
item = items[0]
|
| 514 |
+
|
| 515 |
+
m = NoteTranscriber(
|
| 516 |
+
rosvot_model_path="pretrained_models/rosvot/rosvot/model.pt",
|
| 517 |
+
rwbd_model_path="pretrained_models/rosvot/rwbd/model.pt",
|
| 518 |
+
device="cuda"
|
| 519 |
+
)
|
| 520 |
+
out = m.process(item)
|
| 521 |
+
|
| 522 |
+
print(out)
|
preprocess/tools/note_transcription/modules/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""ROSVOT model submodules."""
|
preprocess/tools/note_transcription/modules/commons/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Common ROSVOT layers and utilities."""
|
preprocess/tools/note_transcription/modules/commons/conformer/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Conformer layers for ROSVOT."""
|
preprocess/tools/note_transcription/modules/commons/conformer/conformer.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
from .espnet_positional_embedding import RelPositionalEncoding, ScaledPositionalEncoding, PositionalEncoding
|
| 3 |
+
from .espnet_transformer_attn import RelPositionMultiHeadedAttention, MultiHeadedAttention
|
| 4 |
+
from .layers import Swish, ConvolutionModule, EncoderLayer, MultiLayeredConv1d
|
| 5 |
+
from ..layers import Embedding
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ConformerLayers(nn.Module):
|
| 9 |
+
def __init__(self, hidden_size, num_layers, kernel_size=9, dropout=0.0, num_heads=4,
|
| 10 |
+
use_last_norm=True, save_hidden=False):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.use_last_norm = use_last_norm
|
| 13 |
+
self.layers = nn.ModuleList()
|
| 14 |
+
positionwise_layer = MultiLayeredConv1d
|
| 15 |
+
positionwise_layer_args = (hidden_size, hidden_size * 4, 1, dropout)
|
| 16 |
+
self.pos_embed = RelPositionalEncoding(hidden_size, dropout)
|
| 17 |
+
self.encoder_layers = nn.ModuleList([EncoderLayer(
|
| 18 |
+
hidden_size,
|
| 19 |
+
RelPositionMultiHeadedAttention(num_heads, hidden_size, 0.0),
|
| 20 |
+
positionwise_layer(*positionwise_layer_args),
|
| 21 |
+
positionwise_layer(*positionwise_layer_args),
|
| 22 |
+
ConvolutionModule(hidden_size, kernel_size, Swish()),
|
| 23 |
+
dropout,
|
| 24 |
+
) for _ in range(num_layers)])
|
| 25 |
+
if self.use_last_norm:
|
| 26 |
+
self.layer_norm = nn.LayerNorm(hidden_size)
|
| 27 |
+
else:
|
| 28 |
+
self.layer_norm = nn.Linear(hidden_size, hidden_size)
|
| 29 |
+
self.save_hidden = save_hidden
|
| 30 |
+
if save_hidden:
|
| 31 |
+
self.hiddens = []
|
| 32 |
+
|
| 33 |
+
def forward(self, x, padding_mask=None):
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
:param x: [B, T, H]
|
| 37 |
+
:param padding_mask: [B, T]
|
| 38 |
+
:return: [B, T, H]
|
| 39 |
+
"""
|
| 40 |
+
self.hiddens = []
|
| 41 |
+
nonpadding_mask = x.abs().sum(-1) > 0
|
| 42 |
+
x = self.pos_embed(x)
|
| 43 |
+
for l in self.encoder_layers:
|
| 44 |
+
x, mask = l(x, nonpadding_mask[:, None, :])
|
| 45 |
+
if self.save_hidden:
|
| 46 |
+
self.hiddens.append(x[0])
|
| 47 |
+
x = x[0]
|
| 48 |
+
x = self.layer_norm(x) * nonpadding_mask.float()[:, :, None]
|
| 49 |
+
return x
|
| 50 |
+
|
| 51 |
+
class FastConformerLayers(ConformerLayers):
|
| 52 |
+
def __init__(self, hidden_size, num_layers, kernel_size=9, dropout=0.0, num_heads=4,
|
| 53 |
+
use_last_norm=True, save_hidden=False):
|
| 54 |
+
super(ConformerLayers, self).__init__()
|
| 55 |
+
self.use_last_norm = use_last_norm
|
| 56 |
+
self.layers = nn.ModuleList()
|
| 57 |
+
positionwise_layer = MultiLayeredConv1d
|
| 58 |
+
positionwise_layer_args = (hidden_size, hidden_size * 4, 1, dropout)
|
| 59 |
+
self.pos_embed = PositionalEncoding(hidden_size, dropout)
|
| 60 |
+
self.encoder_layers = nn.ModuleList([EncoderLayer(
|
| 61 |
+
hidden_size,
|
| 62 |
+
MultiHeadedAttention(num_heads, hidden_size, 0.0, flash=True),
|
| 63 |
+
positionwise_layer(*positionwise_layer_args),
|
| 64 |
+
positionwise_layer(*positionwise_layer_args),
|
| 65 |
+
ConvolutionModule(hidden_size, kernel_size, Swish()),
|
| 66 |
+
dropout,
|
| 67 |
+
) for _ in range(num_layers)])
|
| 68 |
+
if self.use_last_norm:
|
| 69 |
+
self.layer_norm = nn.LayerNorm(hidden_size)
|
| 70 |
+
else:
|
| 71 |
+
self.layer_norm = nn.Linear(hidden_size, hidden_size)
|
| 72 |
+
self.save_hidden = save_hidden
|
| 73 |
+
if save_hidden:
|
| 74 |
+
self.hiddens = []
|
| 75 |
+
|
| 76 |
+
class ConformerEncoder(ConformerLayers):
|
| 77 |
+
def __init__(self, hidden_size, dict_size, num_layers=None):
|
| 78 |
+
conformer_enc_kernel_size = 9
|
| 79 |
+
super().__init__(hidden_size, num_layers, conformer_enc_kernel_size)
|
| 80 |
+
self.embed = Embedding(dict_size, hidden_size, padding_idx=0)
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
:param src_tokens: [B, T]
|
| 86 |
+
:return: [B x T x C]
|
| 87 |
+
"""
|
| 88 |
+
x = self.embed(x) # [B, T, H]
|
| 89 |
+
x = super(ConformerEncoder, self).forward(x)
|
| 90 |
+
return x
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class ConformerDecoder(ConformerLayers):
|
| 94 |
+
def __init__(self, hidden_size, num_layers):
|
| 95 |
+
conformer_dec_kernel_size = 9
|
| 96 |
+
super().__init__(hidden_size, num_layers, conformer_dec_kernel_size)
|
preprocess/tools/note_transcription/modules/commons/conformer/espnet_positional_embedding.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PositionalEncoding(torch.nn.Module):
|
| 6 |
+
"""Positional encoding.
|
| 7 |
+
Args:
|
| 8 |
+
d_model (int): Embedding dimension.
|
| 9 |
+
dropout_rate (float): Dropout rate.
|
| 10 |
+
max_len (int): Maximum input length.
|
| 11 |
+
reverse (bool): Whether to reverse the input position.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
|
| 15 |
+
"""Construct an PositionalEncoding object."""
|
| 16 |
+
super(PositionalEncoding, self).__init__()
|
| 17 |
+
self.d_model = d_model
|
| 18 |
+
self.reverse = reverse
|
| 19 |
+
self.xscale = math.sqrt(self.d_model)
|
| 20 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
| 21 |
+
self.pe = None
|
| 22 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
| 23 |
+
|
| 24 |
+
def extend_pe(self, x):
|
| 25 |
+
"""Reset the positional encodings."""
|
| 26 |
+
if self.pe is not None:
|
| 27 |
+
if self.pe.size(1) >= x.size(1):
|
| 28 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
| 29 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
| 30 |
+
return
|
| 31 |
+
pe = torch.zeros(x.size(1), self.d_model)
|
| 32 |
+
if self.reverse:
|
| 33 |
+
position = torch.arange(
|
| 34 |
+
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
| 35 |
+
).unsqueeze(1)
|
| 36 |
+
else:
|
| 37 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
| 38 |
+
div_term = torch.exp(
|
| 39 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
| 40 |
+
* -(math.log(10000.0) / self.d_model)
|
| 41 |
+
)
|
| 42 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 43 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 44 |
+
pe = pe.unsqueeze(0)
|
| 45 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
| 46 |
+
|
| 47 |
+
def forward(self, x: torch.Tensor):
|
| 48 |
+
"""Add positional encoding.
|
| 49 |
+
Args:
|
| 50 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 51 |
+
Returns:
|
| 52 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 53 |
+
"""
|
| 54 |
+
self.extend_pe(x)
|
| 55 |
+
x = x * self.xscale + self.pe[:, : x.size(1)]
|
| 56 |
+
return self.dropout(x)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class ScaledPositionalEncoding(PositionalEncoding):
|
| 60 |
+
"""Scaled positional encoding module.
|
| 61 |
+
See Sec. 3.2 https://arxiv.org/abs/1809.08895
|
| 62 |
+
Args:
|
| 63 |
+
d_model (int): Embedding dimension.
|
| 64 |
+
dropout_rate (float): Dropout rate.
|
| 65 |
+
max_len (int): Maximum input length.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
| 69 |
+
"""Initialize class."""
|
| 70 |
+
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
|
| 71 |
+
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
|
| 72 |
+
|
| 73 |
+
def reset_parameters(self):
|
| 74 |
+
"""Reset parameters."""
|
| 75 |
+
self.alpha.data = torch.tensor(1.0)
|
| 76 |
+
|
| 77 |
+
def forward(self, x):
|
| 78 |
+
"""Add positional encoding.
|
| 79 |
+
Args:
|
| 80 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 81 |
+
Returns:
|
| 82 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 83 |
+
"""
|
| 84 |
+
self.extend_pe(x)
|
| 85 |
+
x = x + self.alpha * self.pe[:, : x.size(1)]
|
| 86 |
+
return self.dropout(x)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class RelPositionalEncoding(PositionalEncoding):
|
| 90 |
+
"""Relative positional encoding module.
|
| 91 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
| 92 |
+
Args:
|
| 93 |
+
d_model (int): Embedding dimension.
|
| 94 |
+
dropout_rate (float): Dropout rate.
|
| 95 |
+
max_len (int): Maximum input length.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
| 99 |
+
"""Initialize class."""
|
| 100 |
+
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
| 101 |
+
|
| 102 |
+
def forward(self, x):
|
| 103 |
+
"""Compute positional encoding.
|
| 104 |
+
Args:
|
| 105 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
| 106 |
+
Returns:
|
| 107 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
| 108 |
+
torch.Tensor: Positional embedding tensor (1, time, `*`).
|
| 109 |
+
"""
|
| 110 |
+
self.extend_pe(x)
|
| 111 |
+
x = x * self.xscale
|
| 112 |
+
pos_emb = self.pe[:, : x.size(1)]
|
| 113 |
+
return self.dropout(x), self.dropout(pos_emb)
|
preprocess/tools/note_transcription/modules/commons/conformer/espnet_transformer_attn.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2019 Shigeki Karita
|
| 5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 6 |
+
|
| 7 |
+
"""Multi-Head Attention layer definition."""
|
| 8 |
+
|
| 9 |
+
from packaging import version
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
import numpy
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MultiHeadedAttention(nn.Module):
|
| 18 |
+
"""Multi-Head Attention layer.
|
| 19 |
+
Args:
|
| 20 |
+
n_head (int): The number of heads.
|
| 21 |
+
n_feat (int): The number of features.
|
| 22 |
+
dropout_rate (float): Dropout rate.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, n_head, n_feat, dropout_rate, flash=False):
|
| 26 |
+
"""Construct an MultiHeadedAttention object."""
|
| 27 |
+
super(MultiHeadedAttention, self).__init__()
|
| 28 |
+
assert n_feat % n_head == 0
|
| 29 |
+
# We assume d_v always equals d_k
|
| 30 |
+
self.d_k = n_feat // n_head
|
| 31 |
+
self.h = n_head
|
| 32 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
| 33 |
+
self.linear_k = nn.Linear(n_feat, n_feat)
|
| 34 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
| 35 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
| 36 |
+
self.attn = None
|
| 37 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
| 38 |
+
self.dropout_rate = dropout_rate
|
| 39 |
+
self.flash = flash
|
| 40 |
+
|
| 41 |
+
def forward_qkv(self, query, key, value):
|
| 42 |
+
"""Transform query, key and value.
|
| 43 |
+
Args:
|
| 44 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 45 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 46 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 47 |
+
Returns:
|
| 48 |
+
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
|
| 49 |
+
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
|
| 50 |
+
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
|
| 51 |
+
"""
|
| 52 |
+
n_batch = query.size(0)
|
| 53 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
| 54 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
| 55 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
| 56 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
| 57 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
| 58 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
| 59 |
+
|
| 60 |
+
return q, k, v
|
| 61 |
+
|
| 62 |
+
def forward_attention(self, value, scores, mask):
|
| 63 |
+
"""Compute attention context vector.
|
| 64 |
+
Args:
|
| 65 |
+
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
|
| 66 |
+
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
|
| 67 |
+
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
|
| 68 |
+
Returns:
|
| 69 |
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
| 70 |
+
weighted by the attention score (#batch, time1, time2).
|
| 71 |
+
"""
|
| 72 |
+
n_batch = value.size(0)
|
| 73 |
+
if mask is not None:
|
| 74 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
| 75 |
+
min_value = float(
|
| 76 |
+
numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
|
| 77 |
+
)
|
| 78 |
+
scores = scores.masked_fill(mask, min_value)
|
| 79 |
+
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
| 80 |
+
mask, 0.0
|
| 81 |
+
) # (batch, head, time1, time2)
|
| 82 |
+
else:
|
| 83 |
+
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
| 84 |
+
|
| 85 |
+
p_attn = self.dropout(self.attn)
|
| 86 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
| 87 |
+
x = (
|
| 88 |
+
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
| 89 |
+
) # (batch, time1, d_model)
|
| 90 |
+
|
| 91 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
| 92 |
+
|
| 93 |
+
def forward(self, query, key, value, mask):
|
| 94 |
+
"""Compute scaled dot product attention.
|
| 95 |
+
Args:
|
| 96 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 97 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 98 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 99 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
| 100 |
+
(#batch, time1, time2).
|
| 101 |
+
Returns:
|
| 102 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
| 103 |
+
"""
|
| 104 |
+
q, k, v = self.forward_qkv(query, key, value)
|
| 105 |
+
if version.parse(torch.__version__) >= version.parse("2.0") and self.flash:
|
| 106 |
+
n_batch = value.size(0)
|
| 107 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
| 108 |
+
q, k, v, attn_mask=mask.unsqueeze(1) if mask is not None else None, dropout_p=self.dropout_rate)
|
| 109 |
+
x = (
|
| 110 |
+
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
| 111 |
+
) # (batch, time1, d_model)
|
| 112 |
+
return self.linear_out(x)
|
| 113 |
+
else:
|
| 114 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
| 115 |
+
return self.forward_attention(v, scores, mask)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
| 119 |
+
"""Multi-Head Attention layer with relative position encoding.
|
| 120 |
+
Paper: https://arxiv.org/abs/1901.02860
|
| 121 |
+
Args:
|
| 122 |
+
n_head (int): The number of heads.
|
| 123 |
+
n_feat (int): The number of features.
|
| 124 |
+
dropout_rate (float): Dropout rate.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
def __init__(self, n_head, n_feat, dropout_rate):
|
| 128 |
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
| 129 |
+
super().__init__(n_head, n_feat, dropout_rate)
|
| 130 |
+
# linear transformation for positional ecoding
|
| 131 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
| 132 |
+
# these two learnable bias are used in matrix c and matrix d
|
| 133 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 134 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
| 135 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
| 136 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
| 137 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
| 138 |
+
|
| 139 |
+
def rel_shift(self, x, zero_triu=False):
|
| 140 |
+
"""Compute relative positinal encoding.
|
| 141 |
+
Args:
|
| 142 |
+
x (torch.Tensor): Input tensor (batch, time, size).
|
| 143 |
+
zero_triu (bool): If true, return the lower triangular part of the matrix.
|
| 144 |
+
Returns:
|
| 145 |
+
torch.Tensor: Output tensor.
|
| 146 |
+
"""
|
| 147 |
+
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
| 148 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
| 149 |
+
|
| 150 |
+
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
| 151 |
+
x = x_padded[:, :, 1:].view_as(x)
|
| 152 |
+
|
| 153 |
+
if zero_triu:
|
| 154 |
+
ones = torch.ones((x.size(2), x.size(3)))
|
| 155 |
+
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
| 156 |
+
|
| 157 |
+
return x
|
| 158 |
+
|
| 159 |
+
def forward(self, query, key, value, pos_emb, mask):
|
| 160 |
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
| 161 |
+
Args:
|
| 162 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
| 163 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
| 164 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
| 165 |
+
pos_emb (torch.Tensor): Positional embedding tensor (#batch, time2, size).
|
| 166 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
| 167 |
+
(#batch, time1, time2).
|
| 168 |
+
Returns:
|
| 169 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
| 170 |
+
"""
|
| 171 |
+
q, k, v = self.forward_qkv(query, key, value)
|
| 172 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
| 173 |
+
|
| 174 |
+
n_batch_pos = pos_emb.size(0)
|
| 175 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
| 176 |
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
| 177 |
+
|
| 178 |
+
# (batch, head, time1, d_k)
|
| 179 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
| 180 |
+
# (batch, head, time1, d_k)
|
| 181 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
| 182 |
+
|
| 183 |
+
# compute attention score
|
| 184 |
+
# first compute matrix a and matrix c
|
| 185 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 186 |
+
# (batch, head, time1, time2)
|
| 187 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
| 188 |
+
|
| 189 |
+
# compute matrix b and matrix d
|
| 190 |
+
# (batch, head, time1, time2)
|
| 191 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
| 192 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
| 193 |
+
|
| 194 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
| 195 |
+
self.d_k
|
| 196 |
+
) # (batch, head, time1, time2)
|
| 197 |
+
|
| 198 |
+
return self.forward_attention(v, scores, mask)
|
preprocess/tools/note_transcription/modules/commons/conformer/layers.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from ..layers import LayerNorm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ConvolutionModule(nn.Module):
|
| 8 |
+
"""ConvolutionModule in Conformer model.
|
| 9 |
+
Args:
|
| 10 |
+
channels (int): The number of channels of conv layers.
|
| 11 |
+
kernel_size (int): Kernerl size of conv layers.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
|
| 15 |
+
"""Construct an ConvolutionModule object."""
|
| 16 |
+
super(ConvolutionModule, self).__init__()
|
| 17 |
+
# kernerl_size should be a odd number for 'SAME' padding
|
| 18 |
+
assert (kernel_size - 1) % 2 == 0
|
| 19 |
+
|
| 20 |
+
self.pointwise_conv1 = nn.Conv1d(
|
| 21 |
+
channels,
|
| 22 |
+
2 * channels,
|
| 23 |
+
kernel_size=1,
|
| 24 |
+
stride=1,
|
| 25 |
+
padding=0,
|
| 26 |
+
bias=bias,
|
| 27 |
+
)
|
| 28 |
+
self.depthwise_conv = nn.Conv1d(
|
| 29 |
+
channels,
|
| 30 |
+
channels,
|
| 31 |
+
kernel_size,
|
| 32 |
+
stride=1,
|
| 33 |
+
padding=(kernel_size - 1) // 2,
|
| 34 |
+
groups=channels,
|
| 35 |
+
bias=bias,
|
| 36 |
+
)
|
| 37 |
+
self.norm = nn.BatchNorm1d(channels)
|
| 38 |
+
self.pointwise_conv2 = nn.Conv1d(
|
| 39 |
+
channels,
|
| 40 |
+
channels,
|
| 41 |
+
kernel_size=1,
|
| 42 |
+
stride=1,
|
| 43 |
+
padding=0,
|
| 44 |
+
bias=bias,
|
| 45 |
+
)
|
| 46 |
+
self.activation = activation
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
"""Compute convolution module.
|
| 50 |
+
Args:
|
| 51 |
+
x (torch.Tensor): Input tensor (#batch, time, channels).
|
| 52 |
+
Returns:
|
| 53 |
+
torch.Tensor: Output tensor (#batch, time, channels).
|
| 54 |
+
"""
|
| 55 |
+
# exchange the temporal dimension and the feature dimension
|
| 56 |
+
x = x.transpose(1, 2)
|
| 57 |
+
|
| 58 |
+
# GLU mechanism
|
| 59 |
+
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
| 60 |
+
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
| 61 |
+
|
| 62 |
+
# 1D Depthwise Conv
|
| 63 |
+
x = self.depthwise_conv(x)
|
| 64 |
+
x = self.activation(self.norm(x))
|
| 65 |
+
|
| 66 |
+
x = self.pointwise_conv2(x)
|
| 67 |
+
|
| 68 |
+
return x.transpose(1, 2)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class MultiLayeredConv1d(torch.nn.Module):
|
| 72 |
+
"""Multi-layered conv1d for Transformer block.
|
| 73 |
+
This is a module of multi-leyered conv1d designed
|
| 74 |
+
to replace positionwise feed-forward network
|
| 75 |
+
in Transforner block, which is introduced in
|
| 76 |
+
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
| 77 |
+
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
| 78 |
+
https://arxiv.org/pdf/1905.09263.pdf
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
| 82 |
+
"""Initialize MultiLayeredConv1d module.
|
| 83 |
+
Args:
|
| 84 |
+
in_chans (int): Number of input channels.
|
| 85 |
+
hidden_chans (int): Number of hidden channels.
|
| 86 |
+
kernel_size (int): Kernel size of conv1d.
|
| 87 |
+
dropout_rate (float): Dropout rate.
|
| 88 |
+
"""
|
| 89 |
+
super(MultiLayeredConv1d, self).__init__()
|
| 90 |
+
self.w_1 = torch.nn.Conv1d(
|
| 91 |
+
in_chans,
|
| 92 |
+
hidden_chans,
|
| 93 |
+
kernel_size,
|
| 94 |
+
stride=1,
|
| 95 |
+
padding=(kernel_size - 1) // 2,
|
| 96 |
+
)
|
| 97 |
+
self.w_2 = torch.nn.Conv1d(
|
| 98 |
+
hidden_chans,
|
| 99 |
+
in_chans,
|
| 100 |
+
kernel_size,
|
| 101 |
+
stride=1,
|
| 102 |
+
padding=(kernel_size - 1) // 2,
|
| 103 |
+
)
|
| 104 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
"""Calculate forward propagation.
|
| 108 |
+
Args:
|
| 109 |
+
x (torch.Tensor): Batch of input tensors (B, T, in_chans).
|
| 110 |
+
Returns:
|
| 111 |
+
torch.Tensor: Batch of output tensors (B, T, hidden_chans).
|
| 112 |
+
"""
|
| 113 |
+
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
| 114 |
+
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class Swish(torch.nn.Module):
|
| 118 |
+
"""Construct an Swish object."""
|
| 119 |
+
|
| 120 |
+
def forward(self, x):
|
| 121 |
+
"""Return Swich activation function."""
|
| 122 |
+
return x * torch.sigmoid(x)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class EncoderLayer(nn.Module):
|
| 126 |
+
"""Encoder layer module.
|
| 127 |
+
Args:
|
| 128 |
+
size (int): Input dimension.
|
| 129 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
| 130 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
|
| 131 |
+
can be used as the argument.
|
| 132 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
| 133 |
+
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
| 134 |
+
can be used as the argument.
|
| 135 |
+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
|
| 136 |
+
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
| 137 |
+
can be used as the argument.
|
| 138 |
+
conv_module (torch.nn.Module): Convolution module instance.
|
| 139 |
+
`ConvlutionModule` instance can be used as the argument.
|
| 140 |
+
dropout_rate (float): Dropout rate.
|
| 141 |
+
normalize_before (bool): Whether to use layer_norm before the first block.
|
| 142 |
+
concat_after (bool): Whether to concat attention layer's input and output.
|
| 143 |
+
if True, additional linear will be applied.
|
| 144 |
+
i.e. x -> x + linear(concat(x, att(x)))
|
| 145 |
+
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
size,
|
| 151 |
+
self_attn,
|
| 152 |
+
feed_forward,
|
| 153 |
+
feed_forward_macaron,
|
| 154 |
+
conv_module,
|
| 155 |
+
dropout_rate,
|
| 156 |
+
normalize_before=True,
|
| 157 |
+
concat_after=False,
|
| 158 |
+
):
|
| 159 |
+
"""Construct an EncoderLayer object."""
|
| 160 |
+
super(EncoderLayer, self).__init__()
|
| 161 |
+
self.self_attn = self_attn
|
| 162 |
+
self.feed_forward = feed_forward
|
| 163 |
+
self.feed_forward_macaron = feed_forward_macaron
|
| 164 |
+
self.conv_module = conv_module
|
| 165 |
+
self.norm_ff = LayerNorm(size) # for the FNN module
|
| 166 |
+
self.norm_mha = LayerNorm(size) # for the MHA module
|
| 167 |
+
if feed_forward_macaron is not None:
|
| 168 |
+
self.norm_ff_macaron = LayerNorm(size)
|
| 169 |
+
self.ff_scale = 0.5
|
| 170 |
+
else:
|
| 171 |
+
self.ff_scale = 1.0
|
| 172 |
+
if self.conv_module is not None:
|
| 173 |
+
self.norm_conv = LayerNorm(size) # for the CNN module
|
| 174 |
+
self.norm_final = LayerNorm(size) # for the final output of the block
|
| 175 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 176 |
+
self.size = size
|
| 177 |
+
self.normalize_before = normalize_before
|
| 178 |
+
self.concat_after = concat_after
|
| 179 |
+
if self.concat_after:
|
| 180 |
+
self.concat_linear = nn.Linear(size + size, size)
|
| 181 |
+
|
| 182 |
+
def forward(self, x_input, mask, cache=None):
|
| 183 |
+
"""Compute encoded features.
|
| 184 |
+
Args:
|
| 185 |
+
x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
|
| 186 |
+
- w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
|
| 187 |
+
- w/o pos emb: Tensor (#batch, time, size).
|
| 188 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
| 189 |
+
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
| 190 |
+
Returns:
|
| 191 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
| 192 |
+
torch.Tensor: Mask tensor (#batch, time).
|
| 193 |
+
"""
|
| 194 |
+
if isinstance(x_input, tuple):
|
| 195 |
+
x, pos_emb = x_input[0], x_input[1]
|
| 196 |
+
else:
|
| 197 |
+
x, pos_emb = x_input, None
|
| 198 |
+
|
| 199 |
+
# whether to use macaron style
|
| 200 |
+
if self.feed_forward_macaron is not None:
|
| 201 |
+
residual = x
|
| 202 |
+
if self.normalize_before:
|
| 203 |
+
x = self.norm_ff_macaron(x)
|
| 204 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
|
| 205 |
+
if not self.normalize_before:
|
| 206 |
+
x = self.norm_ff_macaron(x)
|
| 207 |
+
|
| 208 |
+
# multi-headed self-attention module
|
| 209 |
+
residual = x
|
| 210 |
+
if self.normalize_before:
|
| 211 |
+
x = self.norm_mha(x)
|
| 212 |
+
|
| 213 |
+
if cache is None:
|
| 214 |
+
x_q = x
|
| 215 |
+
else:
|
| 216 |
+
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
|
| 217 |
+
x_q = x[:, -1:, :]
|
| 218 |
+
residual = residual[:, -1:, :]
|
| 219 |
+
mask = None if mask is None else mask[:, -1:, :]
|
| 220 |
+
|
| 221 |
+
if pos_emb is not None:
|
| 222 |
+
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
|
| 223 |
+
else:
|
| 224 |
+
x_att = self.self_attn(x_q, x, x, mask)
|
| 225 |
+
|
| 226 |
+
if self.concat_after:
|
| 227 |
+
x_concat = torch.cat((x, x_att), dim=-1)
|
| 228 |
+
x = residual + self.concat_linear(x_concat)
|
| 229 |
+
else:
|
| 230 |
+
x = residual + self.dropout(x_att)
|
| 231 |
+
if not self.normalize_before:
|
| 232 |
+
x = self.norm_mha(x)
|
| 233 |
+
|
| 234 |
+
# convolution module
|
| 235 |
+
if self.conv_module is not None:
|
| 236 |
+
residual = x
|
| 237 |
+
if self.normalize_before:
|
| 238 |
+
x = self.norm_conv(x)
|
| 239 |
+
x = residual + self.dropout(self.conv_module(x))
|
| 240 |
+
if not self.normalize_before:
|
| 241 |
+
x = self.norm_conv(x)
|
| 242 |
+
|
| 243 |
+
# feed forward module
|
| 244 |
+
residual = x
|
| 245 |
+
if self.normalize_before:
|
| 246 |
+
x = self.norm_ff(x)
|
| 247 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
| 248 |
+
if not self.normalize_before:
|
| 249 |
+
x = self.norm_ff(x)
|
| 250 |
+
|
| 251 |
+
if self.conv_module is not None:
|
| 252 |
+
x = self.norm_final(x)
|
| 253 |
+
|
| 254 |
+
if cache is not None:
|
| 255 |
+
x = torch.cat([cache, x], dim=1)
|
| 256 |
+
|
| 257 |
+
if pos_emb is not None:
|
| 258 |
+
return (x, pos_emb), mask
|
| 259 |
+
|
| 260 |
+
return x, mask
|
preprocess/tools/note_transcription/modules/commons/conv.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from .layers import LayerNorm, Embedding
|
| 7 |
+
|
| 8 |
+
class LambdaLayer(nn.Module):
|
| 9 |
+
def __init__(self, lambd):
|
| 10 |
+
super(LambdaLayer, self).__init__()
|
| 11 |
+
self.lambd = lambd
|
| 12 |
+
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
return self.lambd(x)
|
| 15 |
+
|
| 16 |
+
def init_weights_func(m):
|
| 17 |
+
classname = m.__class__.__name__
|
| 18 |
+
if classname.find("Conv1d") != -1:
|
| 19 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 20 |
+
|
| 21 |
+
def get_norm_builder(norm_type, channels, ln_eps=1e-6):
|
| 22 |
+
if norm_type == 'bn':
|
| 23 |
+
norm_builder = lambda: nn.BatchNorm1d(channels)
|
| 24 |
+
elif norm_type == 'in':
|
| 25 |
+
norm_builder = lambda: nn.InstanceNorm1d(channels, affine=True)
|
| 26 |
+
elif norm_type == 'gn':
|
| 27 |
+
norm_builder = lambda: nn.GroupNorm(8, channels)
|
| 28 |
+
elif norm_type == 'ln':
|
| 29 |
+
norm_builder = lambda: LayerNorm(channels, dim=1, eps=ln_eps)
|
| 30 |
+
else:
|
| 31 |
+
norm_builder = lambda: nn.Identity()
|
| 32 |
+
return norm_builder
|
| 33 |
+
|
| 34 |
+
def get_act_builder(act_type):
|
| 35 |
+
if act_type == 'gelu':
|
| 36 |
+
act_builder = lambda: nn.GELU()
|
| 37 |
+
elif act_type == 'relu':
|
| 38 |
+
act_builder = lambda: nn.ReLU(inplace=True)
|
| 39 |
+
elif act_type == 'leakyrelu':
|
| 40 |
+
act_builder = lambda: nn.LeakyReLU(negative_slope=0.01, inplace=True)
|
| 41 |
+
elif act_type == 'swish':
|
| 42 |
+
act_builder = lambda: nn.SiLU(inplace=True)
|
| 43 |
+
else:
|
| 44 |
+
act_builder = lambda: nn.Identity()
|
| 45 |
+
return act_builder
|
| 46 |
+
|
| 47 |
+
class ResidualBlock(nn.Module):
|
| 48 |
+
"""Implements conv->PReLU->norm n-times"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, channels, kernel_size, dilation, n=2, norm_type='bn', dropout=0.0,
|
| 51 |
+
c_multiple=2, ln_eps=1e-12, act_type='gelu'):
|
| 52 |
+
super(ResidualBlock, self).__init__()
|
| 53 |
+
|
| 54 |
+
norm_builder = get_norm_builder(norm_type, channels, ln_eps)
|
| 55 |
+
act_builder = get_act_builder(act_type)
|
| 56 |
+
|
| 57 |
+
self.blocks = [
|
| 58 |
+
nn.Sequential(
|
| 59 |
+
norm_builder(),
|
| 60 |
+
nn.Conv1d(channels, c_multiple * channels, kernel_size, dilation=dilation,
|
| 61 |
+
padding=(dilation * (kernel_size - 1)) // 2),
|
| 62 |
+
LambdaLayer(lambda x: x * kernel_size ** -0.5),
|
| 63 |
+
act_builder(),
|
| 64 |
+
nn.Conv1d(c_multiple * channels, channels, 1, dilation=dilation),
|
| 65 |
+
)
|
| 66 |
+
for i in range(n)
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
self.blocks = nn.ModuleList(self.blocks)
|
| 70 |
+
self.dropout = dropout
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
|
| 74 |
+
for b in self.blocks:
|
| 75 |
+
x_ = b(x)
|
| 76 |
+
if self.dropout > 0 and self.training:
|
| 77 |
+
x_ = F.dropout(x_, self.dropout, training=self.training)
|
| 78 |
+
x = x + x_
|
| 79 |
+
x = x * nonpadding
|
| 80 |
+
return x
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class ConvBlocks(nn.Module):
|
| 84 |
+
"""Decodes the expanded phoneme encoding into spectrograms"""
|
| 85 |
+
|
| 86 |
+
def __init__(self, hidden_size, out_dims, dilations, kernel_size,
|
| 87 |
+
norm_type='ln', layers_in_block=2, c_multiple=2,
|
| 88 |
+
dropout=0.0, ln_eps=1e-5,
|
| 89 |
+
init_weights=True, is_BTC=True, num_layers=None, post_net_kernel=3, act_type='gelu'):
|
| 90 |
+
super(ConvBlocks, self).__init__()
|
| 91 |
+
self.is_BTC = is_BTC
|
| 92 |
+
if num_layers is not None:
|
| 93 |
+
dilations = [1] * num_layers
|
| 94 |
+
self.res_blocks = nn.Sequential(
|
| 95 |
+
*[ResidualBlock(hidden_size, kernel_size, d,
|
| 96 |
+
n=layers_in_block, norm_type=norm_type, c_multiple=c_multiple,
|
| 97 |
+
dropout=dropout, ln_eps=ln_eps, act_type=act_type)
|
| 98 |
+
for d in dilations],
|
| 99 |
+
)
|
| 100 |
+
norm = get_norm_builder(norm_type, hidden_size, ln_eps)()
|
| 101 |
+
self.last_norm = norm
|
| 102 |
+
self.post_net1 = nn.Conv1d(hidden_size, out_dims, kernel_size=post_net_kernel,
|
| 103 |
+
padding=post_net_kernel // 2)
|
| 104 |
+
if init_weights:
|
| 105 |
+
self.apply(init_weights_func)
|
| 106 |
+
|
| 107 |
+
def forward(self, x, nonpadding=None):
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
:param x: [B, T, H]
|
| 111 |
+
:return: [B, T, H]
|
| 112 |
+
"""
|
| 113 |
+
if self.is_BTC:
|
| 114 |
+
x = x.transpose(1, 2)
|
| 115 |
+
if nonpadding is None:
|
| 116 |
+
nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
|
| 117 |
+
elif self.is_BTC:
|
| 118 |
+
nonpadding = nonpadding.transpose(1, 2)
|
| 119 |
+
x = self.res_blocks(x) * nonpadding
|
| 120 |
+
x = self.last_norm(x) * nonpadding
|
| 121 |
+
x = self.post_net1(x) * nonpadding
|
| 122 |
+
if self.is_BTC:
|
| 123 |
+
x = x.transpose(1, 2)
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class TextConvEncoder(ConvBlocks):
|
| 128 |
+
def __init__(self, dict_size, hidden_size, out_dims, dilations, kernel_size,
|
| 129 |
+
norm_type='ln', layers_in_block=2, c_multiple=2,
|
| 130 |
+
dropout=0.0, ln_eps=1e-5, init_weights=True, num_layers=None, post_net_kernel=3):
|
| 131 |
+
super().__init__(hidden_size, out_dims, dilations, kernel_size,
|
| 132 |
+
norm_type, layers_in_block, c_multiple,
|
| 133 |
+
dropout, ln_eps, init_weights, num_layers=num_layers,
|
| 134 |
+
post_net_kernel=post_net_kernel)
|
| 135 |
+
self.embed_tokens = Embedding(dict_size, hidden_size, 0)
|
| 136 |
+
self.embed_scale = math.sqrt(hidden_size)
|
| 137 |
+
|
| 138 |
+
def forward(self, txt_tokens):
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
:param txt_tokens: [B, T]
|
| 142 |
+
:return: {
|
| 143 |
+
'encoder_out': [B x T x C]
|
| 144 |
+
}
|
| 145 |
+
"""
|
| 146 |
+
x = self.embed_scale * self.embed_tokens(txt_tokens)
|
| 147 |
+
return super().forward(x)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class ConditionalConvBlocks(ConvBlocks):
|
| 151 |
+
def __init__(self, hidden_size, c_cond, c_out, dilations, kernel_size,
|
| 152 |
+
norm_type='ln', layers_in_block=2, c_multiple=2,
|
| 153 |
+
dropout=0.0, ln_eps=1e-5, init_weights=True, is_BTC=True, num_layers=None):
|
| 154 |
+
super().__init__(hidden_size, c_out, dilations, kernel_size,
|
| 155 |
+
norm_type, layers_in_block, c_multiple,
|
| 156 |
+
dropout, ln_eps, init_weights, is_BTC=False, num_layers=num_layers)
|
| 157 |
+
self.g_prenet = nn.Conv1d(c_cond, hidden_size, 3, padding=1)
|
| 158 |
+
self.is_BTC_ = is_BTC
|
| 159 |
+
if init_weights:
|
| 160 |
+
self.g_prenet.apply(init_weights_func)
|
| 161 |
+
|
| 162 |
+
def forward(self, x, cond, nonpadding=None):
|
| 163 |
+
if self.is_BTC_:
|
| 164 |
+
x = x.transpose(1, 2)
|
| 165 |
+
cond = cond.transpose(1, 2)
|
| 166 |
+
if nonpadding is not None:
|
| 167 |
+
nonpadding = nonpadding.transpose(1, 2)
|
| 168 |
+
if nonpadding is None:
|
| 169 |
+
nonpadding = x.abs().sum(1)[:, None]
|
| 170 |
+
x = x + self.g_prenet(cond)
|
| 171 |
+
x = x * nonpadding
|
| 172 |
+
x = super(ConditionalConvBlocks, self).forward(x) # input needs to be BTC
|
| 173 |
+
if self.is_BTC_:
|
| 174 |
+
x = x.transpose(1, 2)
|
| 175 |
+
return x
|
preprocess/tools/note_transcription/modules/commons/layers.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.autograd import Function
|
| 4 |
+
|
| 5 |
+
class LayerNorm(torch.nn.LayerNorm):
|
| 6 |
+
"""Layer normalization module.
|
| 7 |
+
:param int nout: output dim size
|
| 8 |
+
:param int dim: dimension to be normalized
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, nout, dim=-1, eps=1e-5):
|
| 12 |
+
"""Construct an LayerNorm object."""
|
| 13 |
+
super(LayerNorm, self).__init__(nout, eps=eps)
|
| 14 |
+
self.dim = dim
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
"""Apply layer normalization.
|
| 18 |
+
:param torch.Tensor x: input tensor
|
| 19 |
+
:return: layer normalized tensor
|
| 20 |
+
:rtype torch.Tensor
|
| 21 |
+
"""
|
| 22 |
+
if self.dim == -1:
|
| 23 |
+
return super(LayerNorm, self).forward(x)
|
| 24 |
+
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Reshape(nn.Module):
|
| 28 |
+
def __init__(self, *args):
|
| 29 |
+
super(Reshape, self).__init__()
|
| 30 |
+
self.shape = args
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
return x.view(self.shape)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Permute(nn.Module):
|
| 37 |
+
def __init__(self, *args):
|
| 38 |
+
super(Permute, self).__init__()
|
| 39 |
+
self.args = args
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
return x.permute(self.args)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def Linear(in_features, out_features, bias=True, init_type='xavier'):
|
| 46 |
+
m = nn.Linear(in_features, out_features, bias)
|
| 47 |
+
if init_type == 'xavier':
|
| 48 |
+
nn.init.xavier_uniform_(m.weight)
|
| 49 |
+
elif init_type == 'kaiming':
|
| 50 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_in')
|
| 51 |
+
if bias:
|
| 52 |
+
nn.init.constant_(m.bias, 0.)
|
| 53 |
+
return m
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def Embedding(num_embeddings, embedding_dim, padding_idx=None, init_type='normal'):
|
| 57 |
+
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
| 58 |
+
if init_type == 'normal':
|
| 59 |
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
| 60 |
+
elif init_type == 'kaiming':
|
| 61 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_in')
|
| 62 |
+
if padding_idx is not None:
|
| 63 |
+
nn.init.constant_(m.weight[padding_idx], 0)
|
| 64 |
+
return m
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class GradientReverseFunction(Function):
|
| 68 |
+
@staticmethod
|
| 69 |
+
def forward(ctx, input, coeff=1.):
|
| 70 |
+
ctx.coeff = coeff
|
| 71 |
+
output = input * 1.0
|
| 72 |
+
return output
|
| 73 |
+
|
| 74 |
+
@staticmethod
|
| 75 |
+
def backward(ctx, grad_output):
|
| 76 |
+
return grad_output.neg() * ctx.coeff, None
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class GRL(nn.Module):
|
| 80 |
+
def __init__(self):
|
| 81 |
+
super(GRL, self).__init__()
|
| 82 |
+
|
| 83 |
+
def forward(self, *input):
|
| 84 |
+
return GradientReverseFunction.apply(*input)
|
| 85 |
+
|
preprocess/tools/note_transcription/modules/commons/rel_transformer.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
from .layers import Embedding
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def convert_pad_shape(pad_shape):
|
| 10 |
+
l = pad_shape[::-1]
|
| 11 |
+
pad_shape = [item for sublist in l for item in sublist]
|
| 12 |
+
return pad_shape
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def shift_1d(x):
|
| 16 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
| 17 |
+
return x
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def sequence_mask(length, max_length=None):
|
| 21 |
+
if max_length is None:
|
| 22 |
+
max_length = length.max()
|
| 23 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
| 24 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Encoder(nn.Module):
|
| 28 |
+
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
|
| 29 |
+
window_size=None, block_length=None, pre_ln=False, **kwargs):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.hidden_channels = hidden_channels
|
| 32 |
+
self.filter_channels = filter_channels
|
| 33 |
+
self.n_heads = n_heads
|
| 34 |
+
self.n_layers = n_layers
|
| 35 |
+
self.kernel_size = kernel_size
|
| 36 |
+
self.p_dropout = p_dropout
|
| 37 |
+
self.window_size = window_size
|
| 38 |
+
self.block_length = block_length
|
| 39 |
+
self.pre_ln = pre_ln
|
| 40 |
+
|
| 41 |
+
self.drop = nn.Dropout(p_dropout)
|
| 42 |
+
self.attn_layers = nn.ModuleList()
|
| 43 |
+
self.norm_layers_1 = nn.ModuleList()
|
| 44 |
+
self.ffn_layers = nn.ModuleList()
|
| 45 |
+
self.norm_layers_2 = nn.ModuleList()
|
| 46 |
+
for i in range(self.n_layers):
|
| 47 |
+
self.attn_layers.append(
|
| 48 |
+
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, window_size=window_size,
|
| 49 |
+
p_dropout=p_dropout, block_length=block_length))
|
| 50 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
| 51 |
+
self.ffn_layers.append(
|
| 52 |
+
FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
|
| 53 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
| 54 |
+
if pre_ln:
|
| 55 |
+
self.last_ln = LayerNorm(hidden_channels)
|
| 56 |
+
|
| 57 |
+
def forward(self, x, x_mask):
|
| 58 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
| 59 |
+
for i in range(self.n_layers):
|
| 60 |
+
x = x * x_mask
|
| 61 |
+
x_ = x
|
| 62 |
+
if self.pre_ln:
|
| 63 |
+
x = self.norm_layers_1[i](x)
|
| 64 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
| 65 |
+
y = self.drop(y)
|
| 66 |
+
x = x_ + y
|
| 67 |
+
if not self.pre_ln:
|
| 68 |
+
x = self.norm_layers_1[i](x)
|
| 69 |
+
|
| 70 |
+
x_ = x
|
| 71 |
+
if self.pre_ln:
|
| 72 |
+
x = self.norm_layers_2[i](x)
|
| 73 |
+
y = self.ffn_layers[i](x, x_mask)
|
| 74 |
+
y = self.drop(y)
|
| 75 |
+
x = x_ + y
|
| 76 |
+
if not self.pre_ln:
|
| 77 |
+
x = self.norm_layers_2[i](x)
|
| 78 |
+
if self.pre_ln:
|
| 79 |
+
x = self.last_ln(x)
|
| 80 |
+
x = x * x_mask
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class MultiHeadAttention(nn.Module):
|
| 85 |
+
def __init__(self, channels, out_channels, n_heads, window_size=None, heads_share=True, p_dropout=0.,
|
| 86 |
+
block_length=None, proximal_bias=False, proximal_init=False):
|
| 87 |
+
super().__init__()
|
| 88 |
+
assert channels % n_heads == 0
|
| 89 |
+
|
| 90 |
+
self.channels = channels
|
| 91 |
+
self.out_channels = out_channels
|
| 92 |
+
self.n_heads = n_heads
|
| 93 |
+
self.window_size = window_size
|
| 94 |
+
self.heads_share = heads_share
|
| 95 |
+
self.block_length = block_length
|
| 96 |
+
self.proximal_bias = proximal_bias
|
| 97 |
+
self.p_dropout = p_dropout
|
| 98 |
+
self.attn = None
|
| 99 |
+
|
| 100 |
+
self.k_channels = channels // n_heads
|
| 101 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
| 102 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
| 103 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
| 104 |
+
if window_size is not None:
|
| 105 |
+
n_heads_rel = 1 if heads_share else n_heads
|
| 106 |
+
rel_stddev = self.k_channels ** -0.5
|
| 107 |
+
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
| 108 |
+
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
| 109 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
| 110 |
+
self.drop = nn.Dropout(p_dropout)
|
| 111 |
+
|
| 112 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
| 113 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
| 114 |
+
if proximal_init:
|
| 115 |
+
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
|
| 116 |
+
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
|
| 117 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
| 118 |
+
|
| 119 |
+
def forward(self, x, c, attn_mask=None):
|
| 120 |
+
q = self.conv_q(x)
|
| 121 |
+
k = self.conv_k(c)
|
| 122 |
+
v = self.conv_v(c)
|
| 123 |
+
|
| 124 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
| 125 |
+
|
| 126 |
+
x = self.conv_o(x)
|
| 127 |
+
return x
|
| 128 |
+
|
| 129 |
+
def attention(self, query, key, value, mask=None):
|
| 130 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
| 131 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
| 132 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
| 133 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
| 134 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
| 135 |
+
|
| 136 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
|
| 137 |
+
if self.window_size is not None:
|
| 138 |
+
assert t_s == t_t, "Relative attention is only available for self-attention."
|
| 139 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
| 140 |
+
rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
|
| 141 |
+
rel_logits = self._relative_position_to_absolute_position(rel_logits)
|
| 142 |
+
scores_local = rel_logits / math.sqrt(self.k_channels)
|
| 143 |
+
scores = scores + scores_local
|
| 144 |
+
if self.proximal_bias:
|
| 145 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
| 146 |
+
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
| 147 |
+
if mask is not None:
|
| 148 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
| 149 |
+
if self.block_length is not None:
|
| 150 |
+
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
|
| 151 |
+
scores = scores * block_mask + -1e4 * (1 - block_mask)
|
| 152 |
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
| 153 |
+
p_attn = self.drop(p_attn)
|
| 154 |
+
output = torch.matmul(p_attn, value)
|
| 155 |
+
if self.window_size is not None:
|
| 156 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
| 157 |
+
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
| 158 |
+
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
| 159 |
+
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
| 160 |
+
return output, p_attn
|
| 161 |
+
|
| 162 |
+
def _matmul_with_relative_values(self, x, y):
|
| 163 |
+
"""
|
| 164 |
+
x: [b, h, l, m]
|
| 165 |
+
y: [h or 1, m, d]
|
| 166 |
+
ret: [b, h, l, d]
|
| 167 |
+
"""
|
| 168 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
| 169 |
+
return ret
|
| 170 |
+
|
| 171 |
+
def _matmul_with_relative_keys(self, x, y):
|
| 172 |
+
"""
|
| 173 |
+
x: [b, h, l, d]
|
| 174 |
+
y: [h or 1, m, d]
|
| 175 |
+
ret: [b, h, l, m]
|
| 176 |
+
"""
|
| 177 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
| 178 |
+
return ret
|
| 179 |
+
|
| 180 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
| 181 |
+
max_relative_position = 2 * self.window_size + 1
|
| 182 |
+
# Pad first before slice to avoid using cond ops.
|
| 183 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
| 184 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
| 185 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
| 186 |
+
if pad_length > 0:
|
| 187 |
+
padded_relative_embeddings = F.pad(
|
| 188 |
+
relative_embeddings,
|
| 189 |
+
convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
|
| 190 |
+
else:
|
| 191 |
+
padded_relative_embeddings = relative_embeddings
|
| 192 |
+
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
| 193 |
+
return used_relative_embeddings
|
| 194 |
+
|
| 195 |
+
def _relative_position_to_absolute_position(self, x):
|
| 196 |
+
"""
|
| 197 |
+
x: [b, h, l, 2*l-1]
|
| 198 |
+
ret: [b, h, l, l]
|
| 199 |
+
"""
|
| 200 |
+
batch, heads, length, _ = x.size()
|
| 201 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
| 202 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
| 203 |
+
|
| 204 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
| 205 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
| 206 |
+
x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
| 207 |
+
|
| 208 |
+
# Reshape and slice out the padded elements.
|
| 209 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
|
| 210 |
+
return x_final
|
| 211 |
+
|
| 212 |
+
def _absolute_position_to_relative_position(self, x):
|
| 213 |
+
"""
|
| 214 |
+
x: [b, h, l, l]
|
| 215 |
+
ret: [b, h, l, 2*l-1]
|
| 216 |
+
"""
|
| 217 |
+
batch, heads, length, _ = x.size()
|
| 218 |
+
# padd along column
|
| 219 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
| 220 |
+
x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)])
|
| 221 |
+
# add 0's in the beginning that will skew the elements after reshape
|
| 222 |
+
x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
| 223 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
| 224 |
+
return x_final
|
| 225 |
+
|
| 226 |
+
def _attention_bias_proximal(self, length):
|
| 227 |
+
"""Bias for self-attention to encourage attention to close positions.
|
| 228 |
+
Args:
|
| 229 |
+
length: an integer scalar.
|
| 230 |
+
Returns:
|
| 231 |
+
a Tensor with shape [1, 1, length, length]
|
| 232 |
+
"""
|
| 233 |
+
r = torch.arange(length, dtype=torch.float32)
|
| 234 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
| 235 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class FFN(nn.Module):
|
| 239 |
+
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None):
|
| 240 |
+
super().__init__()
|
| 241 |
+
self.in_channels = in_channels
|
| 242 |
+
self.out_channels = out_channels
|
| 243 |
+
self.filter_channels = filter_channels
|
| 244 |
+
self.kernel_size = kernel_size
|
| 245 |
+
self.p_dropout = p_dropout
|
| 246 |
+
self.activation = activation
|
| 247 |
+
|
| 248 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
| 249 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, 1)
|
| 250 |
+
self.drop = nn.Dropout(p_dropout)
|
| 251 |
+
|
| 252 |
+
def forward(self, x, x_mask):
|
| 253 |
+
x = self.conv_1(x * x_mask)
|
| 254 |
+
if self.activation == "gelu":
|
| 255 |
+
x = x * torch.sigmoid(1.702 * x)
|
| 256 |
+
else:
|
| 257 |
+
x = torch.relu(x)
|
| 258 |
+
x = self.drop(x)
|
| 259 |
+
x = self.conv_2(x * x_mask)
|
| 260 |
+
return x * x_mask
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class LayerNorm(nn.Module):
|
| 264 |
+
def __init__(self, channels, eps=1e-4):
|
| 265 |
+
super().__init__()
|
| 266 |
+
self.channels = channels
|
| 267 |
+
self.eps = eps
|
| 268 |
+
|
| 269 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
| 270 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
| 271 |
+
|
| 272 |
+
def forward(self, x):
|
| 273 |
+
n_dims = len(x.shape)
|
| 274 |
+
mean = torch.mean(x, 1, keepdim=True)
|
| 275 |
+
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
|
| 276 |
+
|
| 277 |
+
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
| 278 |
+
|
| 279 |
+
shape = [1, -1] + [1] * (n_dims - 2)
|
| 280 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
| 281 |
+
return x
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class ConvReluNorm(nn.Module):
|
| 285 |
+
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
| 286 |
+
super().__init__()
|
| 287 |
+
self.in_channels = in_channels
|
| 288 |
+
self.hidden_channels = hidden_channels
|
| 289 |
+
self.out_channels = out_channels
|
| 290 |
+
self.kernel_size = kernel_size
|
| 291 |
+
self.n_layers = n_layers
|
| 292 |
+
self.p_dropout = p_dropout
|
| 293 |
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
| 294 |
+
|
| 295 |
+
self.conv_layers = nn.ModuleList()
|
| 296 |
+
self.norm_layers = nn.ModuleList()
|
| 297 |
+
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
| 298 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
| 299 |
+
self.relu_drop = nn.Sequential(
|
| 300 |
+
nn.ReLU(),
|
| 301 |
+
nn.Dropout(p_dropout))
|
| 302 |
+
for _ in range(n_layers - 1):
|
| 303 |
+
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
| 304 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
| 305 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
| 306 |
+
self.proj.weight.data.zero_()
|
| 307 |
+
self.proj.bias.data.zero_()
|
| 308 |
+
|
| 309 |
+
def forward(self, x, x_mask):
|
| 310 |
+
x_org = x
|
| 311 |
+
for i in range(self.n_layers):
|
| 312 |
+
x = self.conv_layers[i](x * x_mask)
|
| 313 |
+
x = self.norm_layers[i](x)
|
| 314 |
+
x = self.relu_drop(x)
|
| 315 |
+
x = x_org + self.proj(x)
|
| 316 |
+
return x * x_mask
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class RelTransformerEncoder(nn.Module):
|
| 320 |
+
def __init__(self,
|
| 321 |
+
n_vocab,
|
| 322 |
+
out_channels,
|
| 323 |
+
hidden_channels,
|
| 324 |
+
filter_channels,
|
| 325 |
+
n_heads,
|
| 326 |
+
n_layers,
|
| 327 |
+
kernel_size,
|
| 328 |
+
p_dropout=0.0,
|
| 329 |
+
window_size=4,
|
| 330 |
+
block_length=None,
|
| 331 |
+
prenet=True,
|
| 332 |
+
pre_ln=True,
|
| 333 |
+
):
|
| 334 |
+
|
| 335 |
+
super().__init__()
|
| 336 |
+
|
| 337 |
+
self.n_vocab = n_vocab
|
| 338 |
+
self.out_channels = out_channels
|
| 339 |
+
self.hidden_channels = hidden_channels
|
| 340 |
+
self.filter_channels = filter_channels
|
| 341 |
+
self.n_heads = n_heads
|
| 342 |
+
self.n_layers = n_layers
|
| 343 |
+
self.kernel_size = kernel_size
|
| 344 |
+
self.p_dropout = p_dropout
|
| 345 |
+
self.window_size = window_size
|
| 346 |
+
self.block_length = block_length
|
| 347 |
+
self.prenet = prenet
|
| 348 |
+
if n_vocab > 0:
|
| 349 |
+
self.emb = Embedding(n_vocab, hidden_channels, padding_idx=0)
|
| 350 |
+
|
| 351 |
+
if prenet:
|
| 352 |
+
self.pre = ConvReluNorm(hidden_channels, hidden_channels, hidden_channels,
|
| 353 |
+
kernel_size=5, n_layers=3, p_dropout=0)
|
| 354 |
+
self.encoder = Encoder(
|
| 355 |
+
hidden_channels,
|
| 356 |
+
filter_channels,
|
| 357 |
+
n_heads,
|
| 358 |
+
n_layers,
|
| 359 |
+
kernel_size,
|
| 360 |
+
p_dropout,
|
| 361 |
+
window_size=window_size,
|
| 362 |
+
block_length=block_length,
|
| 363 |
+
pre_ln=pre_ln,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
def forward(self, x, x_mask=None):
|
| 367 |
+
if self.n_vocab > 0:
|
| 368 |
+
x_lengths = (x > 0).long().sum(-1)
|
| 369 |
+
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
| 370 |
+
else:
|
| 371 |
+
x_lengths = (x.abs().sum(-1) > 0).long().sum(-1)
|
| 372 |
+
x = torch.transpose(x, 1, -1) # [b, h, t]
|
| 373 |
+
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
| 374 |
+
|
| 375 |
+
if self.prenet:
|
| 376 |
+
x = self.pre(x, x_mask)
|
| 377 |
+
x = self.encoder(x, x_mask)
|
| 378 |
+
return x.transpose(1, 2)
|
preprocess/tools/note_transcription/modules/commons/rnn.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PreNet(nn.Module):
|
| 7 |
+
def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.fc1 = nn.Linear(in_dims, fc1_dims)
|
| 10 |
+
self.fc2 = nn.Linear(fc1_dims, fc2_dims)
|
| 11 |
+
self.p = dropout
|
| 12 |
+
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
x = self.fc1(x)
|
| 15 |
+
x = F.relu(x)
|
| 16 |
+
x = F.dropout(x, self.p, training=self.training)
|
| 17 |
+
x = self.fc2(x)
|
| 18 |
+
x = F.relu(x)
|
| 19 |
+
x = F.dropout(x, self.p, training=self.training)
|
| 20 |
+
return x
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class HighwayNetwork(nn.Module):
|
| 24 |
+
def __init__(self, size):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.W1 = nn.Linear(size, size)
|
| 27 |
+
self.W2 = nn.Linear(size, size)
|
| 28 |
+
self.W1.bias.data.fill_(0.)
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
x1 = self.W1(x)
|
| 32 |
+
x2 = self.W2(x)
|
| 33 |
+
g = torch.sigmoid(x2)
|
| 34 |
+
y = g * F.relu(x1) + (1. - g) * x
|
| 35 |
+
return y
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BatchNormConv(nn.Module):
|
| 39 |
+
def __init__(self, in_channels, out_channels, kernel, relu=True):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
|
| 42 |
+
self.bnorm = nn.BatchNorm1d(out_channels)
|
| 43 |
+
self.relu = relu
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
x = self.conv(x)
|
| 47 |
+
x = F.relu(x) if self.relu is True else x
|
| 48 |
+
return self.bnorm(x)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ConvNorm(torch.nn.Module):
|
| 52 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
|
| 53 |
+
padding=None, dilation=1, bias=True, w_init_gain='linear'):
|
| 54 |
+
super(ConvNorm, self).__init__()
|
| 55 |
+
if padding is None:
|
| 56 |
+
assert (kernel_size % 2 == 1)
|
| 57 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
| 58 |
+
|
| 59 |
+
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
| 60 |
+
kernel_size=kernel_size, stride=stride,
|
| 61 |
+
padding=padding, dilation=dilation,
|
| 62 |
+
bias=bias)
|
| 63 |
+
|
| 64 |
+
torch.nn.init.xavier_uniform_(
|
| 65 |
+
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
|
| 66 |
+
|
| 67 |
+
def forward(self, signal):
|
| 68 |
+
conv_signal = self.conv(signal)
|
| 69 |
+
return conv_signal
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class CBHG(nn.Module):
|
| 73 |
+
def __init__(self, K, in_channels, channels, proj_channels, num_highways):
|
| 74 |
+
super().__init__()
|
| 75 |
+
|
| 76 |
+
# List of all rnns to call `flatten_parameters()` on
|
| 77 |
+
self._to_flatten = []
|
| 78 |
+
|
| 79 |
+
self.bank_kernels = [i for i in range(1, K + 1)]
|
| 80 |
+
self.conv1d_bank = nn.ModuleList()
|
| 81 |
+
for k in self.bank_kernels:
|
| 82 |
+
conv = BatchNormConv(in_channels, channels, k)
|
| 83 |
+
self.conv1d_bank.append(conv)
|
| 84 |
+
|
| 85 |
+
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
| 86 |
+
|
| 87 |
+
self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
|
| 88 |
+
self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
|
| 89 |
+
|
| 90 |
+
# Fix the highway input if necessary
|
| 91 |
+
if proj_channels[-1] != channels:
|
| 92 |
+
self.highway_mismatch = True
|
| 93 |
+
self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
|
| 94 |
+
else:
|
| 95 |
+
self.highway_mismatch = False
|
| 96 |
+
|
| 97 |
+
self.highways = nn.ModuleList()
|
| 98 |
+
for i in range(num_highways):
|
| 99 |
+
hn = HighwayNetwork(channels)
|
| 100 |
+
self.highways.append(hn)
|
| 101 |
+
|
| 102 |
+
self.rnn = nn.GRU(channels, channels, batch_first=True, bidirectional=True)
|
| 103 |
+
self._to_flatten.append(self.rnn)
|
| 104 |
+
|
| 105 |
+
# Avoid fragmentation of RNN parameters and associated warning
|
| 106 |
+
self._flatten_parameters()
|
| 107 |
+
|
| 108 |
+
def forward(self, x):
|
| 109 |
+
# Although we `_flatten_parameters()` on init, when using DataParallel
|
| 110 |
+
# the model gets replicated, making it no longer guaranteed that the
|
| 111 |
+
# weights are contiguous in GPU memory. Hence, we must call it again
|
| 112 |
+
self._flatten_parameters()
|
| 113 |
+
|
| 114 |
+
# Save these for later
|
| 115 |
+
residual = x
|
| 116 |
+
seq_len = x.size(-1)
|
| 117 |
+
conv_bank = []
|
| 118 |
+
|
| 119 |
+
# Convolution Bank
|
| 120 |
+
for conv in self.conv1d_bank:
|
| 121 |
+
c = conv(x) # Convolution
|
| 122 |
+
conv_bank.append(c[:, :, :seq_len])
|
| 123 |
+
|
| 124 |
+
# Stack along the channel axis
|
| 125 |
+
conv_bank = torch.cat(conv_bank, dim=1)
|
| 126 |
+
|
| 127 |
+
# dump the last padding to fit residual
|
| 128 |
+
x = self.maxpool(conv_bank)[:, :, :seq_len]
|
| 129 |
+
|
| 130 |
+
# Conv1d projections
|
| 131 |
+
x = self.conv_project1(x)
|
| 132 |
+
x = self.conv_project2(x)
|
| 133 |
+
|
| 134 |
+
# Residual Connect
|
| 135 |
+
x = x + residual
|
| 136 |
+
|
| 137 |
+
# Through the highways
|
| 138 |
+
x = x.transpose(1, 2)
|
| 139 |
+
if self.highway_mismatch is True:
|
| 140 |
+
x = self.pre_highway(x)
|
| 141 |
+
for h in self.highways:
|
| 142 |
+
x = h(x)
|
| 143 |
+
|
| 144 |
+
# And then the RNN
|
| 145 |
+
x, _ = self.rnn(x)
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
def _flatten_parameters(self):
|
| 149 |
+
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
|
| 150 |
+
to improve efficiency and avoid PyTorch yelling at us."""
|
| 151 |
+
[m.flatten_parameters() for m in self._to_flatten]
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class TacotronEncoder(nn.Module):
|
| 155 |
+
def __init__(self, embed_dims, num_chars, cbhg_channels, K, num_highways, dropout):
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.embedding = nn.Embedding(num_chars, embed_dims)
|
| 158 |
+
self.pre_net = PreNet(embed_dims, embed_dims, embed_dims, dropout=dropout)
|
| 159 |
+
self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
|
| 160 |
+
proj_channels=[cbhg_channels, cbhg_channels],
|
| 161 |
+
num_highways=num_highways)
|
| 162 |
+
self.proj_out = nn.Linear(cbhg_channels * 2, cbhg_channels)
|
| 163 |
+
|
| 164 |
+
def forward(self, x):
|
| 165 |
+
x = self.embedding(x)
|
| 166 |
+
x = self.pre_net(x)
|
| 167 |
+
x.transpose_(1, 2)
|
| 168 |
+
x = self.cbhg(x)
|
| 169 |
+
x = self.proj_out(x)
|
| 170 |
+
return x
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class RNNEncoder(nn.Module):
|
| 174 |
+
def __init__(self, num_chars, embedding_dim, n_convolutions=3, kernel_size=5):
|
| 175 |
+
super(RNNEncoder, self).__init__()
|
| 176 |
+
self.embedding = nn.Embedding(num_chars, embedding_dim, padding_idx=0)
|
| 177 |
+
convolutions = []
|
| 178 |
+
for _ in range(n_convolutions):
|
| 179 |
+
conv_layer = nn.Sequential(
|
| 180 |
+
ConvNorm(embedding_dim,
|
| 181 |
+
embedding_dim,
|
| 182 |
+
kernel_size=kernel_size, stride=1,
|
| 183 |
+
padding=int((kernel_size - 1) / 2),
|
| 184 |
+
dilation=1, w_init_gain='relu'),
|
| 185 |
+
nn.BatchNorm1d(embedding_dim))
|
| 186 |
+
convolutions.append(conv_layer)
|
| 187 |
+
self.convolutions = nn.ModuleList(convolutions)
|
| 188 |
+
|
| 189 |
+
self.lstm = nn.LSTM(embedding_dim, int(embedding_dim / 2), 1,
|
| 190 |
+
batch_first=True, bidirectional=True)
|
| 191 |
+
|
| 192 |
+
def forward(self, x):
|
| 193 |
+
input_lengths = (x > 0).sum(-1)
|
| 194 |
+
input_lengths = input_lengths.cpu().numpy()
|
| 195 |
+
|
| 196 |
+
x = self.embedding(x)
|
| 197 |
+
x = x.transpose(1, 2) # [B, H, T]
|
| 198 |
+
for conv in self.convolutions:
|
| 199 |
+
x = F.dropout(F.relu(conv(x)), 0.5, self.training) + x
|
| 200 |
+
x = x.transpose(1, 2) # [B, T, H]
|
| 201 |
+
|
| 202 |
+
# pytorch tensor are not reversible, hence the conversion
|
| 203 |
+
x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
|
| 204 |
+
|
| 205 |
+
self.lstm.flatten_parameters()
|
| 206 |
+
outputs, _ = self.lstm(x)
|
| 207 |
+
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
|
| 208 |
+
|
| 209 |
+
return outputs
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class DecoderRNN(torch.nn.Module):
|
| 213 |
+
def __init__(self, hidden_size, decoder_rnn_dim, dropout):
|
| 214 |
+
super(DecoderRNN, self).__init__()
|
| 215 |
+
self.in_conv1d = nn.Sequential(
|
| 216 |
+
torch.nn.Conv1d(
|
| 217 |
+
in_channels=hidden_size,
|
| 218 |
+
out_channels=hidden_size,
|
| 219 |
+
kernel_size=9, padding=4,
|
| 220 |
+
),
|
| 221 |
+
torch.nn.ReLU(),
|
| 222 |
+
torch.nn.Conv1d(
|
| 223 |
+
in_channels=hidden_size,
|
| 224 |
+
out_channels=hidden_size,
|
| 225 |
+
kernel_size=9, padding=4,
|
| 226 |
+
),
|
| 227 |
+
)
|
| 228 |
+
self.ln = nn.LayerNorm(hidden_size)
|
| 229 |
+
if decoder_rnn_dim == 0:
|
| 230 |
+
decoder_rnn_dim = hidden_size * 2
|
| 231 |
+
self.rnn = torch.nn.LSTM(
|
| 232 |
+
input_size=hidden_size,
|
| 233 |
+
hidden_size=decoder_rnn_dim,
|
| 234 |
+
num_layers=1,
|
| 235 |
+
batch_first=True,
|
| 236 |
+
bidirectional=True,
|
| 237 |
+
dropout=dropout
|
| 238 |
+
)
|
| 239 |
+
self.rnn.flatten_parameters()
|
| 240 |
+
self.conv1d = torch.nn.Conv1d(
|
| 241 |
+
in_channels=decoder_rnn_dim * 2,
|
| 242 |
+
out_channels=hidden_size,
|
| 243 |
+
kernel_size=3,
|
| 244 |
+
padding=1,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
def forward(self, x):
|
| 248 |
+
input_masks = x.abs().sum(-1).ne(0).data[:, :, None]
|
| 249 |
+
input_lengths = input_masks.sum([-1, -2])
|
| 250 |
+
input_lengths = input_lengths.cpu().numpy()
|
| 251 |
+
|
| 252 |
+
x = self.in_conv1d(x.transpose(1, 2)).transpose(1, 2)
|
| 253 |
+
x = self.ln(x)
|
| 254 |
+
x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
|
| 255 |
+
self.rnn.flatten_parameters()
|
| 256 |
+
x, _ = self.rnn(x) # [B, T, C]
|
| 257 |
+
x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
| 258 |
+
x = x * input_masks
|
| 259 |
+
pre_mel = self.conv1d(x.transpose(1, 2)).transpose(1, 2) # [B, T, C]
|
| 260 |
+
pre_mel = pre_mel * input_masks
|
| 261 |
+
return pre_mel
|
preprocess/tools/note_transcription/modules/commons/transformer.py
ADDED
|
@@ -0,0 +1,751 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import Parameter, Linear
|
| 5 |
+
from .layers import LayerNorm, Embedding
|
| 6 |
+
from ...utils.nn.seq_utils import (
|
| 7 |
+
get_incremental_state,
|
| 8 |
+
set_incremental_state,
|
| 9 |
+
softmax,
|
| 10 |
+
make_positions,
|
| 11 |
+
)
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
DEFAULT_MAX_SOURCE_POSITIONS = 2000
|
| 15 |
+
DEFAULT_MAX_TARGET_POSITIONS = 2000
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
| 19 |
+
"""This module produces sinusoidal positional embeddings of any length.
|
| 20 |
+
|
| 21 |
+
Padding symbols are ignored.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, embedding_dim, padding_idx, init_size=1024):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.embedding_dim = embedding_dim
|
| 27 |
+
self.padding_idx = padding_idx
|
| 28 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
| 29 |
+
init_size,
|
| 30 |
+
embedding_dim,
|
| 31 |
+
padding_idx,
|
| 32 |
+
)
|
| 33 |
+
self.register_buffer('_float_tensor', torch.FloatTensor(1))
|
| 34 |
+
|
| 35 |
+
@staticmethod
|
| 36 |
+
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
|
| 37 |
+
"""Build sinusoidal embeddings.
|
| 38 |
+
|
| 39 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
| 40 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
| 41 |
+
"""
|
| 42 |
+
half_dim = embedding_dim // 2
|
| 43 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 44 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
| 45 |
+
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
|
| 46 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
|
| 47 |
+
if embedding_dim % 2 == 1:
|
| 48 |
+
# zero pad
|
| 49 |
+
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
| 50 |
+
if padding_idx is not None:
|
| 51 |
+
emb[padding_idx, :] = 0
|
| 52 |
+
return emb
|
| 53 |
+
|
| 54 |
+
def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
|
| 55 |
+
"""Input is expected to be of size [bsz x seqlen]."""
|
| 56 |
+
bsz, seq_len = input.shape[:2]
|
| 57 |
+
max_pos = self.padding_idx + 1 + seq_len
|
| 58 |
+
if self.weights is None or max_pos > self.weights.size(0):
|
| 59 |
+
# recompute/expand embeddings if needed
|
| 60 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
| 61 |
+
max_pos,
|
| 62 |
+
self.embedding_dim,
|
| 63 |
+
self.padding_idx,
|
| 64 |
+
)
|
| 65 |
+
self.weights = self.weights.to(self._float_tensor)
|
| 66 |
+
|
| 67 |
+
if incremental_state is not None:
|
| 68 |
+
# positions is the same for every token when decoding a single step
|
| 69 |
+
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
|
| 70 |
+
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
|
| 71 |
+
|
| 72 |
+
positions = make_positions(input, self.padding_idx) if positions is None else positions
|
| 73 |
+
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
|
| 74 |
+
|
| 75 |
+
def max_positions(self):
|
| 76 |
+
"""Maximum number of supported positions."""
|
| 77 |
+
return int(1e5) # an arbitrary large number
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class TransformerFFNLayer(nn.Module):
|
| 81 |
+
def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.kernel_size = kernel_size
|
| 84 |
+
self.dropout = dropout
|
| 85 |
+
self.act = act
|
| 86 |
+
if padding == 'SAME':
|
| 87 |
+
self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
|
| 88 |
+
elif padding == 'LEFT':
|
| 89 |
+
self.ffn_1 = nn.Sequential(
|
| 90 |
+
nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
|
| 91 |
+
nn.Conv1d(hidden_size, filter_size, kernel_size)
|
| 92 |
+
)
|
| 93 |
+
self.ffn_2 = Linear(filter_size, hidden_size)
|
| 94 |
+
|
| 95 |
+
def forward(self, x, incremental_state=None):
|
| 96 |
+
# x: T x B x C
|
| 97 |
+
if incremental_state is not None:
|
| 98 |
+
saved_state = self._get_input_buffer(incremental_state)
|
| 99 |
+
if 'prev_input' in saved_state:
|
| 100 |
+
prev_input = saved_state['prev_input']
|
| 101 |
+
x = torch.cat((prev_input, x), dim=0)
|
| 102 |
+
x = x[-self.kernel_size:]
|
| 103 |
+
saved_state['prev_input'] = x
|
| 104 |
+
self._set_input_buffer(incremental_state, saved_state)
|
| 105 |
+
|
| 106 |
+
x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
|
| 107 |
+
x = x * self.kernel_size ** -0.5
|
| 108 |
+
|
| 109 |
+
if incremental_state is not None:
|
| 110 |
+
x = x[-1:]
|
| 111 |
+
if self.act == 'gelu':
|
| 112 |
+
x = F.gelu(x)
|
| 113 |
+
if self.act == 'relu':
|
| 114 |
+
x = F.relu(x)
|
| 115 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 116 |
+
x = self.ffn_2(x)
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
def _get_input_buffer(self, incremental_state):
|
| 120 |
+
return get_incremental_state(
|
| 121 |
+
self,
|
| 122 |
+
incremental_state,
|
| 123 |
+
'f',
|
| 124 |
+
) or {}
|
| 125 |
+
|
| 126 |
+
def _set_input_buffer(self, incremental_state, buffer):
|
| 127 |
+
set_incremental_state(
|
| 128 |
+
self,
|
| 129 |
+
incremental_state,
|
| 130 |
+
'f',
|
| 131 |
+
buffer,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def clear_buffer(self, incremental_state):
|
| 135 |
+
if incremental_state is not None:
|
| 136 |
+
saved_state = self._get_input_buffer(incremental_state)
|
| 137 |
+
if 'prev_input' in saved_state:
|
| 138 |
+
del saved_state['prev_input']
|
| 139 |
+
self._set_input_buffer(incremental_state, saved_state)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class MultiheadAttention(nn.Module):
|
| 143 |
+
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
|
| 144 |
+
add_bias_kv=False, add_zero_attn=False, self_attention=False,
|
| 145 |
+
encoder_decoder_attention=False):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.embed_dim = embed_dim
|
| 148 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 149 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 150 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
| 151 |
+
|
| 152 |
+
self.num_heads = num_heads
|
| 153 |
+
self.dropout = dropout
|
| 154 |
+
self.head_dim = embed_dim // num_heads
|
| 155 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
| 156 |
+
self.scaling = self.head_dim ** -0.5
|
| 157 |
+
|
| 158 |
+
self.self_attention = self_attention
|
| 159 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
| 160 |
+
|
| 161 |
+
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
|
| 162 |
+
'value to be of the same size'
|
| 163 |
+
|
| 164 |
+
if self.qkv_same_dim:
|
| 165 |
+
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
|
| 166 |
+
else:
|
| 167 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
| 168 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
| 169 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
| 170 |
+
|
| 171 |
+
if bias:
|
| 172 |
+
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
|
| 173 |
+
else:
|
| 174 |
+
self.register_parameter('in_proj_bias', None)
|
| 175 |
+
|
| 176 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 177 |
+
|
| 178 |
+
if add_bias_kv:
|
| 179 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 180 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 181 |
+
else:
|
| 182 |
+
self.bias_k = self.bias_v = None
|
| 183 |
+
|
| 184 |
+
self.add_zero_attn = add_zero_attn
|
| 185 |
+
|
| 186 |
+
self.reset_parameters()
|
| 187 |
+
|
| 188 |
+
self.enable_torch_version = False
|
| 189 |
+
if hasattr(F, "multi_head_attention_forward"):
|
| 190 |
+
self.enable_torch_version = True
|
| 191 |
+
else:
|
| 192 |
+
self.enable_torch_version = False
|
| 193 |
+
self.last_attn_probs = None
|
| 194 |
+
|
| 195 |
+
def reset_parameters(self):
|
| 196 |
+
if self.qkv_same_dim:
|
| 197 |
+
nn.init.xavier_uniform_(self.in_proj_weight)
|
| 198 |
+
else:
|
| 199 |
+
nn.init.xavier_uniform_(self.k_proj_weight)
|
| 200 |
+
nn.init.xavier_uniform_(self.v_proj_weight)
|
| 201 |
+
nn.init.xavier_uniform_(self.q_proj_weight)
|
| 202 |
+
|
| 203 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
| 204 |
+
if self.in_proj_bias is not None:
|
| 205 |
+
nn.init.constant_(self.in_proj_bias, 0.)
|
| 206 |
+
nn.init.constant_(self.out_proj.bias, 0.)
|
| 207 |
+
if self.bias_k is not None:
|
| 208 |
+
nn.init.xavier_normal_(self.bias_k)
|
| 209 |
+
if self.bias_v is not None:
|
| 210 |
+
nn.init.xavier_normal_(self.bias_v)
|
| 211 |
+
|
| 212 |
+
def forward(
|
| 213 |
+
self,
|
| 214 |
+
query, key, value,
|
| 215 |
+
key_padding_mask=None,
|
| 216 |
+
incremental_state=None,
|
| 217 |
+
need_weights=True,
|
| 218 |
+
static_kv=False,
|
| 219 |
+
attn_mask=None,
|
| 220 |
+
before_softmax=False,
|
| 221 |
+
need_head_weights=False,
|
| 222 |
+
enc_dec_attn_constraint_mask=None,
|
| 223 |
+
reset_attn_weight=None
|
| 224 |
+
):
|
| 225 |
+
"""Input shape: Time x Batch x Channel
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
| 229 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
| 230 |
+
padding elements are indicated by 1s.
|
| 231 |
+
need_weights (bool, optional): return the attention weights,
|
| 232 |
+
averaged over heads (default: False).
|
| 233 |
+
attn_mask (ByteTensor, optional): typically used to
|
| 234 |
+
implement causal attention, where the mask prevents the
|
| 235 |
+
attention from looking forward in time (default: None).
|
| 236 |
+
before_softmax (bool, optional): return the raw attention
|
| 237 |
+
weights and values before the attention softmax.
|
| 238 |
+
need_head_weights (bool, optional): return the attention
|
| 239 |
+
weights for each head. Implies *need_weights*. Default:
|
| 240 |
+
return the average attention weights over all heads.
|
| 241 |
+
"""
|
| 242 |
+
if need_head_weights:
|
| 243 |
+
need_weights = True
|
| 244 |
+
|
| 245 |
+
tgt_len, bsz, embed_dim = query.size()
|
| 246 |
+
assert embed_dim == self.embed_dim
|
| 247 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
| 248 |
+
if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
|
| 249 |
+
if self.qkv_same_dim:
|
| 250 |
+
return F.multi_head_attention_forward(query, key, value,
|
| 251 |
+
self.embed_dim, self.num_heads,
|
| 252 |
+
self.in_proj_weight,
|
| 253 |
+
self.in_proj_bias, self.bias_k, self.bias_v,
|
| 254 |
+
self.add_zero_attn, self.dropout,
|
| 255 |
+
self.out_proj.weight, self.out_proj.bias,
|
| 256 |
+
self.training, key_padding_mask, need_weights,
|
| 257 |
+
attn_mask)
|
| 258 |
+
else:
|
| 259 |
+
return F.multi_head_attention_forward(query, key, value,
|
| 260 |
+
self.embed_dim, self.num_heads,
|
| 261 |
+
torch.empty([0]),
|
| 262 |
+
self.in_proj_bias, self.bias_k, self.bias_v,
|
| 263 |
+
self.add_zero_attn, self.dropout,
|
| 264 |
+
self.out_proj.weight, self.out_proj.bias,
|
| 265 |
+
self.training, key_padding_mask, need_weights,
|
| 266 |
+
attn_mask, use_separate_proj_weight=True,
|
| 267 |
+
q_proj_weight=self.q_proj_weight,
|
| 268 |
+
k_proj_weight=self.k_proj_weight,
|
| 269 |
+
v_proj_weight=self.v_proj_weight)
|
| 270 |
+
|
| 271 |
+
if incremental_state is not None:
|
| 272 |
+
saved_state = self._get_input_buffer(incremental_state)
|
| 273 |
+
if 'prev_key' in saved_state:
|
| 274 |
+
# previous time steps are cached - no need to recompute
|
| 275 |
+
# key and value if they are static
|
| 276 |
+
if static_kv:
|
| 277 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
| 278 |
+
key = value = None
|
| 279 |
+
else:
|
| 280 |
+
saved_state = None
|
| 281 |
+
|
| 282 |
+
if self.self_attention:
|
| 283 |
+
# self-attention
|
| 284 |
+
q, k, v = self.in_proj_qkv(query)
|
| 285 |
+
elif self.encoder_decoder_attention:
|
| 286 |
+
# encoder-decoder attention
|
| 287 |
+
q = self.in_proj_q(query)
|
| 288 |
+
if key is None:
|
| 289 |
+
assert value is None
|
| 290 |
+
k = v = None
|
| 291 |
+
else:
|
| 292 |
+
k = self.in_proj_k(key)
|
| 293 |
+
v = self.in_proj_v(key)
|
| 294 |
+
|
| 295 |
+
else:
|
| 296 |
+
q = self.in_proj_q(query)
|
| 297 |
+
k = self.in_proj_k(key)
|
| 298 |
+
v = self.in_proj_v(value)
|
| 299 |
+
q *= self.scaling
|
| 300 |
+
|
| 301 |
+
if self.bias_k is not None:
|
| 302 |
+
assert self.bias_v is not None
|
| 303 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
| 304 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
| 305 |
+
if attn_mask is not None:
|
| 306 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
| 307 |
+
if key_padding_mask is not None:
|
| 308 |
+
key_padding_mask = torch.cat(
|
| 309 |
+
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
|
| 310 |
+
|
| 311 |
+
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
| 312 |
+
if k is not None:
|
| 313 |
+
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
| 314 |
+
if v is not None:
|
| 315 |
+
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
| 316 |
+
|
| 317 |
+
if saved_state is not None:
|
| 318 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
| 319 |
+
if 'prev_key' in saved_state:
|
| 320 |
+
prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
|
| 321 |
+
if static_kv:
|
| 322 |
+
k = prev_key
|
| 323 |
+
else:
|
| 324 |
+
k = torch.cat((prev_key, k), dim=1)
|
| 325 |
+
if 'prev_value' in saved_state:
|
| 326 |
+
prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
|
| 327 |
+
if static_kv:
|
| 328 |
+
v = prev_value
|
| 329 |
+
else:
|
| 330 |
+
v = torch.cat((prev_value, v), dim=1)
|
| 331 |
+
if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
|
| 332 |
+
prev_key_padding_mask = saved_state['prev_key_padding_mask']
|
| 333 |
+
if static_kv:
|
| 334 |
+
key_padding_mask = prev_key_padding_mask
|
| 335 |
+
else:
|
| 336 |
+
key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
|
| 337 |
+
|
| 338 |
+
saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
| 339 |
+
saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
| 340 |
+
saved_state['prev_key_padding_mask'] = key_padding_mask
|
| 341 |
+
|
| 342 |
+
self._set_input_buffer(incremental_state, saved_state)
|
| 343 |
+
|
| 344 |
+
src_len = k.size(1)
|
| 345 |
+
|
| 346 |
+
# This is part of a workaround to get around fork/join parallelism
|
| 347 |
+
# not supporting Optional types.
|
| 348 |
+
if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
|
| 349 |
+
key_padding_mask = None
|
| 350 |
+
|
| 351 |
+
if key_padding_mask is not None:
|
| 352 |
+
assert key_padding_mask.size(0) == bsz
|
| 353 |
+
assert key_padding_mask.size(1) == src_len
|
| 354 |
+
|
| 355 |
+
if self.add_zero_attn:
|
| 356 |
+
src_len += 1
|
| 357 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
| 358 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
| 359 |
+
if attn_mask is not None:
|
| 360 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
| 361 |
+
if key_padding_mask is not None:
|
| 362 |
+
key_padding_mask = torch.cat(
|
| 363 |
+
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
|
| 364 |
+
|
| 365 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
| 366 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
| 367 |
+
|
| 368 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
| 369 |
+
|
| 370 |
+
if attn_mask is not None:
|
| 371 |
+
if len(attn_mask.shape) == 2:
|
| 372 |
+
attn_mask = attn_mask.unsqueeze(0)
|
| 373 |
+
elif len(attn_mask.shape) == 3:
|
| 374 |
+
attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
|
| 375 |
+
bsz * self.num_heads, tgt_len, src_len)
|
| 376 |
+
attn_weights = attn_weights + attn_mask
|
| 377 |
+
|
| 378 |
+
if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
|
| 379 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 380 |
+
attn_weights = attn_weights.masked_fill(
|
| 381 |
+
enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
|
| 382 |
+
-1e8,
|
| 383 |
+
)
|
| 384 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 385 |
+
|
| 386 |
+
if key_padding_mask is not None:
|
| 387 |
+
# don't attend to padding symbols
|
| 388 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 389 |
+
attn_weights = attn_weights.masked_fill(
|
| 390 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
| 391 |
+
-1e8,
|
| 392 |
+
)
|
| 393 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 394 |
+
|
| 395 |
+
attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 396 |
+
|
| 397 |
+
if before_softmax:
|
| 398 |
+
return attn_weights, v
|
| 399 |
+
|
| 400 |
+
attn_weights_float = softmax(attn_weights, dim=-1)
|
| 401 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
| 402 |
+
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
|
| 403 |
+
|
| 404 |
+
if reset_attn_weight is not None:
|
| 405 |
+
if reset_attn_weight:
|
| 406 |
+
self.last_attn_probs = attn_probs.detach()
|
| 407 |
+
else:
|
| 408 |
+
assert self.last_attn_probs is not None
|
| 409 |
+
attn_probs = self.last_attn_probs
|
| 410 |
+
attn = torch.bmm(attn_probs, v)
|
| 411 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
| 412 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
| 413 |
+
attn = self.out_proj(attn)
|
| 414 |
+
|
| 415 |
+
if need_weights:
|
| 416 |
+
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
| 417 |
+
if not need_head_weights:
|
| 418 |
+
# average attention weights over heads
|
| 419 |
+
attn_weights = attn_weights.mean(dim=0)
|
| 420 |
+
else:
|
| 421 |
+
attn_weights = None
|
| 422 |
+
|
| 423 |
+
return attn, (attn_weights, attn_logits)
|
| 424 |
+
|
| 425 |
+
def in_proj_qkv(self, query):
|
| 426 |
+
return self._in_proj(query).chunk(3, dim=-1)
|
| 427 |
+
|
| 428 |
+
def in_proj_q(self, query):
|
| 429 |
+
if self.qkv_same_dim:
|
| 430 |
+
return self._in_proj(query, end=self.embed_dim)
|
| 431 |
+
else:
|
| 432 |
+
bias = self.in_proj_bias
|
| 433 |
+
if bias is not None:
|
| 434 |
+
bias = bias[:self.embed_dim]
|
| 435 |
+
return F.linear(query, self.q_proj_weight, bias)
|
| 436 |
+
|
| 437 |
+
def in_proj_k(self, key):
|
| 438 |
+
if self.qkv_same_dim:
|
| 439 |
+
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
|
| 440 |
+
else:
|
| 441 |
+
weight = self.k_proj_weight
|
| 442 |
+
bias = self.in_proj_bias
|
| 443 |
+
if bias is not None:
|
| 444 |
+
bias = bias[self.embed_dim:2 * self.embed_dim]
|
| 445 |
+
return F.linear(key, weight, bias)
|
| 446 |
+
|
| 447 |
+
def in_proj_v(self, value):
|
| 448 |
+
if self.qkv_same_dim:
|
| 449 |
+
return self._in_proj(value, start=2 * self.embed_dim)
|
| 450 |
+
else:
|
| 451 |
+
weight = self.v_proj_weight
|
| 452 |
+
bias = self.in_proj_bias
|
| 453 |
+
if bias is not None:
|
| 454 |
+
bias = bias[2 * self.embed_dim:]
|
| 455 |
+
return F.linear(value, weight, bias)
|
| 456 |
+
|
| 457 |
+
def _in_proj(self, input, start=0, end=None):
|
| 458 |
+
weight = self.in_proj_weight
|
| 459 |
+
bias = self.in_proj_bias
|
| 460 |
+
weight = weight[start:end, :]
|
| 461 |
+
if bias is not None:
|
| 462 |
+
bias = bias[start:end]
|
| 463 |
+
return F.linear(input, weight, bias)
|
| 464 |
+
|
| 465 |
+
def _get_input_buffer(self, incremental_state):
|
| 466 |
+
return get_incremental_state(
|
| 467 |
+
self,
|
| 468 |
+
incremental_state,
|
| 469 |
+
'attn_state',
|
| 470 |
+
) or {}
|
| 471 |
+
|
| 472 |
+
def _set_input_buffer(self, incremental_state, buffer):
|
| 473 |
+
set_incremental_state(
|
| 474 |
+
self,
|
| 475 |
+
incremental_state,
|
| 476 |
+
'attn_state',
|
| 477 |
+
buffer,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
|
| 481 |
+
return attn_weights
|
| 482 |
+
|
| 483 |
+
def clear_buffer(self, incremental_state=None):
|
| 484 |
+
if incremental_state is not None:
|
| 485 |
+
saved_state = self._get_input_buffer(incremental_state)
|
| 486 |
+
if 'prev_key' in saved_state:
|
| 487 |
+
del saved_state['prev_key']
|
| 488 |
+
if 'prev_value' in saved_state:
|
| 489 |
+
del saved_state['prev_value']
|
| 490 |
+
self._set_input_buffer(incremental_state, saved_state)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class EncSALayer(nn.Module):
|
| 494 |
+
def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
|
| 495 |
+
relu_dropout=0.1, kernel_size=9, padding='SAME', act='gelu'):
|
| 496 |
+
super().__init__()
|
| 497 |
+
self.c = c
|
| 498 |
+
self.dropout = dropout
|
| 499 |
+
self.num_heads = num_heads
|
| 500 |
+
if num_heads > 0:
|
| 501 |
+
self.layer_norm1 = LayerNorm(c)
|
| 502 |
+
self.self_attn = MultiheadAttention(
|
| 503 |
+
self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False)
|
| 504 |
+
self.layer_norm2 = LayerNorm(c)
|
| 505 |
+
self.ffn = TransformerFFNLayer(
|
| 506 |
+
c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
|
| 507 |
+
|
| 508 |
+
def forward(self, x, encoder_padding_mask=None, **kwargs):
|
| 509 |
+
layer_norm_training = kwargs.get('layer_norm_training', None)
|
| 510 |
+
if layer_norm_training is not None:
|
| 511 |
+
self.layer_norm1.training = layer_norm_training
|
| 512 |
+
self.layer_norm2.training = layer_norm_training
|
| 513 |
+
if self.num_heads > 0:
|
| 514 |
+
residual = x
|
| 515 |
+
x = self.layer_norm1(x)
|
| 516 |
+
x, _, = self.self_attn(
|
| 517 |
+
query=x,
|
| 518 |
+
key=x,
|
| 519 |
+
value=x,
|
| 520 |
+
key_padding_mask=encoder_padding_mask
|
| 521 |
+
)
|
| 522 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 523 |
+
x = residual + x
|
| 524 |
+
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
|
| 525 |
+
|
| 526 |
+
residual = x
|
| 527 |
+
x = self.layer_norm2(x)
|
| 528 |
+
x = self.ffn(x)
|
| 529 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 530 |
+
x = residual + x
|
| 531 |
+
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
|
| 532 |
+
return x
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
class DecSALayer(nn.Module):
|
| 536 |
+
def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
|
| 537 |
+
kernel_size=9, act='gelu'):
|
| 538 |
+
super().__init__()
|
| 539 |
+
self.c = c
|
| 540 |
+
self.dropout = dropout
|
| 541 |
+
self.layer_norm1 = LayerNorm(c)
|
| 542 |
+
self.self_attn = MultiheadAttention(
|
| 543 |
+
c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
|
| 544 |
+
)
|
| 545 |
+
self.layer_norm2 = LayerNorm(c)
|
| 546 |
+
self.encoder_attn = MultiheadAttention(
|
| 547 |
+
c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
|
| 548 |
+
)
|
| 549 |
+
self.layer_norm3 = LayerNorm(c)
|
| 550 |
+
self.ffn = TransformerFFNLayer(
|
| 551 |
+
c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
|
| 552 |
+
|
| 553 |
+
def forward(
|
| 554 |
+
self,
|
| 555 |
+
x,
|
| 556 |
+
encoder_out=None,
|
| 557 |
+
encoder_padding_mask=None,
|
| 558 |
+
incremental_state=None,
|
| 559 |
+
self_attn_mask=None,
|
| 560 |
+
self_attn_padding_mask=None,
|
| 561 |
+
attn_out=None,
|
| 562 |
+
reset_attn_weight=None,
|
| 563 |
+
**kwargs,
|
| 564 |
+
):
|
| 565 |
+
layer_norm_training = kwargs.get('layer_norm_training', None)
|
| 566 |
+
if layer_norm_training is not None:
|
| 567 |
+
self.layer_norm1.training = layer_norm_training
|
| 568 |
+
self.layer_norm2.training = layer_norm_training
|
| 569 |
+
self.layer_norm3.training = layer_norm_training
|
| 570 |
+
residual = x
|
| 571 |
+
x = self.layer_norm1(x)
|
| 572 |
+
x, _ = self.self_attn(
|
| 573 |
+
query=x,
|
| 574 |
+
key=x,
|
| 575 |
+
value=x,
|
| 576 |
+
key_padding_mask=self_attn_padding_mask,
|
| 577 |
+
incremental_state=incremental_state,
|
| 578 |
+
attn_mask=self_attn_mask
|
| 579 |
+
)
|
| 580 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 581 |
+
x = residual + x
|
| 582 |
+
|
| 583 |
+
attn_logits = None
|
| 584 |
+
if encoder_out is not None or attn_out is not None:
|
| 585 |
+
residual = x
|
| 586 |
+
x = self.layer_norm2(x)
|
| 587 |
+
if encoder_out is not None:
|
| 588 |
+
x, attn = self.encoder_attn(
|
| 589 |
+
query=x,
|
| 590 |
+
key=encoder_out,
|
| 591 |
+
value=encoder_out,
|
| 592 |
+
key_padding_mask=encoder_padding_mask,
|
| 593 |
+
incremental_state=incremental_state,
|
| 594 |
+
static_kv=True,
|
| 595 |
+
enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
|
| 596 |
+
'enc_dec_attn_constraint_mask'),
|
| 597 |
+
reset_attn_weight=reset_attn_weight
|
| 598 |
+
)
|
| 599 |
+
attn_logits = attn[1]
|
| 600 |
+
elif attn_out is not None:
|
| 601 |
+
x = self.encoder_attn.in_proj_v(attn_out)
|
| 602 |
+
if encoder_out is not None or attn_out is not None:
|
| 603 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 604 |
+
x = residual + x
|
| 605 |
+
|
| 606 |
+
residual = x
|
| 607 |
+
x = self.layer_norm3(x)
|
| 608 |
+
x = self.ffn(x, incremental_state=incremental_state)
|
| 609 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 610 |
+
x = residual + x
|
| 611 |
+
return x, attn_logits
|
| 612 |
+
|
| 613 |
+
def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
|
| 614 |
+
self.encoder_attn.clear_buffer(incremental_state)
|
| 615 |
+
self.ffn.clear_buffer(incremental_state)
|
| 616 |
+
|
| 617 |
+
def set_buffer(self, name, tensor, incremental_state):
|
| 618 |
+
return set_incremental_state(self, incremental_state, name, tensor)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
class TransformerEncoderLayer(nn.Module):
|
| 622 |
+
def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2):
|
| 623 |
+
super().__init__()
|
| 624 |
+
self.hidden_size = hidden_size
|
| 625 |
+
self.dropout = dropout
|
| 626 |
+
self.num_heads = num_heads
|
| 627 |
+
self.op = EncSALayer(
|
| 628 |
+
hidden_size, num_heads, dropout=dropout,
|
| 629 |
+
attention_dropout=0.0, relu_dropout=dropout,
|
| 630 |
+
kernel_size=kernel_size)
|
| 631 |
+
|
| 632 |
+
def forward(self, x, **kwargs):
|
| 633 |
+
return self.op(x, **kwargs)
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
class TransformerDecoderLayer(nn.Module):
|
| 637 |
+
def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2):
|
| 638 |
+
super().__init__()
|
| 639 |
+
self.hidden_size = hidden_size
|
| 640 |
+
self.dropout = dropout
|
| 641 |
+
self.num_heads = num_heads
|
| 642 |
+
self.op = DecSALayer(
|
| 643 |
+
hidden_size, num_heads, dropout=dropout,
|
| 644 |
+
attention_dropout=0.0, relu_dropout=dropout,
|
| 645 |
+
kernel_size=kernel_size)
|
| 646 |
+
|
| 647 |
+
def forward(self, x, **kwargs):
|
| 648 |
+
return self.op(x, **kwargs)
|
| 649 |
+
|
| 650 |
+
def clear_buffer(self, *args):
|
| 651 |
+
return self.op.clear_buffer(*args)
|
| 652 |
+
|
| 653 |
+
def set_buffer(self, *args):
|
| 654 |
+
return self.op.set_buffer(*args)
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
class FFTBlocks(nn.Module):
|
| 658 |
+
def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=0.0,
|
| 659 |
+
num_heads=2, use_pos_embed=True, use_last_norm=True,
|
| 660 |
+
use_pos_embed_alpha=True):
|
| 661 |
+
super().__init__()
|
| 662 |
+
self.num_layers = num_layers
|
| 663 |
+
embed_dim = self.hidden_size = hidden_size
|
| 664 |
+
self.dropout = dropout
|
| 665 |
+
self.use_pos_embed = use_pos_embed
|
| 666 |
+
self.use_last_norm = use_last_norm
|
| 667 |
+
if use_pos_embed:
|
| 668 |
+
self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
|
| 669 |
+
self.padding_idx = 0
|
| 670 |
+
self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
|
| 671 |
+
self.embed_positions = SinusoidalPositionalEmbedding(
|
| 672 |
+
embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
self.layers = nn.ModuleList([])
|
| 676 |
+
self.layers.extend([
|
| 677 |
+
TransformerEncoderLayer(self.hidden_size, self.dropout,
|
| 678 |
+
kernel_size=ffn_kernel_size, num_heads=num_heads)
|
| 679 |
+
for _ in range(self.num_layers)
|
| 680 |
+
])
|
| 681 |
+
if self.use_last_norm:
|
| 682 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
| 683 |
+
else:
|
| 684 |
+
self.layer_norm = None
|
| 685 |
+
|
| 686 |
+
def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
|
| 687 |
+
"""
|
| 688 |
+
:param x: [B, T, C]
|
| 689 |
+
:param padding_mask: [B, T]
|
| 690 |
+
:return: [B, T, C] or [L, B, T, C]
|
| 691 |
+
"""
|
| 692 |
+
padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
|
| 693 |
+
nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
|
| 694 |
+
if self.use_pos_embed:
|
| 695 |
+
positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
|
| 696 |
+
x = x + positions
|
| 697 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 698 |
+
# B x T x C -> T x B x C
|
| 699 |
+
x = x.transpose(0, 1) * nonpadding_mask_TB
|
| 700 |
+
hiddens = []
|
| 701 |
+
for layer in self.layers:
|
| 702 |
+
x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
|
| 703 |
+
hiddens.append(x)
|
| 704 |
+
if self.use_last_norm:
|
| 705 |
+
x = self.layer_norm(x) * nonpadding_mask_TB
|
| 706 |
+
if return_hiddens:
|
| 707 |
+
x = torch.stack(hiddens, 0) # [L, T, B, C]
|
| 708 |
+
x = x.transpose(1, 2) # [L, B, T, C]
|
| 709 |
+
else:
|
| 710 |
+
x = x.transpose(0, 1) # [B, T, C]
|
| 711 |
+
return x
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
class FastSpeechEncoder(FFTBlocks):
|
| 715 |
+
def __init__(self, dict_size, hidden_size=256, num_layers=4, kernel_size=9, num_heads=2,
|
| 716 |
+
dropout=0.0):
|
| 717 |
+
super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
|
| 718 |
+
use_pos_embed=False, dropout=dropout) # use_pos_embed_alpha for compatibility
|
| 719 |
+
self.embed_tokens = Embedding(dict_size, hidden_size, 0)
|
| 720 |
+
self.embed_scale = math.sqrt(hidden_size)
|
| 721 |
+
self.padding_idx = 0
|
| 722 |
+
self.embed_positions = SinusoidalPositionalEmbedding(
|
| 723 |
+
hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
def forward(self, txt_tokens, attn_mask=None):
|
| 727 |
+
"""
|
| 728 |
+
|
| 729 |
+
:param txt_tokens: [B, T]
|
| 730 |
+
:return: {
|
| 731 |
+
'encoder_out': [B x T x C]
|
| 732 |
+
}
|
| 733 |
+
"""
|
| 734 |
+
encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
|
| 735 |
+
x = self.forward_embedding(txt_tokens) # [B, T, H]
|
| 736 |
+
if self.num_layers > 0:
|
| 737 |
+
x = super(FastSpeechEncoder, self).forward(x, encoder_padding_mask, attn_mask=attn_mask)
|
| 738 |
+
return x
|
| 739 |
+
|
| 740 |
+
def forward_embedding(self, txt_tokens):
|
| 741 |
+
# embed tokens and positions
|
| 742 |
+
x = self.embed_scale * self.embed_tokens(txt_tokens)
|
| 743 |
+
positions = self.embed_positions(txt_tokens)
|
| 744 |
+
x = x + positions
|
| 745 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 746 |
+
return x
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
class FastSpeechDecoder(FFTBlocks):
|
| 750 |
+
def __init__(self, hidden_size=256, num_layers=4, kernel_size=9, num_heads=2):
|
| 751 |
+
super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads)
|
preprocess/tools/note_transcription/modules/commons/wavenet.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from packaging import version
|
| 4 |
+
|
| 5 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
| 6 |
+
n_channels_int = n_channels[0]
|
| 7 |
+
in_act = input_a + input_b
|
| 8 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
| 9 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
| 10 |
+
acts = t_act * s_act
|
| 11 |
+
return acts
|
| 12 |
+
|
| 13 |
+
jit_fused_add_tanh_sigmoid_multiply = fused_add_tanh_sigmoid_multiply
|
| 14 |
+
|
| 15 |
+
def script_function():
|
| 16 |
+
if version.parse(torch.__version__) >= version.parse('2.0'):
|
| 17 |
+
global jit_fused_add_tanh_sigmoid_multiply
|
| 18 |
+
jit_fused_add_tanh_sigmoid_multiply = torch.jit.script(fused_add_tanh_sigmoid_multiply)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class WN(torch.nn.Module):
|
| 22 |
+
def __init__(self, hidden_size, kernel_size, dilation_rate, n_layers, c_cond=0,
|
| 23 |
+
p_dropout=0, share_cond_layers=False, is_BTC=False):
|
| 24 |
+
super(WN, self).__init__()
|
| 25 |
+
assert (kernel_size % 2 == 1)
|
| 26 |
+
assert (hidden_size % 2 == 0)
|
| 27 |
+
self.is_BTC = is_BTC
|
| 28 |
+
self.hidden_size = hidden_size
|
| 29 |
+
self.kernel_size = kernel_size
|
| 30 |
+
self.dilation_rate = dilation_rate
|
| 31 |
+
self.n_layers = n_layers
|
| 32 |
+
self.gin_channels = c_cond
|
| 33 |
+
self.p_dropout = p_dropout
|
| 34 |
+
self.share_cond_layers = share_cond_layers
|
| 35 |
+
|
| 36 |
+
self.in_layers = torch.nn.ModuleList()
|
| 37 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
| 38 |
+
self.drop = nn.Dropout(p_dropout)
|
| 39 |
+
|
| 40 |
+
if c_cond != 0 and not share_cond_layers:
|
| 41 |
+
cond_layer = torch.nn.Conv1d(c_cond, 2 * hidden_size * n_layers, 1)
|
| 42 |
+
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
|
| 43 |
+
|
| 44 |
+
for i in range(n_layers):
|
| 45 |
+
dilation = dilation_rate ** i
|
| 46 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
| 47 |
+
in_layer = torch.nn.Conv1d(hidden_size, 2 * hidden_size, kernel_size,
|
| 48 |
+
dilation=dilation, padding=padding)
|
| 49 |
+
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
|
| 50 |
+
self.in_layers.append(in_layer)
|
| 51 |
+
|
| 52 |
+
# last one is not necessary
|
| 53 |
+
if i < n_layers - 1:
|
| 54 |
+
res_skip_channels = 2 * hidden_size
|
| 55 |
+
else:
|
| 56 |
+
res_skip_channels = hidden_size
|
| 57 |
+
|
| 58 |
+
res_skip_layer = torch.nn.Conv1d(hidden_size, res_skip_channels, 1)
|
| 59 |
+
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
|
| 60 |
+
self.res_skip_layers.append(res_skip_layer)
|
| 61 |
+
|
| 62 |
+
script_function()
|
| 63 |
+
|
| 64 |
+
def forward(self, x, nonpadding=None, cond=None):
|
| 65 |
+
if self.is_BTC:
|
| 66 |
+
x = x.transpose(1, 2)
|
| 67 |
+
cond = cond.transpose(1, 2) if cond is not None else None
|
| 68 |
+
nonpadding = nonpadding.transpose(1, 2) if nonpadding is not None else None
|
| 69 |
+
if nonpadding is None:
|
| 70 |
+
nonpadding = 1
|
| 71 |
+
output = torch.zeros_like(x)
|
| 72 |
+
n_channels_tensor = torch.IntTensor([self.hidden_size])
|
| 73 |
+
|
| 74 |
+
if cond is not None and not self.share_cond_layers:
|
| 75 |
+
cond = self.cond_layer(cond)
|
| 76 |
+
|
| 77 |
+
for i in range(self.n_layers):
|
| 78 |
+
x_in = self.in_layers[i](x)
|
| 79 |
+
x_in = self.drop(x_in)
|
| 80 |
+
if cond is not None:
|
| 81 |
+
cond_offset = i * 2 * self.hidden_size
|
| 82 |
+
cond_l = cond[:, cond_offset:cond_offset + 2 * self.hidden_size, :]
|
| 83 |
+
else:
|
| 84 |
+
cond_l = torch.zeros_like(x_in)
|
| 85 |
+
|
| 86 |
+
if version.parse(torch.__version__) >= version.parse('2.0'):
|
| 87 |
+
acts = jit_fused_add_tanh_sigmoid_multiply(x_in, cond_l, n_channels_tensor)
|
| 88 |
+
else:
|
| 89 |
+
acts = fused_add_tanh_sigmoid_multiply(x_in, cond_l, n_channels_tensor)
|
| 90 |
+
|
| 91 |
+
res_skip_acts = self.res_skip_layers[i](acts)
|
| 92 |
+
if i < self.n_layers - 1:
|
| 93 |
+
x = (x + res_skip_acts[:, :self.hidden_size, :]) * nonpadding
|
| 94 |
+
output = output + res_skip_acts[:, self.hidden_size:, :]
|
| 95 |
+
else:
|
| 96 |
+
output = output + res_skip_acts
|
| 97 |
+
output = output * nonpadding
|
| 98 |
+
if self.is_BTC:
|
| 99 |
+
output = output.transpose(1, 2)
|
| 100 |
+
return output
|
| 101 |
+
|
| 102 |
+
def remove_weight_norm(self):
|
| 103 |
+
def remove_weight_norm(m):
|
| 104 |
+
try:
|
| 105 |
+
nn.utils.remove_weight_norm(m)
|
| 106 |
+
except ValueError: # this module didn't have weight norm
|
| 107 |
+
return
|
| 108 |
+
|
| 109 |
+
self.apply(remove_weight_norm)
|
preprocess/tools/note_transcription/modules/pe/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Pitch extractor modules for ROSVOT."""
|