Audio-to-Audio
audio
speech
voice-conversion
Project Beatrice commited on
Commit
f34836d
·
1 Parent(s): 5ddb63e

Add 2.0.0-rc.0 features

Browse files
.gitignore CHANGED
@@ -1,3 +1,3 @@
1
- poetry.lock
2
  work/*
3
  __pycache__
 
1
+ *.lock
2
  work/*
3
  __pycache__
README.md CHANGED
@@ -22,15 +22,38 @@ Beatrice 2 は、以下を目標に開発されています。
22
  * 変換音声の高い自然性と明瞭さ
23
  * 多様な変換先話者
24
  * 公式 VST での変換時、外部の録音機器を使った実測で 50ms 程度の遅延
25
- * 開発者のノート PC (Intel Core i7-1165G7) でシングルスレッドで動作させ、RTF < 0.25 となる程度の負荷
26
  * 最小構成で 30MB 以下の容量
27
- * VST と [VC Client](https://github.com/w-okada/voice-changer) での動作
28
  * その他 (内緒)
29
 
30
  ## Release Notes
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  * **2024-10-20**: Beatrice Trainer 2.0.0-beta.2 をリリースしました。
33
- * **[公式 VST](https://prj-beatrice.com) や [VC Client](https://github.com/w-okada/voice-changer) を最新版にアップデートしてください。新しい Trainer で生成したモデルは、古いバージョンの公式 VST や VC Client で動作しません。**
34
  * [Scaled Weight Standardization](https://arxiv.org/abs/2101.08692) の導入により、学習の安定性が向上しました。
35
  * 無音に非常に近い音声に対する損失の計算結果が nan になる問題を修正し、学習の安定性が向上しました。
36
  * 周期信号の生成方法を変更し、事前学習モデルを用いない場合により少ない学習ステップ数で高品質な変換音声を生成できるようになりました。
@@ -53,7 +76,7 @@ Beatrice は、既存の学習済みモデルを用いて声質の変換を行
53
  しかし、新たなモデルの作成を効率良く行うためには GPU が必要です。
54
 
55
  学習スクリプトを実行すると、デフォルト設定では 9GB 程度の VRAM を消費します。
56
- GeForce RTX 4090 を使用した場合、 30 分程度で学習が��了します。
57
 
58
  GPU を手元に用意できない場合でも、以下のリポジトリを使用して Google Colab 上で学習を行うことができます。
59
 
@@ -73,14 +96,15 @@ cd beatrice-trainer
73
 
74
  ### 2. Environment Setup
75
 
76
- Poetry などを使用して、依存ライブラリをインストールしてください。
77
 
78
  ```sh
79
- poetry install
80
- poetry shell
81
  # Alternatively, you can use pip to install dependencies directly:
82
- # pip3 install -e .
83
  ```
 
84
 
85
  正しくインストールできていれば、 `python3 beatrice_trainer -h` で以下のようなヘルプが表示されます。
86
 
@@ -153,8 +177,8 @@ tensorboard --logdir <output_dir>
153
  ### 5. After Training
154
 
155
  学習が正常に完了すると、出力ディレクトリ内に `paraphernalia_(data_dir_name)_(step)` という名前のディレクトリが生成されています。
156
- このディレクトリを[公式 VST](https://prj-beatrice.com) [VC Client](https://github.com/w-okada/voice-changer) で読み込むことで、ストリーム (リアルタイム) 変換を行うことができます。
157
- **読み込めない場合は公式 VST VC Client のバージョンが古い可能性がありますので、最新のバージョンにアップデートしてください。**
158
 
159
  ## Detailed Usage
160
 
@@ -183,11 +207,11 @@ python3 beatrice_trainer -d <your_training_data_dir> -o <output_dir> -r
183
  * ストリーム変換に必要なファイルを全て含むディレクトリです。
184
  * 学習途中のものも出力される場合があり、必要なステップ数のもの以外は削除して問題ありません。
185
  * このディレクトリ以外の出力物はストリーム変換に使用されないため、不要であれば削除して問題ありません。
186
- * `checkpoint_(data_dir_name)_(step)`
187
  * 学習を途中から再開するためのチェックポイントです。
188
- * checkpoint_latest.pt にリネームし、 `-r` オプションを付けて学習スクリプトを実行すると、そのステップ数から学習を再開できます。
189
- * `checkpoint_latest.pt`
190
- * 最も新しい checkpoint_(data_dir_name)_(step) のコピーです。
191
  * `config.json`
192
  * 学習に使用されたコンフィグです。
193
  * `events.out.tfevents.*`
@@ -195,12 +219,12 @@ python3 beatrice_trainer -d <your_training_data_dir> -o <output_dir> -r
195
 
196
  ### Customize Paraphernalia
197
 
198
- 学習スクリプトによって生成された paraphernalia ディレクトリ内にある `beatrice_paraphernalia_*.toml` ファイルを編集することで、 VST VC Client 上での表示を変更できます。
199
 
200
  `model.version` は、生成されたモデルのフォーマットバージョンを表すため、変更しないでください。
201
 
202
  各 `description` は、長すぎると全文が表示されない場合があります。
203
- 現在表示できていても、将来的な VST VC Client の仕様変更により表示できなくなる可能性があるため、余裕を持った文字数・行数に収めてください。
204
 
205
  `portrait` に設定する画像は、 PNG 形式かつ正方形としてください。
206
 
@@ -232,16 +256,20 @@ python3 beatrice_trainer -d <your_training_data_dir> -o <output_dir> -r
232
  * 損失関数の実装に利用。
233
  * [UnivNet](https://arxiv.org/abs/2106.07889) ([Unofficial implementation by maum-ai](https://github.com/maum-ai/univnet), [BSD 3-Clause License](https://github.com/maum-ai/univnet/blob/master/LICENSE))
234
  * DiscriminatorR の実装に利用。
 
 
235
  * [NF-ResNets](https://arxiv.org/abs/2101.08692)
236
  * Scaled Weight Standardization のアイデアを利用。
237
  * [Soft-VC](https://arxiv.org/abs/2111.02392)
238
  * PhoneExtractor の基本的なアイデアとして利用。
 
 
239
  * [Descript Audio Codec](https://arxiv.org/abs/2306.06546)
240
  * Multi-scale mel loss のアイデアを利用。
241
  * [StreamVC](https://arxiv.org/abs/2401.03078)
242
  * 声質変換スキームの基本的なアイデアとして利用。
243
  * [FIRNet](https://ast-astrec.nict.go.jp/release/preprints/preprint_icassp_2024_ohtani.pdf)
244
- * FIR フィルタを Vocoder に適用するアイデアを利用。
245
  * [EVA-GAN](https://arxiv.org/abs/2402.00892)
246
  * SiLU を vocoder に適用するアイデアを利用。
247
  * [Subramani et al., 2024](https://arxiv.org/abs/2309.14507)
 
22
  * 変換音声の高い自然性と明瞭さ
23
  * 多様な変換先話者
24
  * 公式 VST での変換時、外部の録音機器を使った実測で 50ms 程度の遅延
25
+ * 開発者のノート PC (Intel Core i7-1165G7) でシングルスレッドで動作させ、 RTF < 0.2 となる程度の負荷
26
  * 最小構成で 30MB 以下の容量
27
+ * VST と [VCClient](https://github.com/w-okada/voice-changer) での動作
28
  * その他 (内緒)
29
 
30
  ## Release Notes
31
 
32
+ * **2025-08-31**: Beatrice Trainer 2.0.0-rc.0 をリリースしました。
33
+ * **[公式 VST](https://prj-beatrice.com)、 [VCClient](https://github.com/w-okada/voice-changer)、 [beatrice-client](https://github.com/aq2r/beatrice-client) を最新版にアップデートしてください。新しい Trainer で生成したモデルは、古いバージョンの公式 VST、 VCClient、 beatrice-client で動作しません。**
34
+ * RTF の目標値を 0.25 から 0.2 に変更しました。
35
+ * パッケージマネージャを Poetry から uv に変更しました。
36
+ * PitchEstimator の学習データに VocalSet を追加しました。
37
+ * PitchEstimator の出力値の上限を A5 付近から F6 付近に引き上げました。
38
+ * PitchEstimator が有声/無声の予測を行わないように変更しました。
39
+ * PitchEstimator のアーキテクチャで、活性化関数が欠落していた箇所を修正しました。
40
+ * PhoneExtractor のアーキテクチャに self-attention の追加や GRU の削除などの変更を行い、処理効率が向上しました。
41
+ * WaveGenerator のアーキテクチャに cross-attention によって話者性を注入する構造を追加し、話者類似性が向上しました。
42
+ * PhoneExtractor の出力に対して学習時にノイズを加算することにより、生成音声の品質が向上しました。
43
+ * PhoneExtractor の出力に対する [kNN-VC](https://arxiv.org/abs/2305.18975) に類似したベクトル量子化処理を追加し、話者類似性が向上しました。
44
+ * Discriminator に入力する波形に微細なノイズを加算する処理を追加し、学習の安定性が向上しました。
45
+ * GradientEqualizer は品質への寄与が確認できなかったため、削除しました。
46
+ * Data augmentation の処理にフォルマントシフトを追加し、話者類似性が向上しました。
47
+ * Aperiodicity loss の計算における半フレームのずれを修正しました。
48
+ * Aperiodicity loss を音量が非常に小さい部分では 0 とし、学習の安定性が向上しました。
49
+ * Loudness loss を追加し、生成音声の品質が向上しました。
50
+ * 学習率のスケジューリングを cosine から exponential に変更し、学習の延長が行いやすくなりました。
51
+ * チェックポイントファイルを圧縮して保存するように変更しました。
52
+ * コンフィグファイルで設定可能な項目を追加しました。
53
+ * 損失関数の値などによって品質が評価できると誤解されることを避けるため、TensorBoard への数値の記録をデフォルトで無効にしました。
54
+ * ハイパーパラメータの調整や、その他いくつかの変更を行いました。
55
  * **2024-10-20**: Beatrice Trainer 2.0.0-beta.2 をリリースしました。
56
+ * **[公式 VST](https://prj-beatrice.com) や [VCClient](https://github.com/w-okada/voice-changer) を最新版にアップデートしてください。新しい Trainer で生成したモデルは、古いバージョンの公式 VST や VCClient で動作しません。**
57
  * [Scaled Weight Standardization](https://arxiv.org/abs/2101.08692) の導入により、学習の安定性が向上しました。
58
  * 無音に非常に近い音声に対する損失の計算結果が nan になる問題を修正し、学習の安定性が向上しました。
59
  * 周期信号の生成方法を変更し、事前学習モデルを用いない場合により少ない学習ステップ数で高品質な変換音声を生成できるようになりました。
 
76
  しかし、新たなモデルの作成を効率良く行うためには GPU が必要です。
77
 
78
  学習スクリプトを実行すると、デフォルト設定では 9GB 程度の VRAM を消費します。
79
+ GeForce RTX 4090 を使用した場合、 40 分程度で学習が完了します。
80
 
81
  GPU を手元に用意できない場合でも、以下のリポジトリを使用して Google Colab 上で学習を行うことができます。
82
 
 
96
 
97
  ### 2. Environment Setup
98
 
99
+ uv などを使用して、依存ライブラリをインストールしてください。
100
 
101
  ```sh
102
+ uv sync --extra cu128
103
+ . .venv/bin/activate
104
  # Alternatively, you can use pip to install dependencies directly:
105
+ # pip3 install -e .[cu128]
106
  ```
107
+ Windows 環境では、 `. .venv/bin/activate` の代わりに `.venv\Scripts\activate` を実行してください。
108
 
109
  正しくインストールできていれば、 `python3 beatrice_trainer -h` で以下のようなヘルプが表示されます。
110
 
 
177
  ### 5. After Training
178
 
179
  学習が正常に完了すると、出力ディレクトリ内に `paraphernalia_(data_dir_name)_(step)` という名前のディレクトリが生成されています。
180
+ このディレクトリを[公式 VST](https://prj-beatrice.com) [VCClient](https://github.com/w-okada/voice-changer) または [beatrice-client](https://github.com/aq2r/beatrice-client) で読み込むことで、ストリーム (リアルタイム) 変換を行うことができます。
181
+ **読み込めない場合は公式 VST VCClient、 beatrice-client のバージョンが古い可能性がありますので、最新のバージョンにアップデートしてください。**
182
 
183
  ## Detailed Usage
184
 
 
207
  * ストリーム変換に必要なファイルを全て含むディレクトリです。
208
  * 学習途中のものも出力される場合があり、必要なステップ数のもの以外は削除して問題ありません。
209
  * このディレクトリ以外の出力物はストリーム変換に使用されないため、不要であれば削除して問題ありません。
210
+ * `checkpoint_(data_dir_name)_(step).pt.gz`
211
  * 学習を途中から再開するためのチェックポイントです。
212
+ * checkpoint_latest.pt.gz にリネームし、 `-r` オプションを付けて学習スクリプトを実行すると、そのステップ数から学習を再開できます。
213
+ * `checkpoint_latest.pt.gz`
214
+ * 最も新しい checkpoint_(data_dir_name)_(step).pt.gz のコピーです。
215
  * `config.json`
216
  * 学習に使用されたコンフィグです。
217
  * `events.out.tfevents.*`
 
219
 
220
  ### Customize Paraphernalia
221
 
222
+ 学習スクリプトによって生成された paraphernalia ディレクトリ内にある `beatrice_paraphernalia_*.toml` ファイルを編集することで、 VST VCClient、 beatrice-client 上での表示を変更できます。
223
 
224
  `model.version` は、生成されたモデルのフォーマットバージョンを表すため、変更しないでください。
225
 
226
  各 `description` は、長すぎると全文が表示されない場合があります。
227
+ 現在表示できていても、将来的な VST VCClient または beatrice-client の仕様変更により表示できなくなる可能性があるため、余裕を持った文字数・行数に収めてください。
228
 
229
  `portrait` に設定する画像は、 PNG 形式かつ正方形としてください。
230
 
 
256
  * 損失関数の実装に利用。
257
  * [UnivNet](https://arxiv.org/abs/2106.07889) ([Unofficial implementation by maum-ai](https://github.com/maum-ai/univnet), [BSD 3-Clause License](https://github.com/maum-ai/univnet/blob/master/LICENSE))
258
  * DiscriminatorR の実装に利用。
259
+ * [FragmentVC](https://arxiv.org/abs/2010.14150)
260
+ * SSL モデルに由来する特徴量をクエリとした cross-attention により声質を注入するアイデアを利用。
261
  * [NF-ResNets](https://arxiv.org/abs/2101.08692)
262
  * Scaled Weight Standardization のアイデアを利用。
263
  * [Soft-VC](https://arxiv.org/abs/2111.02392)
264
  * PhoneExtractor の基本的なアイデアとして利用。
265
+ * [kNN-VC](https://arxiv.org/abs/2305.18975)
266
+ * 声質変換スキームを補助的にアイデアとして利用。
267
  * [Descript Audio Codec](https://arxiv.org/abs/2306.06546)
268
  * Multi-scale mel loss のアイデアを利用。
269
  * [StreamVC](https://arxiv.org/abs/2401.03078)
270
  * 声質変換スキームの基本的なアイデアとして利用。
271
  * [FIRNet](https://ast-astrec.nict.go.jp/release/preprints/preprint_icassp_2024_ohtani.pdf)
272
+ * FIR フィルタを vocoder に適用するアイデアを利用。
273
  * [EVA-GAN](https://arxiv.org/abs/2402.00892)
274
  * SiLU を vocoder に適用するアイデアを利用。
275
  * [Subramani et al., 2024](https://arxiv.org/abs/2309.14507)
assets/README.md CHANGED
@@ -15,7 +15,7 @@
15
  ## Pretrained
16
 
17
  Beatrice の事前学習済みモデルです。
18
- [ReazonSpeech](https://huggingface.co/datasets/reazon-research/reazonspeech), [DNS-Chellenge](https://github.com/microsoft/DNS-Challenge), [LibriTTS-R](https://www.openslr.org/141/) のデータを使用して学習されています。
19
 
20
  ## Test
21
 
 
15
  ## Pretrained
16
 
17
  Beatrice の事前学習済みモデルです。
18
+ [ReazonSpeech](https://huggingface.co/datasets/reazon-research/reazonspeech), [VocalSet](https://zenodo.org/records/1193957), [DNS-Chellenge](https://github.com/microsoft/DNS-Challenge), [LibriTTS-R](https://www.openslr.org/141/) のデータを使用して学習されています。
19
 
20
  ## Test
21
 
assets/default_config.json CHANGED
@@ -1,33 +1,60 @@
1
  {
2
- "learning_rate_g": 2e-4,
3
- "learning_rate_d": 1e-4,
4
- "min_learning_rate_g": 1e-5,
5
- "min_learning_rate_d": 5e-6,
6
  "adam_betas": [
7
  0.8,
8
  0.99
9
  ],
10
  "adam_eps": 1e-6,
11
  "batch_size": 8,
12
- "grad_weight_mel": 1.0,
13
- "grad_weight_ap": 2.0,
14
- "grad_weight_adv": 3.0,
15
- "grad_weight_fm": 3.0,
 
16
  "grad_balancer_ema_decay": 0.995,
17
  "use_amp": true,
18
  "num_workers": 16,
19
  "n_steps": 10000,
20
- "warmup_steps": 2000,
 
 
21
  "in_sample_rate": 16000,
22
  "out_sample_rate": 24000,
23
  "wav_length": 96000,
24
  "segment_length": 100,
25
- "phone_extractor_file": "assets/pretrained/003b_checkpoint_03000000.pt",
26
- "pitch_estimator_file": "assets/pretrained/008_1_checkpoint_00300000.pt",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  "in_ir_wav_dir": "assets/ir",
28
  "in_noise_wav_dir": "assets/noise",
29
  "in_test_wav_dir": "assets/test",
30
- "pretrained_file": "assets/pretrained/079_checkpoint_libritts_r_200_02400000.pt",
 
31
  "hidden_channels": 256,
32
  "san": false,
33
  "compile_convnext": false,
 
1
  {
2
+ "learning_rate_g": 5e-5,
3
+ "learning_rate_d": 5e-5,
4
+ "learning_rate_decay": 0.999999,
 
5
  "adam_betas": [
6
  0.8,
7
  0.99
8
  ],
9
  "adam_eps": 1e-6,
10
  "batch_size": 8,
11
+ "grad_weight_loudness": 1.0,
12
+ "grad_weight_mel": 50.0,
13
+ "grad_weight_ap": 100.0,
14
+ "grad_weight_adv": 150.0,
15
+ "grad_weight_fm": 150.0,
16
  "grad_balancer_ema_decay": 0.995,
17
  "use_amp": true,
18
  "num_workers": 16,
19
  "n_steps": 10000,
20
+ "warmup_steps": 5000,
21
+ "evaluation_interval": 2000,
22
+ "save_interval": 2000,
23
  "in_sample_rate": 16000,
24
  "out_sample_rate": 24000,
25
  "wav_length": 96000,
26
  "segment_length": 100,
27
+ "phone_noise_ratio": 0.5,
28
+ "vq_topk": 4,
29
+ "training_time_vq": "none",
30
+ "floor_noise_level": 1e-3,
31
+ "record_metrics": false,
32
+ "augmentation_snr_candidates": [
33
+ 20.0,
34
+ 25.0,
35
+ 30.0,
36
+ 35.0,
37
+ 40.0,
38
+ 45.0
39
+ ],
40
+ "augmentation_formant_shift_probability": 0.5,
41
+ "augmentation_formant_shift_semitone_min": -3.0,
42
+ "augmentation_formant_shift_semitone_max": 3.0,
43
+ "augmentation_reverb_probability": 0.5,
44
+ "augmentation_lpf_probability": 0.2,
45
+ "augmentation_lpf_cutoff_freq_candidates": [
46
+ 2000.0,
47
+ 3000.0,
48
+ 4000.0,
49
+ 6000.0
50
+ ],
51
+ "phone_extractor_file": "assets/pretrained/122_checkpoint_03000000.pt",
52
+ "pitch_estimator_file": "assets/pretrained/104_3_checkpoint_00300000.pt",
53
  "in_ir_wav_dir": "assets/ir",
54
  "in_noise_wav_dir": "assets/noise",
55
  "in_test_wav_dir": "assets/test",
56
+ "pretrained_file": "assets/pretrained/151_checkpoint_libritts_r_200_02750000.pt.gz",
57
+ "pitch_bins": 448,
58
  "hidden_channels": 256,
59
  "san": false,
60
  "compile_convnext": false,
assets/pretrained/{008_1_checkpoint_00300000.pt → 104_3_checkpoint_00300000.pt} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:32174239b2fa3411544a8d6015f970fd5de65b7b512864f6980cbfe6f47043a6
3
- size 6907000
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:174e5411009e0e4f6ee8a8c97c4cd2f646791eae1b9aa2b425acb797e0353ef4
3
+ size 7061178
assets/pretrained/{003b_checkpoint_03000000.pt → 122_checkpoint_03000000.pt} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:48b250b90b482d7510e7f2c1148ccb186160a3f9a1b6289d3c53779cb217cf64
3
- size 26504680
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46e2d609825ace2158c83672cfc9cc1dcb3c2b7c8d294ee911fcb6840a592bae
3
+ size 14657692
assets/pretrained/{079_checkpoint_libritts_r_200_02400000.pt → 151_checkpoint_libritts_r_200_02750000.pt.gz} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3c2c87af05cb645f96fe6df651999f9b20bf66fa4e98af17c84211a742b62fe6
3
- size 186736305
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14ecdb01e51cf22b80664973daa3dedeeb0bada48bbf5262e58950c818cdcb1a
3
+ size 153189983
beatrice_trainer/__main__.py CHANGED
@@ -4,6 +4,7 @@
4
  # %%
5
  import argparse
6
  import gc
 
7
  import json
8
  import math
9
  import os
@@ -17,7 +18,7 @@ from functools import partial
17
  from pathlib import Path
18
  from pprint import pprint
19
  from random import Random
20
- from typing import BinaryIO, Literal, Optional, Union
21
 
22
  import numpy as np
23
  import pyworld
@@ -40,7 +41,7 @@ if not hasattr(torch.amp, "GradScaler"):
40
 
41
 
42
  # モジュールのバージョンではない
43
- PARAPHERNALIA_VERSION = "2.0.0-beta.1"
44
 
45
 
46
  def is_notebook() -> bool:
@@ -59,35 +60,51 @@ def repo_root() -> Path:
59
  # ハイパーパラメータ
60
  # 学習データや出力ディレクトリなど、学習ごとに変わるようなものはここに含めない
61
  dict_default_hparams = {
62
- # train
63
- "learning_rate_g": 2e-4,
64
- "learning_rate_d": 1e-4,
65
- "min_learning_rate_g": 1e-5,
66
- "min_learning_rate_d": 5e-6,
67
  "adam_betas": [0.8, 0.99],
68
  "adam_eps": 1e-6,
69
  "batch_size": 8,
70
- "grad_weight_mel": 1.0, # grad_weight は比が同じなら同じ意味になるはず
71
- "grad_weight_ap": 2.0,
72
- "grad_weight_adv": 3.0,
73
- "grad_weight_fm": 3.0,
 
74
  "grad_balancer_ema_decay": 0.995,
75
  "use_amp": True,
76
  "num_workers": 16,
77
  "n_steps": 10000,
78
- "warmup_steps": 2000,
 
 
79
  "in_sample_rate": 16000, # 変更不可
80
  "out_sample_rate": 24000, # 変更不可
81
  "wav_length": 4 * 24000, # 4s
82
  "segment_length": 100, # 1s
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  # data
84
- "phone_extractor_file": "assets/pretrained/003b_checkpoint_03000000.pt",
85
- "pitch_estimator_file": "assets/pretrained/008_1_checkpoint_00300000.pt",
86
  "in_ir_wav_dir": "assets/ir",
87
  "in_noise_wav_dir": "assets/noise",
88
  "in_test_wav_dir": "assets/test",
89
- "pretrained_file": "assets/pretrained/079_checkpoint_libritts_r_200_02400000.pt", # None も可
90
  # model
 
91
  "hidden_channels": 256, # ファインチューン時変更不可、変更した場合は推論側の対応必要
92
  "san": False, # ファインチューン時変更不可
93
  "compile_convnext": False,
@@ -118,8 +135,8 @@ if __name__ == "__main__":
118
 
119
 
120
  def prepare_training_configs_for_experiment() -> tuple[dict, Path, Path, bool, bool]:
121
- import ipynbname
122
- from IPython import get_ipython
123
 
124
  h = deepcopy(dict_default_hparams)
125
  in_wav_dataset_dir = repo_root() / "../../data/processed/libritts_r_200"
@@ -228,28 +245,38 @@ def dump_layer(layer: nn.Module, f: BinaryIO):
228
  elif isinstance(layer, (nn.Linear, nn.Conv1d, nn.LayerNorm)):
229
  dump(layer.weight)
230
  dump(layer.bias)
231
- elif isinstance(layer, nn.ConvTranspose1d):
232
- dump(layer.weight.transpose(0, 1))
233
- dump(layer.bias)
234
- elif isinstance(layer, nn.GRU):
235
- dump(layer.weight_ih_l0)
236
- dump(layer.bias_ih_l0)
237
- dump(layer.weight_hh_l0)
238
- dump(layer.bias_hh_l0)
239
- for i in range(1, 99999):
240
- if not hasattr(layer, f"weight_ih_l{i}"):
241
- break
242
- dump(getattr(layer, f"weight_ih_l{i}"))
243
- dump(getattr(layer, f"bias_ih_l{i}"))
244
- dump(getattr(layer, f"weight_hh_l{i}"))
245
- dump(getattr(layer, f"bias_hh_l{i}"))
 
 
 
 
 
 
 
 
 
 
246
  elif isinstance(layer, nn.Embedding):
247
  dump(layer.weight)
248
  elif isinstance(layer, nn.Parameter):
249
  dump(layer)
250
  elif isinstance(layer, nn.ModuleList):
251
- for l in layer:
252
- dump_layer(l, f)
253
  else:
254
  assert False, layer
255
 
@@ -368,6 +395,136 @@ class WSLinear(nn.Linear):
368
  self.gain.data.fill_(1.0)
369
 
370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  class ConvNeXtBlock(nn.Module):
372
  def __init__(
373
  self,
@@ -379,10 +536,39 @@ class ConvNeXtBlock(nn.Module):
379
  enable_scaling: bool = False,
380
  pre_scale: float = 1.0,
381
  post_scale: float = 1.0,
 
 
 
 
 
 
382
  ):
383
  super().__init__()
384
  self.use_weight_standardization = use_weight_standardization
385
  self.enable_scaling = enable_scaling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  self.dwconv = CausalConv1d(
387
  channels, channels, kernel_size=kernel_size, groups=channels
388
  )
@@ -407,7 +593,39 @@ class ConvNeXtBlock(nn.Module):
407
  self.register_buffer("post_scale", torch.tensor(post_scale))
408
  self.post_scale_weight = nn.Parameter(torch.ones(()))
409
 
410
- def forward(self, x: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
  identity = x
412
  if self.enable_scaling:
413
  x = x * self.pre_scale
@@ -426,14 +644,31 @@ class ConvNeXtBlock(nn.Module):
426
  return x
427
 
428
  def merge_weights(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
  if self.use_weight_standardization:
430
  self.dwconv.merge_weights()
431
  self.pwconv1.merge_weights()
432
  self.pwconv2.merge_weights()
433
  else:
434
- self.pwconv1.bias.data += (
435
- self.norm.bias.data[None, :] * self.pwconv1.weight.data
436
- ).sum(1)
437
  self.pwconv1.weight.data *= self.norm.weight.data[None, :]
438
  self.norm.bias.data[:] = 0.0
439
  self.norm.weight.data[:] = 1.0
@@ -458,6 +693,8 @@ class ConvNeXtBlock(nn.Module):
458
  if not hasattr(f, "write"):
459
  raise TypeError
460
 
 
 
461
  dump_layer(self.dwconv, f)
462
  dump_layer(self.pwconv1, f)
463
  dump_layer(self.pwconv2, f)
@@ -475,10 +712,16 @@ class ConvNeXtStack(nn.Module):
475
  kernel_size: int,
476
  use_weight_standardization: bool = False,
477
  enable_scaling: bool = False,
 
 
 
478
  ):
479
  super().__init__()
480
  assert delay * 2 + 1 <= embed_kernel_size
 
481
  self.use_weight_standardization = use_weight_standardization
 
 
482
  self.embed = CausalConv1d(in_channels, channels, embed_kernel_size, delay=delay)
483
  self.norm = nn.LayerNorm(channels)
484
  self.convnext = nn.ModuleList()
@@ -494,6 +737,12 @@ class ConvNeXtStack(nn.Module):
494
  enable_scaling=enable_scaling,
495
  pre_scale=pre_scale,
496
  post_scale=post_scale,
 
 
 
 
 
 
497
  )
498
  self.convnext.append(block)
499
  self.final_layer_norm = nn.LayerNorm(channels)
@@ -506,11 +755,25 @@ class ConvNeXtStack(nn.Module):
506
  self.norm = nn.Identity()
507
  self.final_layer_norm = nn.Identity()
508
 
509
- def forward(self, x: torch.Tensor) -> torch.Tensor:
 
 
510
  x = self.embed(x)
511
  x = self.norm(x.transpose(1, 2)).transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
512
  for conv_block in self.convnext:
513
- x = conv_block(x)
 
 
514
  x = self.final_layer_norm(x.transpose(1, 2)).transpose(1, 2)
515
  return x
516
 
@@ -535,6 +798,23 @@ class ConvNeXtStack(nn.Module):
535
  if not self.use_weight_standardization:
536
  dump_layer(self.final_layer_norm, f)
537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
 
539
  class FeatureExtractor(nn.Module):
540
  def __init__(self, hidden_channels: int):
@@ -588,64 +868,30 @@ class FeatureExtractor(nn.Module):
588
 
589
 
590
  class FeatureProjection(nn.Module):
591
- def __init__(self, in_channels: int, out_channels: int):
592
  super().__init__()
593
- self.norm = nn.LayerNorm(in_channels)
594
- self.projection = nn.Conv1d(in_channels, out_channels, 1)
595
  self.dropout = nn.Dropout(0.1)
596
 
597
  def forward(self, x: torch.Tensor) -> torch.Tensor:
598
  # [batch_size, channels, length]
599
  x = self.norm(x.transpose(1, 2)).transpose(1, 2)
600
- x = self.projection(x)
601
  x = self.dropout(x)
602
  return x
603
 
604
- def merge_weights(self):
605
- self.projection.bias.data += (
606
- (self.norm.bias.data[None, :, None] * self.projection.weight.data)
607
- .sum(1)
608
- .squeeze(1)
609
- )
610
- self.projection.weight.data *= self.norm.weight.data[None, :, None]
611
- self.norm.bias.data[:] = 0.0
612
- self.norm.weight.data[:] = 1.0
613
-
614
- def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
615
- if isinstance(f, (str, bytes, os.PathLike)):
616
- with open(f, "wb") as f:
617
- self.dump(f)
618
- return
619
- if not hasattr(f, "write"):
620
- raise TypeError
621
-
622
- dump_layer(self.projection, f)
623
-
624
 
625
  class PhoneExtractor(nn.Module):
626
  def __init__(
627
  self,
628
- phone_channels: int = 256,
629
- hidden_channels: int = 256,
630
- backbone_embed_kernel_size: int = 7,
631
  kernel_size: int = 17,
632
- n_blocks: int = 8,
633
  ):
634
  super().__init__()
635
  self.feature_extractor = FeatureExtractor(hidden_channels)
636
- self.feature_projection = FeatureProjection(hidden_channels, hidden_channels)
637
- self.n_speaker_encoder_layers = 3
638
- self.speaker_encoder = nn.GRU(
639
- hidden_channels,
640
- hidden_channels,
641
- self.n_speaker_encoder_layers,
642
- batch_first=True,
643
- )
644
- for i in range(self.n_speaker_encoder_layers):
645
- for input_char in "ih":
646
- self.speaker_encoder = weight_norm(
647
- self.speaker_encoder, f"weight_{input_char}h_l{i}"
648
- )
649
  self.backbone = ConvNeXtStack(
650
  in_channels=hidden_channels,
651
  channels=hidden_channels,
@@ -654,6 +900,7 @@ class PhoneExtractor(nn.Module):
654
  delay=0,
655
  embed_kernel_size=backbone_embed_kernel_size,
656
  kernel_size=kernel_size,
 
657
  )
658
  self.head = weight_norm(nn.Conv1d(hidden_channels, phone_channels, 1))
659
 
@@ -670,36 +917,14 @@ class PhoneExtractor(nn.Module):
670
  stats["feature_norm"] = x.detach().norm(dim=1).mean()
671
  # [batch_size, feature_extractor_hidden_channels, length] -> [batch_size, hidden_channels, length]
672
  x = self.feature_projection(x)
673
- # [batch_size, hidden_channels, length] -> [batch_size, length, hidden_channels]
674
- g, _ = self.speaker_encoder(x.transpose(1, 2))
675
- if self.training:
676
- batch_size, length, _ = g.size()
677
- shuffle_sizes_for_each_data = torch.randint(
678
- 0, 50, (batch_size,), device=g.device
679
- )
680
- max_indices = torch.arange(length, device=g.device)[None, :, None]
681
- min_indices = (
682
- max_indices - shuffle_sizes_for_each_data[:, None, None]
683
- ).clamp_(min=0)
684
- with torch.cuda.amp.autocast(False):
685
- indices = (
686
- torch.rand(g.size(), device=g.device)
687
- * (max_indices - min_indices + 1)
688
- ).long() + min_indices
689
- assert indices.min() >= 0, indices.min()
690
- assert indices.max() < length, (indices.max(), length)
691
- g = g.gather(1, indices)
692
-
693
- # [batch_size, length, hidden_channels] -> [batch_size, hidden_channels, length]
694
- g = g.transpose(1, 2).contiguous()
695
  # [batch_size, hidden_channels, length]
696
- x = self.backbone(x + g)
697
  # [batch_size, hidden_channels, length] -> [batch_size, phone_channels, length]
698
  phone = self.head(F.gelu(x, approximate="tanh"))
699
 
700
  results = [phone]
701
  if return_stats:
702
- stats["code_norm"] = phone.detach().norm(dim=1).mean().item()
703
  results.append(stats)
704
 
705
  if len(results) == 1:
@@ -719,15 +944,25 @@ class PhoneExtractor(nn.Module):
719
 
720
  def remove_weight_norm(self):
721
  self.feature_extractor.remove_weight_norm()
722
- for i in range(self.n_speaker_encoder_layers):
723
- for input_char in "ih":
724
- remove_weight_norm(self.speaker_encoder, f"weight_{input_char}h_l{i}")
725
  remove_weight_norm(self.head)
726
 
727
  def merge_weights(self):
728
- self.feature_projection.merge_weights()
729
  self.backbone.merge_weights()
730
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731
  def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
732
  if isinstance(f, (str, bytes, os.PathLike)):
733
  with open(f, "wb") as f:
@@ -737,12 +972,187 @@ class PhoneExtractor(nn.Module):
737
  raise TypeError
738
 
739
  dump_layer(self.feature_extractor, f)
740
- dump_layer(self.feature_projection, f)
741
- dump_layer(self.speaker_encoder, f)
742
  dump_layer(self.backbone, f)
743
  dump_layer(self.head, f)
744
 
745
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
746
  # %% [markdown]
747
  # ## Pitch Estimator
748
 
@@ -790,7 +1200,6 @@ def extract_pitch_features(
790
  )
791
 
792
  # 自己相関
793
- # 余裕があったら LPC 残差にするのも試したい
794
  # 元々これに 2.0 / corr_win_length を掛けて使おうと思っていたが、
795
  # この値は振幅の 2 乗に比例していて、NN に入力するために良い感じに分散を
796
  # 標準化する方法が思いつかなかったのでやめた
@@ -836,17 +1245,17 @@ class PitchEstimator(nn.Module):
836
  self,
837
  input_instfreq_channels: int = 192,
838
  input_corr_channels: int = 256,
839
- pitch_channels: int = 384,
840
  channels: int = 192,
841
- intermediate_channels: int = 192 * 3,
842
- n_blocks: int = 6,
843
  delay: int = 1, # 10ms, 特徴抽出と合わせると 22.5ms
844
  embed_kernel_size: int = 3,
845
  kernel_size: int = 33,
846
- bins_per_octave: int = 96,
847
  ):
848
  super().__init__()
849
- self.bins_per_octave = bins_per_octave
850
 
851
  self.instfreq_embed_0 = nn.Conv1d(input_instfreq_channels, channels, 1)
852
  self.instfreq_embed_1 = nn.Conv1d(channels, channels, 1)
@@ -860,8 +1269,9 @@ class PitchEstimator(nn.Module):
860
  delay,
861
  embed_kernel_size,
862
  kernel_size,
 
863
  )
864
- self.head = nn.Conv1d(channels, pitch_channels, 1)
865
 
866
  def forward(self, wav: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
867
  # wav: [batch_size, 1, wav_length]
@@ -884,32 +1294,30 @@ class PitchEstimator(nn.Module):
884
  corr_diff = F.gelu(self.corr_embed_0(corr_diff), approximate="tanh")
885
  corr_diff = self.corr_embed_1(corr_diff)
886
  # [batch_size, channels, length]
887
- x = instfreq_features + corr_diff # ここ活性化関数忘れてる
888
  x = self.backbone(x)
889
- # [batch_size, pitch_channels, length]
890
  x = self.head(x)
891
  return x, energy
892
 
893
  def sample_pitch(
894
- self, pitch: torch.Tensor, band_width: int = 48, return_features: bool = False
895
  ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
896
- # pitch: [batch_size, pitch_channels, length]
897
  # 返されるピッチの値には 0 は含まれない
898
- batch_size, pitch_channels, length = pitch.size()
899
  pitch = pitch.softmax(1)
900
  if return_features:
901
  unvoiced_proba = pitch[:, :1, :].clone()
902
  pitch[:, 0, :] = -100.0
903
  pitch = (
904
- pitch.transpose(1, 2)
905
- .contiguous()
906
- .view(batch_size * length, 1, pitch_channels)
907
  )
908
  band_pitch = F.conv1d(
909
  pitch,
910
  torch.ones((1, 1, 1), device=pitch.device).expand(1, 1, band_width),
911
  )
912
- # [batch_size * length, 1, pitch_channels - band_width + 1] -> Long[batch_size * length, 1]
913
  quantized_band_pitch = band_pitch.argmax(2)
914
  if return_features:
915
  # [batch_size * length, 1]
@@ -917,29 +1325,33 @@ class PitchEstimator(nn.Module):
917
  # [batch_size * length, 1]
918
  half_pitch_band_proba = band_pitch.gather(
919
  2,
920
- (quantized_band_pitch - self.bins_per_octave).clamp_(min=1)[:, :, None],
 
 
921
  )
922
- half_pitch_band_proba[quantized_band_pitch <= self.bins_per_octave] = 0.0
 
 
923
  half_pitch_proba = (half_pitch_band_proba / (band_proba + 1e-6)).view(
924
  batch_size, 1, length
925
  )
926
  # [batch_size * length, 1]
927
  double_pitch_band_proba = band_pitch.gather(
928
  2,
929
- (quantized_band_pitch + self.bins_per_octave).clamp_(
930
- max=pitch_channels - band_width
931
  )[:, :, None],
932
  )
933
  double_pitch_band_proba[
934
  quantized_band_pitch
935
- > pitch_channels - band_width - self.bins_per_octave
936
  ] = 0.0
937
  double_pitch_proba = (double_pitch_band_proba / (band_proba + 1e-6)).view(
938
  batch_size, 1, length
939
  )
940
- # Long[1, pitch_channels]
941
- mask = torch.arange(pitch_channels, device=pitch.device)[None, :]
942
- # bool[batch_size * length, pitch_channels]
943
  mask = (quantized_band_pitch <= mask) & (
944
  mask < quantized_band_pitch + band_width
945
  )
@@ -1088,24 +1500,6 @@ def generate_noise(
1088
  return noise, excitation # [batch_size, length * hop_length]
1089
 
1090
 
1091
- class GradientEqualizerFunction(torch.autograd.Function):
1092
- """ノルムが小さいほど勾配が大きくなってしまうのを補正する"""
1093
-
1094
- @staticmethod
1095
- def forward(ctx, x: torch.Tensor) -> torch.Tensor:
1096
- # x: [batch_size, 1, length]
1097
- rms = x.square().mean(dim=2, keepdim=True).sqrt_()
1098
- ctx.save_for_backward(rms)
1099
- return x
1100
-
1101
- @staticmethod
1102
- def backward(ctx, dx: torch.Tensor) -> torch.Tensor:
1103
- # dx: [batch_size, 1, length]
1104
- (rms,) = ctx.saved_tensors
1105
- dx = dx * (math.sqrt(2.0) * rms + 0.1)
1106
- return dx
1107
-
1108
-
1109
  D4C_PREVENT_ZERO_DIVISION = True # False にすると本家の処理
1110
 
1111
 
@@ -1493,6 +1887,7 @@ class Vocoder(nn.Module):
1493
  def __init__(
1494
  self,
1495
  channels: int,
 
1496
  hop_length: int = 240,
1497
  n_pre_blocks: int = 4,
1498
  out_sample_rate: float = 24000.0,
@@ -1504,17 +1899,20 @@ class Vocoder(nn.Module):
1504
  self.prenet = ConvNeXtStack(
1505
  in_channels=channels,
1506
  channels=channels,
1507
- intermediate_channels=channels * 3,
1508
  n_blocks=n_pre_blocks,
1509
  delay=2, # 20ms 遅延
1510
  embed_kernel_size=7,
1511
  kernel_size=33,
1512
  enable_scaling=True,
 
 
 
1513
  )
1514
  self.ir_generator = ConvNeXtStack(
1515
  in_channels=channels,
1516
  channels=channels,
1517
- intermediate_channels=channels * 3,
1518
  n_blocks=2,
1519
  delay=0,
1520
  embed_kernel_size=3,
@@ -1528,7 +1926,7 @@ class Vocoder(nn.Module):
1528
  self.aperiodicity_generator = ConvNeXtStack(
1529
  in_channels=channels,
1530
  channels=channels,
1531
- intermediate_channels=channels * 3,
1532
  n_blocks=1,
1533
  delay=0,
1534
  embed_kernel_size=3,
@@ -1541,7 +1939,7 @@ class Vocoder(nn.Module):
1541
  self.post_filter_generator = ConvNeXtStack(
1542
  in_channels=channels,
1543
  channels=channels,
1544
- intermediate_channels=channels * 3,
1545
  n_blocks=1,
1546
  delay=0,
1547
  embed_kernel_size=3,
@@ -1553,13 +1951,14 @@ class Vocoder(nn.Module):
1553
  self.register_buffer("post_filter_scale", torch.tensor(0.01))
1554
 
1555
  def forward(
1556
- self, x: torch.Tensor, pitch: torch.Tensor
1557
  ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
1558
  # x: [batch_size, channels, length]
1559
  # pitch: [batch_size, length]
 
1560
  batch_size, _, length = x.size()
1561
 
1562
- x = self.prenet(x)
1563
  ir = self.ir_generator(x)
1564
  ir = F.silu(ir, inplace=True)
1565
  # [batch_size, 512, length]
@@ -1643,8 +2042,6 @@ class Vocoder(nn.Module):
1643
  # [batch_size, 1, length * hop_length]
1644
  y_g_hat = (periodic_signal + aperiodic_signal)[:, None, :]
1645
 
1646
- y_g_hat = GradientEqualizerFunction.apply(y_g_hat)
1647
-
1648
  return y_g_hat, {
1649
  "periodic_signal": periodic_signal.detach(),
1650
  "aperiodic_signal": aperiodic_signal.detach(),
@@ -1761,20 +2158,36 @@ class ConverterNetwork(nn.Module):
1761
  phone_extractor: PhoneExtractor,
1762
  pitch_estimator: PitchEstimator,
1763
  n_speakers: int,
 
1764
  hidden_channels: int,
 
 
 
 
1765
  ):
1766
  super().__init__()
1767
  self.frozen_modules = {
1768
  "phone_extractor": phone_extractor.eval().requires_grad_(False),
1769
  "pitch_estimator": pitch_estimator.eval().requires_grad_(False),
1770
  }
 
 
 
1771
  self.out_sample_rate = out_sample_rate = 24000
1772
- self.embed_phone = nn.Conv1d(256, hidden_channels, 1)
 
 
 
 
 
 
 
 
1773
  self.embed_phone.weight.data.normal_(0.0, math.sqrt(2.0 / (256 * 5)))
1774
  self.embed_phone.bias.data.zero_()
1775
- self.embed_quantized_pitch = nn.Embedding(384, hidden_channels)
1776
  phase = (
1777
- torch.arange(384, dtype=torch.float)[:, None]
1778
  * (
1779
  torch.arange(0, hidden_channels, 2, dtype=torch.float)
1780
  * (-math.log(10000.0) / hidden_channels)
@@ -1791,8 +2204,22 @@ class ConverterNetwork(nn.Module):
1791
  self.embed_speaker.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0))
1792
  self.embed_formant_shift = nn.Embedding(9, hidden_channels)
1793
  self.embed_formant_shift.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0))
 
 
 
 
 
 
 
 
 
 
 
 
 
1794
  self.vocoder = Vocoder(
1795
  channels=hidden_channels,
 
1796
  hop_length=out_sample_rate // 100,
1797
  n_pre_blocks=4,
1798
  out_sample_rate=out_sample_rate,
@@ -1820,6 +2247,21 @@ class ConverterNetwork(nn.Module):
1820
  )
1821
  )
1822
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1823
  def _get_resampler(
1824
  self, orig_freq, new_freq, device, cache={}
1825
  ) -> torchaudio.transforms.Resample:
@@ -1849,27 +2291,53 @@ class ConverterNetwork(nn.Module):
1849
  # slice_start_indices: [batch_size]
1850
 
1851
  batch_size, _, _ = x.size()
 
1852
 
1853
  with torch.inference_mode():
1854
  phone_extractor: PhoneExtractor = self.frozen_modules["phone_extractor"]
1855
  pitch_estimator: PitchEstimator = self.frozen_modules["pitch_estimator"]
1856
  # [batch_size, 1, wav_length] -> [batch_size, phone_channels, length]
1857
  phone = phone_extractor.units(x).transpose(1, 2)
1858
- # [batch_size, 1, wav_length] -> [batch_size, pitch_channels, length], [batch_size, 1, length]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1859
  pitch, energy = pitch_estimator(x)
1860
  # augmentation
1861
  if self.training:
1862
- # [batch_size, pitch_channels - 1]
1863
  weights = pitch.softmax(1)[:, 1:, :].mean(2)
1864
  # [batch_size]
1865
  mean_pitch = (
1866
- weights * torch.arange(1, 384, device=weights.device)
 
 
 
 
 
1867
  ).sum(1) / weights.sum(1)
1868
  mean_pitch = mean_pitch.round_().long()
1869
  target_pitch = torch.randint_like(mean_pitch, 64, 257)
1870
  shift = target_pitch - mean_pitch
1871
  shift_ratio = (
1872
- 2.0 ** (shift.float() / pitch_estimator.bins_per_octave)
1873
  ).tolist()
1874
  shift = []
1875
  interval_length = 100 # 1s
@@ -1889,7 +2357,8 @@ class ConverterNetwork(nn.Module):
1889
  shift_ratio_i = shift_numer_i / shift_denom_i
1890
  shift_i = int(
1891
  round(
1892
- math.log2(shift_ratio_i) * pitch_estimator.bins_per_octave
 
1893
  )
1894
  )
1895
  shift.append(shift_i)
@@ -1921,7 +2390,7 @@ class ConverterNetwork(nn.Module):
1921
  # [batch_size, 1, sum(wav_length) + batch_size * 16000]
1922
  concatenated_shifted_x = torch.cat(concatenated_shifted_x, dim=2)
1923
  assert concatenated_shifted_x.size(2) % (256 * 160) == 0
1924
- # [1, pitch_channels, length / shift_ratio], [1, 1, length / shift_ratio]
1925
  concatenated_pitch, concatenated_energy = pitch_estimator(
1926
  concatenated_shifted_x
1927
  )
@@ -1963,7 +2432,7 @@ class ConverterNetwork(nn.Module):
1963
  energy[i : i + 1, :, :length] = energy_i[:, :, :length]
1964
  torch.backends.cudnn.benchmark = True
1965
 
1966
- # [batch_size, pitch_channels, length] -> Long[batch_size, length], [batch_size, 3, length]
1967
  quantized_pitch, pitch_features = pitch_estimator.sample_pitch(
1968
  pitch, return_features=True
1969
  )
@@ -1975,14 +2444,14 @@ class ConverterNetwork(nn.Module):
1975
  quantized_pitch
1976
  + (
1977
  pitch_shift_semitone[:, None]
1978
- * (pitch_estimator.bins_per_octave / 12.0)
1979
  )
1980
  .round_()
1981
  .long()
1982
- ).clamp_(1, 383),
1983
  )
1984
  pitch = 55.0 * 2.0 ** (
1985
- quantized_pitch.float() / pitch_estimator.bins_per_octave
1986
  )
1987
  # phone が 2.5ms 先読みしているのに対して、
1988
  # energy は 12.5ms, pitch_features は 22.5ms 先読みしているので、
@@ -2017,8 +2486,15 @@ class ConverterNetwork(nn.Module):
2017
  # [batch_size, hidden_channels, length] -> [batch_size, hidden_channels, segment_length]
2018
  x = slice_segments(x, slice_start_indices, slice_segment_length)
2019
  x = F.silu(x, inplace=True)
 
 
 
 
 
 
 
2020
  # [batch_size, hidden_channels, segment_length] -> [batch_size, 1, segment_length * 240]
2021
- y_g_hat, stats = self.vocoder(x, pitch)
2022
  stats["pitch"] = pitch
2023
  if return_stats:
2024
  return y_g_hat, stats
@@ -2026,7 +2502,7 @@ class ConverterNetwork(nn.Module):
2026
  return y_g_hat
2027
 
2028
  def _normalize_melsp(self, x):
2029
- return x.clamp(min=1e-10).log_().mul_(0.5)
2030
 
2031
  def forward_and_compute_loss(
2032
  self,
@@ -2037,7 +2513,15 @@ class ConverterNetwork(nn.Module):
2037
  slice_segment_length: int,
2038
  y_all: torch.Tensor,
2039
  enable_loss_ap: bool = False,
2040
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
2041
  # noisy_wavs_16k: [batch_size, 1, wav_length]
2042
  # target_speaker_id: Long[batch_size]
2043
  # formant_shift_semitone: [batch_size]
@@ -2047,6 +2531,8 @@ class ConverterNetwork(nn.Module):
2047
 
2048
  stats = {}
2049
  loss_mel = 0.0
 
 
2050
 
2051
  # [batch_size, 1, wav_length] -> [batch_size, 1, wav_length * 240]
2052
  y_hat_all, intermediates = self(
@@ -2055,6 +2541,7 @@ class ConverterNetwork(nn.Module):
2055
  formant_shift_semitone,
2056
  return_stats=True,
2057
  )
 
2058
 
2059
  with torch.amp.autocast("cuda", enabled=False):
2060
  periodic_signal = intermediates["periodic_signal"].float()
@@ -2063,9 +2550,25 @@ class ConverterNetwork(nn.Module):
2063
  periodic_signal = periodic_signal[:, : noise_excitation.size(1)]
2064
  aperiodic_signal = aperiodic_signal[:, : noise_excitation.size(1)]
2065
  y_hat_all = y_hat_all.float()
 
 
 
2066
  y_hat_all_truncated = y_hat_all.squeeze(1)[:, : periodic_signal.size(1)]
2067
  y_all_truncated = y_all.squeeze(1)[:, : periodic_signal.size(1)]
2068
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2069
  for melspectrogram in self.melspectrograms:
2070
  melsp_periodic_signal = melspectrogram(periodic_signal)
2071
  melsp_aperiodic_signal = melspectrogram(aperiodic_signal)
@@ -2105,6 +2608,7 @@ class ConverterNetwork(nn.Module):
2105
  t = (
2106
  torch.arange(intermediates["pitch"].size(1), device=y_all.device)
2107
  * 0.01
 
2108
  )
2109
  y_coarse_aperiodicity, y_rms = d4c(
2110
  y_all.squeeze(1),
@@ -2126,7 +2630,7 @@ class ConverterNetwork(nn.Module):
2126
  loss_ap = F.mse_loss(
2127
  y_hat_coarse_aperiodicity, y_coarse_aperiodicity, reduction="none"
2128
  )
2129
- loss_ap *= (rms / (rms + 1e-3))[:, :, None]
2130
  loss_ap = loss_ap.mean()
2131
  else:
2132
  loss_ap = torch.tensor(0.0)
@@ -2137,7 +2641,7 @@ class ConverterNetwork(nn.Module):
2137
  )
2138
  # [batch_size, 1, wav_length] -> [batch_size, 1, slice_segment_length * 240]
2139
  y = slice_segments(y_all, slice_start_indices * 240, slice_segment_length * 240)
2140
- return y, y_hat, y_hat_all, loss_mel, loss_ap, stats
2141
 
2142
  def merge_weights(self):
2143
  self.vocoder.merge_weights()
@@ -2155,6 +2659,29 @@ class ConverterNetwork(nn.Module):
2155
  dump_layer(self.embed_pitch_features, f)
2156
  dump_layer(self.vocoder, f)
2157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2158
 
2159
  # Discriminator
2160
 
@@ -2288,8 +2815,8 @@ class DiscriminatorP(nn.Module):
2288
  t = t + n_pad
2289
  x = x.view(b, c, t // self.period, self.period)
2290
 
2291
- for l in self.convs:
2292
- x = l(x)
2293
  x = F.silu(x, inplace=True)
2294
  fmap.append(x)
2295
  if self.san:
@@ -2336,8 +2863,8 @@ class DiscriminatorR(nn.Module):
2336
  fmap = []
2337
 
2338
  x = self._spectrogram(x).unsqueeze(1)
2339
- for l in self.convs:
2340
- x = l(x)
2341
  x = F.silu(x, inplace=True)
2342
  fmap.append(x)
2343
  if self.san:
@@ -2457,10 +2984,11 @@ class MultiPeriodDiscriminator(nn.Module):
2457
  # adversarial loss
2458
  adv_loss = 0.0
2459
  for dg, name in zip(y_d_gs, self.discriminator_names):
2460
- dg = dg.float()
2461
  if self.san:
2462
- g_loss = F.softplus(1.0 - dg).square().mean()
 
2463
  else:
 
2464
  g_loss = (1.0 - dg).square().mean()
2465
  stats[f"{name}_gg_loss"] = g_loss.item()
2466
  adv_loss += g_loss
@@ -2678,6 +3206,82 @@ def convolve(signal: torch.Tensor, ir: torch.Tensor) -> torch.Tensor:
2678
  return res[..., : signal.size(-1)]
2679
 
2680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2681
  def random_filter(audio: torch.Tensor) -> torch.Tensor:
2682
  assert audio.ndim == 2
2683
  ab = torch.rand(audio.size(0), 6) * 0.75 - 0.375
@@ -2720,7 +3324,7 @@ def get_noise(
2720
 
2721
 
2722
  def get_butterworth_lpf(
2723
- cutoff_freq: int, sample_rate: int, cache={}
2724
  ) -> tuple[torch.Tensor, torch.Tensor]:
2725
  if (cutoff_freq, sample_rate) not in cache:
2726
  q = math.sqrt(0.5)
@@ -2731,8 +3335,9 @@ def get_butterworth_lpf(
2731
  b0 = b1 * 0.5
2732
  a1 = -2.0 * cos_omega / (1.0 + alpha)
2733
  a2 = (1.0 - alpha) / (1.0 + alpha)
2734
- cache[(cutoff_freq, sample_rate)] = torch.tensor([b0, b1, b0]), torch.tensor(
2735
- [1.0, a1, a2]
 
2736
  )
2737
  return cache[(cutoff_freq, sample_rate)]
2738
 
@@ -2742,15 +3347,26 @@ def augment_audio(
2742
  sample_rate: int,
2743
  noise_files: list[Union[str, bytes, os.PathLike]],
2744
  ir_files: list[Union[str, bytes, os.PathLike]],
 
 
 
 
 
 
 
2745
  ) -> torch.Tensor:
2746
  # [1, wav_length]
2747
  assert clean.size(0) == 1
2748
  n_samples = clean.size(1)
2749
 
2750
- snr_candidates = [-20, -25, -30, -35, -40, -45]
2751
-
2752
  original_clean_rms = clean.square().mean().sqrt_()
2753
 
 
 
 
 
 
 
2754
  # noise を取得して clean と concat する
2755
  noise = get_noise(n_samples, sample_rate, noise_files)
2756
  signals = torch.cat([clean, noise])
@@ -2759,7 +3375,7 @@ def augment_audio(
2759
  signals = random_filter(signals)
2760
 
2761
  # clean, noise にリバーブをかける
2762
- if torch.rand(()) < 0.5:
2763
  ir_file = ir_files[torch.randint(0, len(ir_files), ())]
2764
  ir, sr = torchaudio.load(ir_file, backend="soundfile")
2765
  assert ir.size() == (2, sr), ir.size()
@@ -2767,12 +3383,11 @@ def augment_audio(
2767
  signals = convolve(signals, ir)
2768
 
2769
  # clean, noise に同じ LPF をかける
2770
- if torch.rand(()) < 0.2:
2771
  if signals.abs().max() > 0.8:
2772
  signals /= signals.abs().max() * 1.25
2773
- cutoff_freq_candidates = [2000, 3000, 4000, 6000]
2774
- cutoff_freq = cutoff_freq_candidates[
2775
- torch.randint(0, len(cutoff_freq_candidates), ())
2776
  ]
2777
  b, a = get_butterworth_lpf(cutoff_freq, sample_rate)
2778
  signals = torchaudio.functional.lfilter(signals, a, b, clamp=False)
@@ -2782,13 +3397,17 @@ def augment_audio(
2782
  clean_rms = clean.square().mean().sqrt_()
2783
  clean *= original_clean_rms / clean_rms
2784
 
2785
- # clean, noise の音量をピークを重視して取る
2786
- clean_level = clean.square().square_().mean().sqrt_().sqrt_()
2787
- noise_level = noise.square().square_().mean().sqrt_().sqrt_()
2788
- # SNR
2789
- snr = snr_candidates[torch.randint(0, len(snr_candidates), ())]
2790
- # noisy を生成
2791
- noisy = clean + noise * (10.0 ** (snr / 20.0) * clean_level / (noise_level + 1e-5))
 
 
 
 
2792
  return noisy
2793
 
2794
 
@@ -2802,6 +3421,18 @@ class WavDataset(torch.utils.data.Dataset):
2802
  segment_length: int = 100, # 1s
2803
  noise_files: Optional[list[Union[str, bytes, os.PathLike]]] = None,
2804
  ir_files: Optional[list[Union[str, bytes, os.PathLike]]] = None,
 
 
 
 
 
 
 
 
 
 
 
 
2805
  ):
2806
  self.audio_files = audio_files
2807
  self.in_sample_rate = in_sample_rate
@@ -2810,6 +3441,21 @@ class WavDataset(torch.utils.data.Dataset):
2810
  self.segment_length = segment_length
2811
  self.noise_files = noise_files
2812
  self.ir_files = ir_files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2813
 
2814
  if (noise_files is None) is not (ir_files is None):
2815
  raise ValueError("noise_files and ir_files must be both None or not None")
@@ -2851,7 +3497,17 @@ class WavDataset(torch.utils.data.Dataset):
2851
  clean_wav
2852
  )
2853
  noisy_wav_16k = augment_audio(
2854
- clean_wav_16k, self.in_sample_rate, self.noise_files, self.ir_files
 
 
 
 
 
 
 
 
 
 
2855
  )
2856
 
2857
  clean_wav = clean_wav.squeeze_(0)
@@ -2937,6 +3593,44 @@ AUDIO_FILE_SUFFIXES = {
2937
  }
2938
 
2939
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2940
  def prepare_training():
2941
  # 各種準備をする
2942
  # 副作用として、出力ディレクトリと TensorBoard のログファイルなどが生成される
@@ -2961,18 +3655,18 @@ def prepare_training():
2961
  if not in_wav_dataset_dir.is_dir():
2962
  raise ValueError(f"{in_wav_dataset_dir} is not found.")
2963
  if resume:
2964
- latest_checkpoint_file = out_dir / "checkpoint_latest.pt"
2965
  if not latest_checkpoint_file.is_file():
2966
  raise ValueError(f"{latest_checkpoint_file} is not found.")
2967
  else:
2968
  if out_dir.is_dir():
2969
- if (out_dir / "checkpoint_latest.pt").is_file():
2970
  raise ValueError(
2971
- f"{out_dir / 'checkpoint_latest.pt'} already exists. "
2972
  "Please specify a different output directory, or use --resume option."
2973
  )
2974
  for file in out_dir.iterdir():
2975
- if file.suffix == ".pt":
2976
  raise ValueError(
2977
  f"{out_dir} already contains model files. "
2978
  "Please specify a different output directory."
@@ -3084,6 +3778,13 @@ def prepare_training():
3084
  segment_length=h.segment_length,
3085
  noise_files=noise_files,
3086
  ir_files=ir_files,
 
 
 
 
 
 
 
3087
  )
3088
  training_loader = torch.utils.data.DataLoader(
3089
  training_dataset,
@@ -3112,7 +3813,9 @@ def prepare_training():
3112
  print("Computing pitch shifts for test files...")
3113
  test_pitch_shifts = []
3114
  source_f0s = []
3115
- for i, (file, target_ids) in enumerate(tqdm(test_filelist)):
 
 
3116
  source_f0 = compute_mean_f0([file], method="harvest")
3117
  source_f0s.append(source_f0)
3118
  if math.isnan(source_f0):
@@ -3136,7 +3839,9 @@ def prepare_training():
3136
  repo_root() / h.phone_extractor_file, map_location="cpu", weights_only=True
3137
  )
3138
  print(
3139
- phone_extractor.load_state_dict(phone_extractor_checkpoint["phone_extractor"])
 
 
3140
  )
3141
  del phone_extractor_checkpoint
3142
 
@@ -3153,7 +3858,12 @@ def prepare_training():
3153
  phone_extractor,
3154
  pitch_estimator,
3155
  n_speakers,
 
3156
  h.hidden_channels,
 
 
 
 
3157
  ).to(device)
3158
  net_d = MultiPeriodDiscriminator(san=h.san).to(device)
3159
 
@@ -3173,6 +3883,7 @@ def prepare_training():
3173
  grad_scaler = torch.amp.GradScaler("cuda", enabled=h.use_amp)
3174
  grad_balancer = GradBalancer(
3175
  weights={
 
3176
  "loss_mel": h.grad_weight_mel,
3177
  "loss_adv": h.grad_weight_adv,
3178
  "loss_fm": h.grad_weight_fm,
@@ -3187,72 +3898,76 @@ def prepare_training():
3187
  # チェックポイント読み出し
3188
 
3189
  initial_iteration = 0
3190
- if resume:
3191
  checkpoint_file = latest_checkpoint_file
3192
- elif h.pretrained_file is not None:
3193
  checkpoint_file = repo_root() / h.pretrained_file
3194
- else:
3195
  checkpoint_file = None
 
3196
  if checkpoint_file is not None:
3197
- checkpoint = torch.load(checkpoint_file, map_location="cpu", weights_only=True)
 
3198
  if not resume and not skip_training: # ファインチューニング
3199
- checkpoint_n_speakers = len(checkpoint["net_g"]["embed_speaker.weight"])
3200
- initial_speaker_embedding = checkpoint["net_g"][
3201
- "embed_speaker.weight"
3202
- ].mean(0, keepdim=True)
3203
- if True:
3204
- checkpoint["net_g"]["embed_speaker.weight"] = initial_speaker_embedding[
3205
- [0] * n_speakers
3206
- ]
3207
- else: # 話者追加用
3208
- assert n_speakers > checkpoint_n_speakers
3209
- print(
3210
- f"embed_speaker.weight was padded: {checkpoint_n_speakers} -> {n_speakers}"
3211
- )
3212
- checkpoint["net_g"]["embed_speaker.weight"] = F.pad(
3213
- checkpoint["net_g"]["embed_speaker.weight"],
3214
- (0, 0, 0, n_speakers - checkpoint_n_speakers),
3215
- )
3216
- checkpoint["net_g"]["embed_speaker.weight"][
3217
- checkpoint_n_speakers:
3218
- ] = initial_speaker_embedding
3219
  print(net_g.load_state_dict(checkpoint["net_g"], strict=False))
3220
  print(net_d.load_state_dict(checkpoint["net_d"], strict=False))
3221
  if resume or skip_training:
3222
- optim_g.load_state_dict(checkpoint["optim_g"])
3223
- optim_d.load_state_dict(checkpoint["optim_d"])
 
 
 
 
3224
  initial_iteration = checkpoint["iteration"]
3225
  grad_balancer.load_state_dict(checkpoint["grad_balancer"])
3226
  grad_scaler.load_state_dict(checkpoint["grad_scaler"])
3227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3228
  # スケジューラ
3229
 
3230
- def get_cosine_annealing_warmup_scheduler(
3231
  optimizer: torch.optim.Optimizer,
3232
  warmup_epochs: int,
3233
- total_epochs: int,
3234
- min_learning_rate: float,
3235
  ) -> torch.optim.lr_scheduler.LambdaLR:
3236
- lr_ratio = min_learning_rate / optimizer.param_groups[0]["lr"]
3237
- m = 0.5 * (1.0 - lr_ratio)
3238
- a = 0.5 * (1.0 + lr_ratio)
3239
-
3240
  def lr_lambda(current_epoch: int) -> float:
3241
  if current_epoch < warmup_epochs:
3242
  return current_epoch / warmup_epochs
3243
- elif current_epoch < total_epochs:
3244
- rate = (current_epoch - warmup_epochs) / (total_epochs - warmup_epochs)
3245
- return math.cos(rate * math.pi) * m + a
3246
  else:
3247
- return min_learning_rate
3248
 
3249
  return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
3250
 
3251
- scheduler_g = get_cosine_annealing_warmup_scheduler(
3252
- optim_g, h.warmup_steps, h.n_steps, h.min_learning_rate_g
3253
  )
3254
- scheduler_d = get_cosine_annealing_warmup_scheduler(
3255
- optim_d, h.warmup_steps, h.n_steps, h.min_learning_rate_d
3256
  )
3257
  with warnings.catch_warnings():
3258
  warnings.filterwarnings(
@@ -3274,6 +3989,9 @@ def prepare_training():
3274
  writer = None
3275
  else:
3276
  writer = SummaryWriter(out_dir)
 
 
 
3277
  writer.add_text(
3278
  "log",
3279
  f"start training w/ {torch.cuda.get_device_name(device) if torch.cuda.is_available() else 'cpu'}.",
@@ -3367,12 +4085,11 @@ if __name__ == "__main__" and writer is not None:
3367
  if h.profile
3368
  else nullcontext()
3369
  ) as profiler:
3370
-
3371
- for iteration in tqdm(range(initial_iteration, h.n_steps)):
3372
  # === 1. データ前処理 ===
3373
  try:
3374
  batch = next(data_iter)
3375
- except:
3376
  data_iter = iter(training_loader)
3377
  batch = next(data_iter)
3378
  (
@@ -3388,20 +4105,27 @@ if __name__ == "__main__" and writer is not None:
3388
  # === 2.1 Generator の順伝播 ===
3389
  if h.compile_convnext:
3390
  ConvNeXtStack.forward = compiled_convnextstack_forward
3391
- y, y_hat, y_hat_for_backward, loss_mel, loss_ap, generator_stats = (
3392
- net_g.forward_and_compute_loss(
3393
- noisy_wavs_16k[:, None, :],
3394
- speaker_ids,
3395
- formant_shift_semitone,
3396
- slice_start_indices=slice_starts,
3397
- slice_segment_length=h.segment_length,
3398
- y_all=clean_wavs[:, None, :],
3399
- enable_loss_ap=h.grad_weight_ap != 0.0,
3400
- )
 
 
 
 
 
 
3401
  )
3402
  if h.compile_convnext:
3403
  ConvNeXtStack.forward = raw_convnextstack_forward
3404
  assert y_hat.isfinite().all()
 
3405
  assert loss_mel.isfinite().all()
3406
  assert loss_ap.isfinite().all()
3407
 
@@ -3432,6 +4156,7 @@ if __name__ == "__main__" and writer is not None:
3432
  assert param.grad is None
3433
  gradient_balancer_stats = grad_balancer.backward(
3434
  {
 
3435
  "loss_mel": loss_mel,
3436
  "loss_adv": loss_adv,
3437
  "loss_fm": loss_fm,
@@ -3441,6 +4166,7 @@ if __name__ == "__main__" and writer is not None:
3441
  grad_scaler,
3442
  skip_update_ema=iteration > 10 and iteration % 5 != 0,
3443
  )
 
3444
  loss_mel = loss_mel.item()
3445
  loss_adv = loss_adv.item()
3446
  loss_fm = loss_fm.item()
@@ -3461,6 +4187,7 @@ if __name__ == "__main__" and writer is not None:
3461
  grad_scaler.update()
3462
 
3463
  # === 3. ログ ===
 
3464
  dict_scalars["loss_g/loss_mel"].append(loss_mel)
3465
  if h.grad_weight_ap:
3466
  dict_scalars["loss_g/loss_ap"].append(loss_ap)
@@ -3569,11 +4296,8 @@ if __name__ == "__main__" and writer is not None:
3569
  )
3570
 
3571
  # === 4. 検証 ===
3572
- if (iteration + 1) % (
3573
- 50000 if h.n_steps > 200000 else 2000
3574
- ) == 0 or iteration + 1 in {
3575
  1,
3576
- 30000,
3577
  h.n_steps,
3578
  }:
3579
  torch.backends.cudnn.benchmark = False
@@ -3670,36 +4394,36 @@ if __name__ == "__main__" and writer is not None:
3670
  torch.cuda.empty_cache()
3671
 
3672
  # === 5. 保存 ===
3673
- if (iteration + 1) % (
3674
- 50000 if h.n_steps > 200000 else 2000
3675
- ) == 0 or iteration + 1 in {
3676
  1,
3677
- 30000,
3678
  h.n_steps,
3679
  }:
3680
  # チェックポイント
3681
  name = f"{in_wav_dataset_dir.name}_{iteration + 1:08d}"
3682
- checkpoint_file_save = out_dir / f"checkpoint_{name}.pt"
3683
  if checkpoint_file_save.exists():
3684
  checkpoint_file_save = checkpoint_file_save.with_name(
3685
  f"{checkpoint_file_save.name}_{hash(None):x}"
3686
  )
3687
- torch.save(
3688
- {
3689
- "iteration": iteration + 1,
3690
- "net_g": net_g.state_dict(),
3691
- "phone_extractor": phone_extractor.state_dict(),
3692
- "pitch_estimator": pitch_estimator.state_dict(),
3693
- "net_d": net_d.state_dict(),
3694
- "optim_g": optim_g.state_dict(),
3695
- "optim_d": optim_d.state_dict(),
3696
- "grad_balancer": grad_balancer.state_dict(),
3697
- "grad_scaler": grad_scaler.state_dict(),
3698
- "h": dict(h),
3699
- },
3700
- checkpoint_file_save,
3701
- )
3702
- shutil.copy(checkpoint_file_save, out_dir / "checkpoint_latest.pt")
 
 
 
3703
 
3704
  # 推論用
3705
  paraphernalia_dir = out_dir / f"paraphernalia_{name}"
@@ -3713,27 +4437,35 @@ if __name__ == "__main__" and writer is not None:
3713
  phone_extractor_fp16.remove_weight_norm()
3714
  phone_extractor_fp16.merge_weights()
3715
  phone_extractor_fp16.half()
3716
- phone_extractor_fp16.dump(paraphernalia_dir / f"phone_extractor.bin")
3717
  del phone_extractor_fp16
3718
  pitch_estimator_fp16 = PitchEstimator()
3719
  pitch_estimator_fp16.load_state_dict(pitch_estimator.state_dict())
3720
  pitch_estimator_fp16.merge_weights()
3721
  pitch_estimator_fp16.half()
3722
- pitch_estimator_fp16.dump(paraphernalia_dir / f"pitch_estimator.bin")
3723
  del pitch_estimator_fp16
3724
  net_g_fp16 = ConverterNetwork(
3725
- nn.Module(), nn.Module(), len(speakers), h.hidden_channels
 
 
 
 
 
 
 
 
3726
  )
3727
  net_g_fp16.load_state_dict(net_g.state_dict())
3728
  net_g_fp16.merge_weights()
3729
  net_g_fp16.half()
3730
- net_g_fp16.dump(paraphernalia_dir / f"waveform_generator.bin")
3731
- with open(paraphernalia_dir / f"speaker_embeddings.bin", "wb") as f:
3732
- dump_layer(net_g_fp16.embed_speaker, f)
3733
- with open(
3734
- paraphernalia_dir / f"formant_shift_embeddings.bin", "wb"
3735
- ) as f:
3736
- dump_layer(net_g_fp16.embed_formant_shift, f)
3737
  del net_g_fp16
3738
  shutil.copy(
3739
  repo_root() / "assets/images/noimage.png", paraphernalia_dir
 
4
  # %%
5
  import argparse
6
  import gc
7
+ import gzip
8
  import json
9
  import math
10
  import os
 
18
  from pathlib import Path
19
  from pprint import pprint
20
  from random import Random
21
+ from typing import BinaryIO, Literal, Optional, Union, Sequence, Iterable, Callable
22
 
23
  import numpy as np
24
  import pyworld
 
41
 
42
 
43
  # モジュールのバージョンではない
44
+ PARAPHERNALIA_VERSION = "2.0.0-rc.0"
45
 
46
 
47
  def is_notebook() -> bool:
 
60
  # ハイパーパラメータ
61
  # 学習データや出力ディレクトリなど、学習ごとに変わるようなものはここに含めない
62
  dict_default_hparams = {
63
+ # training
64
+ "learning_rate_g": 5e-5,
65
+ "learning_rate_d": 5e-5,
66
+ "learning_rate_decay": 0.999999,
 
67
  "adam_betas": [0.8, 0.99],
68
  "adam_eps": 1e-6,
69
  "batch_size": 8,
70
+ "grad_weight_loudness": 1.0, # grad_weight は比が同じなら同じ意味になるはず
71
+ "grad_weight_mel": 50.0,
72
+ "grad_weight_ap": 100.0,
73
+ "grad_weight_adv": 150.0,
74
+ "grad_weight_fm": 150.0,
75
  "grad_balancer_ema_decay": 0.995,
76
  "use_amp": True,
77
  "num_workers": 16,
78
  "n_steps": 10000,
79
+ "warmup_steps": 5000,
80
+ "evaluation_interval": 2000,
81
+ "save_interval": 2000,
82
  "in_sample_rate": 16000, # 変更不可
83
  "out_sample_rate": 24000, # 変更不可
84
  "wav_length": 4 * 24000, # 4s
85
  "segment_length": 100, # 1s
86
+ "phone_noise_ratio": 0.5,
87
+ "vq_topk": 4,
88
+ "training_time_vq": "none", # "none", "self" or "random"
89
+ "floor_noise_level": 1e-3,
90
+ "record_metrics": False,
91
+ # augmentation
92
+ "augmentation_snr_candidates": [20.0, 25.0, 30.0, 35.0, 40.0, 45.0],
93
+ "augmentation_formant_shift_probability": 0.5,
94
+ "augmentation_formant_shift_semitone_min": -3.0,
95
+ "augmentation_formant_shift_semitone_max": 3.0,
96
+ "augmentation_reverb_probability": 0.5,
97
+ "augmentation_lpf_probability": 0.2,
98
+ "augmentation_lpf_cutoff_freq_candidates": [2000.0, 3000.0, 4000.0, 6000.0],
99
  # data
100
+ "phone_extractor_file": "assets/pretrained/122_checkpoint_03000000.pt",
101
+ "pitch_estimator_file": "assets/pretrained/104_3_checkpoint_00300000.pt",
102
  "in_ir_wav_dir": "assets/ir",
103
  "in_noise_wav_dir": "assets/noise",
104
  "in_test_wav_dir": "assets/test",
105
+ "pretrained_file": "assets/pretrained/151_checkpoint_libritts_r_200_02750000.pt.gz", # None も可
106
  # model
107
+ "pitch_bins": 448, # 変更不可
108
  "hidden_channels": 256, # ファインチューン時変更不可、変更した場合は推論側の対応必要
109
  "san": False, # ファインチューン時変更不可
110
  "compile_convnext": False,
 
135
 
136
 
137
  def prepare_training_configs_for_experiment() -> tuple[dict, Path, Path, bool, bool]:
138
+ import ipynbname # type: ignore[import]
139
+ from IPython import get_ipython # type: ignore[import]
140
 
141
  h = deepcopy(dict_default_hparams)
142
  in_wav_dataset_dir = repo_root() / "../../data/processed/libritts_r_200"
 
245
  elif isinstance(layer, (nn.Linear, nn.Conv1d, nn.LayerNorm)):
246
  dump(layer.weight)
247
  dump(layer.bias)
248
+ elif isinstance(layer, nn.MultiheadAttention):
249
+ embed_dim = layer.embed_dim
250
+ num_heads = layer.num_heads
251
+ # [3 * embed_dim, embed_dim]
252
+ in_proj_weight = layer.in_proj_weight.data.clone()
253
+ in_proj_weight[: 2 * embed_dim] *= 1.0 / math.sqrt(
254
+ math.sqrt(embed_dim // num_heads)
255
+ )
256
+ in_proj_weight = in_proj_weight.view(
257
+ 3, num_heads, embed_dim // num_heads, embed_dim
258
+ )
259
+ # [num_heads, 3, embed_dim / num_heads, embed_dim]
260
+ in_proj_weight = in_proj_weight.transpose(0, 1)
261
+ # [3 * embed_dim]
262
+ in_proj_bias = layer.in_proj_bias.data.clone()
263
+ in_proj_bias[: 2 * embed_dim] *= 1.0 / math.sqrt(
264
+ math.sqrt(embed_dim // num_heads)
265
+ )
266
+ in_proj_bias = in_proj_bias.view(3, num_heads, embed_dim // num_heads)
267
+ # [num_heads, 3, embed_dim / num_heads]
268
+ in_proj_bias = in_proj_bias.transpose(0, 1)
269
+ dump(in_proj_weight)
270
+ dump(in_proj_bias)
271
+ dump(layer.out_proj.weight)
272
+ dump(layer.out_proj.bias)
273
  elif isinstance(layer, nn.Embedding):
274
  dump(layer.weight)
275
  elif isinstance(layer, nn.Parameter):
276
  dump(layer)
277
  elif isinstance(layer, nn.ModuleList):
278
+ for layer_i in layer:
279
+ dump_layer(layer_i, f)
280
  else:
281
  assert False, layer
282
 
 
395
  self.gain.data.fill_(1.0)
396
 
397
 
398
+ class CrossAttention(nn.Module):
399
+ def __init__(
400
+ self,
401
+ qk_channels: int,
402
+ vo_channels: int,
403
+ num_heads: int,
404
+ in_q_channels: int,
405
+ in_kv_channels: int,
406
+ out_channels: int,
407
+ dropout: float = 0.0,
408
+ ):
409
+ super().__init__()
410
+ assert qk_channels % num_heads == 0
411
+ self.qk_channels = qk_channels
412
+ self.vo_channels = vo_channels
413
+ self.num_heads = num_heads
414
+ self.in_q_channels = in_q_channels
415
+ self.in_kv_channels = in_kv_channels
416
+ self.out_channels = out_channels
417
+ self.dropout = dropout
418
+ self.head_qk_channels = qk_channels // num_heads
419
+ self.head_vo_channels = vo_channels // num_heads
420
+ self.q_projection = nn.Linear(in_q_channels, qk_channels)
421
+ self.q_projection.weight.data.normal_(0.0, math.sqrt(1.0 / in_q_channels))
422
+ self.q_projection.bias.data.zero_()
423
+ self.kv_projection = nn.Linear(in_kv_channels, qk_channels + vo_channels)
424
+ self.kv_projection.weight.data.normal_(0.0, math.sqrt(1.0 / in_kv_channels))
425
+ self.kv_projection.bias.data.zero_()
426
+ self.out_projection = nn.Linear(vo_channels, out_channels)
427
+ self.out_projection.weight.data.normal_(0.0, math.sqrt(1.0 / vo_channels))
428
+ self.out_projection.bias.data.zero_()
429
+
430
+ def forward(
431
+ self,
432
+ q: torch.Tensor,
433
+ kv: torch.Tensor,
434
+ ) -> torch.Tensor:
435
+ # q: [batch_size, q_length, in_q_channels]
436
+ # kv: [batch_size, kv_length, in_kv_channels]
437
+ batch_size, q_length, _ = q.size()
438
+ _, kv_length, _ = kv.size()
439
+ # [batch_size, q_length, qk_channels]
440
+ q = self.q_projection(q)
441
+ # [batch_size, kv_length, qk_channels + vo_channels]
442
+ kv = self.kv_projection(kv)
443
+ # [batch_size, kv_length, qk_channels], [batch_size, kv_length, vo_channels]
444
+ k, v = kv.split([self.qk_channels, self.vo_channels], dim=2)
445
+ q = q.view(
446
+ batch_size, q_length, self.num_heads, self.head_qk_channels
447
+ ).transpose(1, 2)
448
+ k = k.view(
449
+ batch_size, kv_length, self.num_heads, self.head_qk_channels
450
+ ).transpose(1, 2)
451
+ v = v.view(
452
+ batch_size, kv_length, self.num_heads, self.head_vo_channels
453
+ ).transpose(1, 2)
454
+ # [batch_size, num_heads, q_length, head_vo_channels]
455
+ attn_out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout)
456
+ # [batch_size, q_length, vo_channels]
457
+ attn_out = (
458
+ attn_out.transpose(1, 2)
459
+ .contiguous()
460
+ .view(batch_size, q_length, self.vo_channels)
461
+ )
462
+ # [batch_size, q_length, out_channels]
463
+ attn_out = self.out_projection(attn_out)
464
+ return attn_out
465
+
466
+ def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
467
+ if isinstance(f, (str, bytes, os.PathLike)):
468
+ with open(f, "wb") as f:
469
+ self.dump(f)
470
+ return
471
+ if not hasattr(f, "write"):
472
+ raise TypeError
473
+
474
+ q_projection_weight = self.q_projection.weight.data.clone()
475
+ q_projection_bias = self.q_projection.bias.data.clone()
476
+ q_projection_weight *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels))
477
+ q_projection_bias *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels))
478
+ dump_params(q_projection_weight, f)
479
+ dump_params(q_projection_bias, f)
480
+ dump_layer(self.out_projection, f)
481
+
482
+ def dump_kv(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
483
+ if isinstance(f, (str, bytes, os.PathLike)):
484
+ with open(f, "wb") as f:
485
+ self.dump_kv(f)
486
+ return
487
+ if not hasattr(f, "write"):
488
+ raise TypeError
489
+
490
+ kv_projection_weight = self.kv_projection.weight.data.clone()
491
+ kv_projection_bias = self.kv_projection.bias.data.clone()
492
+ k_projection_weight, v_projection_weight = kv_projection_weight.split(
493
+ [self.qk_channels, self.vo_channels]
494
+ )
495
+ k_projection_bias, v_projection_bias = kv_projection_bias.split(
496
+ [self.qk_channels, self.vo_channels]
497
+ )
498
+ k_projection_weight *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels))
499
+ k_projection_bias *= 1.0 / math.sqrt(math.sqrt(self.head_qk_channels))
500
+ # [qk_channels, in_kv_channels] -> [num_heads, head_qk_channels, in_kv_channels]
501
+ k_projection_weight = k_projection_weight.view(
502
+ self.num_heads, self.head_qk_channels, self.in_kv_channels
503
+ )
504
+ # [qk_channels] -> [num_heads, head_qk_channels]
505
+ k_projection_bias = k_projection_bias.view(
506
+ self.num_heads, self.head_qk_channels
507
+ )
508
+ # [vo_channels, in_kv_channels] -> [num_heads, head_vo_channels, in_kv_channels]
509
+ v_projection_weight = v_projection_weight.view(
510
+ self.num_heads, self.head_vo_channels, self.in_kv_channels
511
+ )
512
+ # [vo_channels] -> [num_heads, head_vo_channels]
513
+ v_projection_bias = v_projection_bias.view(
514
+ self.num_heads, self.head_vo_channels
515
+ )
516
+ for i in range(self.num_heads):
517
+ # [head_qk_channels, in_kv_channels]
518
+ dump_params(k_projection_weight[i], f)
519
+ # [head_vo_channels, in_kv_channels]
520
+ dump_params(v_projection_weight[i], f)
521
+ for i in range(self.num_heads):
522
+ # [head_qk_channels]
523
+ dump_params(k_projection_bias[i], f)
524
+ # [head_vo_channels]
525
+ dump_params(v_projection_bias[i], f)
526
+
527
+
528
  class ConvNeXtBlock(nn.Module):
529
  def __init__(
530
  self,
 
536
  enable_scaling: bool = False,
537
  pre_scale: float = 1.0,
538
  post_scale: float = 1.0,
539
+ use_mha: bool = False,
540
+ cross_attention: bool = False,
541
+ num_heads: int = 4,
542
+ attention_dropout: float = 0.1,
543
+ attention_channels: Optional[int] = None,
544
+ kv_channels: Optional[int] = None,
545
  ):
546
  super().__init__()
547
  self.use_weight_standardization = use_weight_standardization
548
  self.enable_scaling = enable_scaling
549
+ self.use_mha = use_mha
550
+ self.cross_attention = cross_attention
551
+ if use_mha:
552
+ self.attn_norm = nn.LayerNorm(channels)
553
+ if cross_attention:
554
+ self.mha = CrossAttention(
555
+ qk_channels=attention_channels,
556
+ vo_channels=attention_channels,
557
+ num_heads=num_heads,
558
+ in_q_channels=channels,
559
+ in_kv_channels=kv_channels,
560
+ out_channels=channels,
561
+ dropout=attention_dropout,
562
+ )
563
+ else: # self-attention
564
+ assert attention_channels is None
565
+ assert kv_channels is None
566
+ self.mha = nn.MultiheadAttention(
567
+ embed_dim=channels,
568
+ num_heads=num_heads,
569
+ dropout=attention_dropout,
570
+ batch_first=True,
571
+ )
572
  self.dwconv = CausalConv1d(
573
  channels, channels, kernel_size=kernel_size, groups=channels
574
  )
 
593
  self.register_buffer("post_scale", torch.tensor(post_scale))
594
  self.post_scale_weight = nn.Parameter(torch.ones(()))
595
 
596
+ def forward(
597
+ self,
598
+ x: torch.Tensor,
599
+ attn_mask: Optional[torch.Tensor] = None,
600
+ kv: Optional[torch.Tensor] = None,
601
+ ) -> torch.Tensor:
602
+ if self.use_mha:
603
+ batch_size, channels, length = x.size()
604
+ if self.cross_attention:
605
+ assert kv is not None
606
+ else:
607
+ assert kv is None
608
+ assert length % 4 == 0
609
+ identity = x
610
+ if self.cross_attention:
611
+ # kv: [batch_size, kv_length, kv_channels]
612
+ x = x.transpose(1, 2)
613
+ x = self.attn_norm(x)
614
+ x = self.mha(x, kv)
615
+ x = x.transpose(1, 2)
616
+ else:
617
+ x = x.view(batch_size, channels, length // 4, 4)
618
+ x = x.permute(0, 3, 2, 1)
619
+ x = x.reshape(batch_size * 4, length // 4, channels)
620
+ x = self.attn_norm(x)
621
+ x, _ = self.mha(
622
+ x, x, x, attn_mask=attn_mask, is_causal=True, need_weights=False
623
+ )
624
+ x = x.view(batch_size, 4, length // 4, channels)
625
+ x = x.permute(0, 3, 2, 1)
626
+ x = x.reshape(batch_size, channels, length)
627
+ x += identity
628
+
629
  identity = x
630
  if self.enable_scaling:
631
  x = x * self.pre_scale
 
644
  return x
645
 
646
  def merge_weights(self):
647
+ if self.use_mha:
648
+ if self.cross_attention:
649
+ assert isinstance(self.mha, CrossAttention)
650
+ self.mha.q_projection.bias.data += torch.mv(
651
+ self.mha.q_projection.weight.data, self.attn_norm.bias.data
652
+ )
653
+ self.mha.q_projection.weight.data *= self.attn_norm.weight.data[None, :]
654
+ self.attn_norm.bias.data[:] = 0.0
655
+ self.attn_norm.weight.data[:] = 1.0
656
+ else: # self-attention
657
+ assert isinstance(self.mha, nn.MultiheadAttention)
658
+ self.mha.in_proj_bias.data += torch.mv(
659
+ self.mha.in_proj_weight.data, self.attn_norm.bias.data
660
+ )
661
+ self.mha.in_proj_weight.data *= self.attn_norm.weight.data[None, :]
662
+ self.attn_norm.bias.data[:] = 0.0
663
+ self.attn_norm.weight.data[:] = 1.0
664
  if self.use_weight_standardization:
665
  self.dwconv.merge_weights()
666
  self.pwconv1.merge_weights()
667
  self.pwconv2.merge_weights()
668
  else:
669
+ self.pwconv1.bias.data += torch.mv(
670
+ self.pwconv1.weight.data, self.norm.bias.data
671
+ )
672
  self.pwconv1.weight.data *= self.norm.weight.data[None, :]
673
  self.norm.bias.data[:] = 0.0
674
  self.norm.weight.data[:] = 1.0
 
693
  if not hasattr(f, "write"):
694
  raise TypeError
695
 
696
+ if self.use_mha:
697
+ dump_layer(self.mha, f)
698
  dump_layer(self.dwconv, f)
699
  dump_layer(self.pwconv1, f)
700
  dump_layer(self.pwconv2, f)
 
712
  kernel_size: int,
713
  use_weight_standardization: bool = False,
714
  enable_scaling: bool = False,
715
+ use_mha: bool = False,
716
+ cross_attention: bool = False,
717
+ kv_channels: Optional[int] = None,
718
  ):
719
  super().__init__()
720
  assert delay * 2 + 1 <= embed_kernel_size
721
+ assert not (use_weight_standardization and use_mha) # 未対応
722
  self.use_weight_standardization = use_weight_standardization
723
+ self.use_mha = use_mha
724
+ self.cross_attention = cross_attention
725
  self.embed = CausalConv1d(in_channels, channels, embed_kernel_size, delay=delay)
726
  self.norm = nn.LayerNorm(channels)
727
  self.convnext = nn.ModuleList()
 
737
  enable_scaling=enable_scaling,
738
  pre_scale=pre_scale,
739
  post_scale=post_scale,
740
+ use_mha=use_mha,
741
+ cross_attention=cross_attention,
742
+ num_heads=4,
743
+ attention_dropout=0.1,
744
+ attention_channels=kv_channels,
745
+ kv_channels=kv_channels,
746
  )
747
  self.convnext.append(block)
748
  self.final_layer_norm = nn.LayerNorm(channels)
 
755
  self.norm = nn.Identity()
756
  self.final_layer_norm = nn.Identity()
757
 
758
+ def forward(
759
+ self, x: torch.Tensor, kv: Optional[torch.Tensor] = None
760
+ ) -> torch.Tensor:
761
  x = self.embed(x)
762
  x = self.norm(x.transpose(1, 2)).transpose(1, 2)
763
+ if self.use_mha and not self.cross_attention:
764
+ pad_length = -x.size(2) % 4
765
+ if pad_length:
766
+ x = F.pad(x, (0, pad_length))
767
+ t40 = x.size(2) // 4
768
+ attn_mask = torch.ones((t40, t40), dtype=torch.bool, device=x.device).triu(
769
+ 1
770
+ )
771
+ else:
772
+ attn_mask = None
773
  for conv_block in self.convnext:
774
+ x = conv_block(x, attn_mask=attn_mask, kv=kv)
775
+ if self.use_mha and not self.cross_attention and pad_length:
776
+ x = x[:, :, :-pad_length]
777
  x = self.final_layer_norm(x.transpose(1, 2)).transpose(1, 2)
778
  return x
779
 
 
798
  if not self.use_weight_standardization:
799
  dump_layer(self.final_layer_norm, f)
800
 
801
+ def dump_kv(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
802
+ if isinstance(f, (str, bytes, os.PathLike)):
803
+ with open(f, "wb") as f:
804
+ self.dump_kv(f)
805
+ return
806
+ if not hasattr(f, "write"):
807
+ raise TypeError
808
+
809
+ assert self.use_mha and self.cross_attention
810
+ for conv_block in self.convnext:
811
+ if not conv_block.use_mha or not conv_block.cross_attention:
812
+ continue
813
+ assert isinstance(conv_block, ConvNeXtBlock)
814
+ assert hasattr(conv_block, "mha")
815
+ assert isinstance(conv_block.mha, CrossAttention)
816
+ conv_block.mha.dump_kv(f)
817
+
818
 
819
  class FeatureExtractor(nn.Module):
820
  def __init__(self, hidden_channels: int):
 
868
 
869
 
870
  class FeatureProjection(nn.Module):
871
+ def __init__(self, channels: int):
872
  super().__init__()
873
+ self.norm = nn.LayerNorm(channels)
 
874
  self.dropout = nn.Dropout(0.1)
875
 
876
  def forward(self, x: torch.Tensor) -> torch.Tensor:
877
  # [batch_size, channels, length]
878
  x = self.norm(x.transpose(1, 2)).transpose(1, 2)
 
879
  x = self.dropout(x)
880
  return x
881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
882
 
883
  class PhoneExtractor(nn.Module):
884
  def __init__(
885
  self,
886
+ phone_channels: int = 128,
887
+ hidden_channels: int = 128,
888
+ backbone_embed_kernel_size: int = 9,
889
  kernel_size: int = 17,
890
+ n_blocks: int = 20,
891
  ):
892
  super().__init__()
893
  self.feature_extractor = FeatureExtractor(hidden_channels)
894
+ self.feature_projection = FeatureProjection(hidden_channels)
 
 
 
 
 
 
 
 
 
 
 
 
895
  self.backbone = ConvNeXtStack(
896
  in_channels=hidden_channels,
897
  channels=hidden_channels,
 
900
  delay=0,
901
  embed_kernel_size=backbone_embed_kernel_size,
902
  kernel_size=kernel_size,
903
+ use_mha=True,
904
  )
905
  self.head = weight_norm(nn.Conv1d(hidden_channels, phone_channels, 1))
906
 
 
917
  stats["feature_norm"] = x.detach().norm(dim=1).mean()
918
  # [batch_size, feature_extractor_hidden_channels, length] -> [batch_size, hidden_channels, length]
919
  x = self.feature_projection(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
920
  # [batch_size, hidden_channels, length]
921
+ x = self.backbone(x)
922
  # [batch_size, hidden_channels, length] -> [batch_size, phone_channels, length]
923
  phone = self.head(F.gelu(x, approximate="tanh"))
924
 
925
  results = [phone]
926
  if return_stats:
927
+ stats["code_norm"] = phone.detach().norm(dim=1).mean()
928
  results.append(stats)
929
 
930
  if len(results) == 1:
 
944
 
945
  def remove_weight_norm(self):
946
  self.feature_extractor.remove_weight_norm()
 
 
 
947
  remove_weight_norm(self.head)
948
 
949
  def merge_weights(self):
 
950
  self.backbone.merge_weights()
951
 
952
+ self.backbone.embed.bias.data += (
953
+ (
954
+ self.feature_projection.norm.bias.data[None, :, None]
955
+ * self.backbone.embed.weight.data # [o, i, k]
956
+ )
957
+ .sum(1)
958
+ .sum(1)
959
+ )
960
+ self.backbone.embed.weight.data *= self.feature_projection.norm.weight.data[
961
+ None, :, None
962
+ ]
963
+ self.feature_projection.norm.bias.data[:] = 0.0
964
+ self.feature_projection.norm.weight.data[:] = 1.0
965
+
966
  def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
967
  if isinstance(f, (str, bytes, os.PathLike)):
968
  with open(f, "wb") as f:
 
972
  raise TypeError
973
 
974
  dump_layer(self.feature_extractor, f)
 
 
975
  dump_layer(self.backbone, f)
976
  dump_layer(self.head, f)
977
 
978
 
979
+ class VectorQuantizer(nn.Module):
980
+ def __init__(
981
+ self,
982
+ n_speakers: int,
983
+ codebook_size: int,
984
+ channels: int,
985
+ topk: int = 4,
986
+ training_time_vq: Literal["none", "self", "random"] = "none",
987
+ ):
988
+ super().__init__()
989
+ assert 1 <= topk <= codebook_size
990
+ self.n_speakers = n_speakers
991
+ self.codebook_size = codebook_size
992
+ self.channels = channels
993
+ self.topk = topk
994
+ self.training_time_vq = training_time_vq
995
+
996
+ self.register_buffer(
997
+ "codebooks",
998
+ torch.empty(n_speakers, codebook_size, channels, dtype=torch.half),
999
+ )
1000
+ self.codebooks: torch.Tensor
1001
+
1002
+ # VQ の適用箇所を変更しやすいように hook にしている
1003
+ self._hook_handle: Optional[torch.utils.hooks.RemovableHandle] = None
1004
+ self.target_speaker_ids: Optional[torch.Tensor] = None
1005
+
1006
+ def _hook(_, __, output):
1007
+ return self(output, self.target_speaker_ids)
1008
+
1009
+ self._hook_fn = _hook
1010
+
1011
+ @torch.no_grad()
1012
+ def build_codebooks(
1013
+ self,
1014
+ collector_func: Callable,
1015
+ target_layer: nn.Module,
1016
+ inputs: Sequence[Iterable[torch.Tensor]],
1017
+ kmeans_n_iters: int = 50,
1018
+ ):
1019
+ assert len(inputs) == self.n_speakers
1020
+ assert self._hook_handle is None, "hook already installed"
1021
+ device = next(self.buffers()).device
1022
+
1023
+ for spk_id, inps in enumerate(tqdm(inputs, desc="Building codebooks")):
1024
+ activations: list[torch.Tensor] = []
1025
+
1026
+ # TODO: データ多すぎる場合に間引く処理をする
1027
+
1028
+ def _collect(_, __, output):
1029
+ # output: [batch_size, channels, length]
1030
+ activations.append(output.detach())
1031
+
1032
+ handle = target_layer.register_forward_hook(_collect)
1033
+ for x in inps:
1034
+ collector_func(x.to(device))
1035
+ handle.remove()
1036
+
1037
+ if not activations:
1038
+ raise RuntimeError(f"No activation collected for speaker {spk_id}")
1039
+
1040
+ # [n_data, channels]
1041
+ activations: torch.Tensor = torch.cat(
1042
+ [
1043
+ a.transpose(1, 2).reshape(a.size(0) * a.size(2), self.channels)
1044
+ for a in activations
1045
+ ]
1046
+ )
1047
+ activations = activations.float()
1048
+ activations = F.normalize(activations, dim=1, eps=1e-6)
1049
+ # [codebook_size, channels]
1050
+ centers = (
1051
+ self._kmeans_plus_plus(activations, self.codebook_size, kmeans_n_iters)
1052
+ if activations.size(0) >= self.codebook_size
1053
+ else self._pad_replicate(activations, self.codebook_size)
1054
+ )
1055
+ self.codebooks[spk_id] = centers.to(self.codebooks.dtype)
1056
+
1057
+ def forward(
1058
+ self, x: torch.Tensor, speaker_ids: Optional[torch.Tensor] = None
1059
+ ) -> torch.Tensor:
1060
+ batch_size, channels, length = x.size()
1061
+ assert channels == self.channels
1062
+ device = x.device
1063
+ dtype = x.dtype
1064
+
1065
+ if self.training:
1066
+ if self.training_time_vq == "none":
1067
+ return x
1068
+ elif self.training_time_vq == "self":
1069
+ if self.target_speaker_ids is None:
1070
+ raise ValueError("target_speaker_ids is not set")
1071
+ elif self.training_time_vq == "random":
1072
+ speaker_ids = torch.randint(
1073
+ 0, self.n_speakers, (batch_size,), device=device
1074
+ )
1075
+ else:
1076
+ raise ValueError(f"Unknown training_time_vq: {self.training_time_vq}")
1077
+ else:
1078
+ if speaker_ids is None:
1079
+ return x
1080
+ speaker_ids = speaker_ids.to(device)
1081
+
1082
+ # [batch_size, channels, length] → [batch_size, length, channels]
1083
+ q = F.normalize(x, dim=1, eps=1e-6)
1084
+ codes = self.codebooks[speaker_ids].to(q.dtype)
1085
+ # [batch_size, length, codebook_size]
1086
+ sim = torch.einsum("bcl,bkc->blk", q, codes)
1087
+
1088
+ # [batch_size, length, topk]
1089
+ _, topk_idx = sim.topk(self.topk, dim=-1)
1090
+ # [batch_size, length, codebook_size, channels]
1091
+ expanded_codes = codes[:, None, :, :].expand(-1, length, -1, -1)
1092
+ # [batch_size, length, topk, channels]
1093
+ expanded_topk_idx = topk_idx[:, :, :, None].expand(-1, -1, -1, channels)
1094
+ # [batch_size, length, topk, channels]
1095
+ gathered = expanded_codes.gather(2, expanded_topk_idx)
1096
+ # [batch_size, length, channels]
1097
+ gathered = gathered.mean(2)
1098
+ # [batch_size, channels, length]
1099
+ return gathered.transpose(1, 2).to(dtype)
1100
+
1101
+ def enable_hook(self, target_layer: nn.Module):
1102
+ if self._hook_handle is not None:
1103
+ raise RuntimeError("hook already installed")
1104
+ self._hook_handle = target_layer.register_forward_hook(self._hook_fn)
1105
+
1106
+ def disable_hook(self):
1107
+ if self._hook_handle is None:
1108
+ raise RuntimeError("hook not installed")
1109
+ self._hook_handle.remove()
1110
+ self._hook_handle = None
1111
+
1112
+ def set_target_speaker_ids(self, speaker_ids: Optional[torch.Tensor]):
1113
+ # この話者が使われる条件は forward() を参照
1114
+ self.target_speaker_ids = speaker_ids
1115
+
1116
+ @staticmethod
1117
+ def _pad_replicate(x: torch.Tensor, n: int) -> torch.Tensor:
1118
+ # データ数が n に満たないとき適当に複製して埋める
1119
+ idx = torch.arange(n, device=x.device) % x.size(0)
1120
+ return x[idx]
1121
+
1122
+ @staticmethod
1123
+ def _kmeans_plus_plus(
1124
+ x: torch.Tensor, n_clusters: int, n_iters: int = 50
1125
+ ) -> torch.Tensor:
1126
+ n_data, _ = x.size()
1127
+ center_indices = [torch.randint(0, n_data, ()).item()]
1128
+ min_distances = torch.full((n_data,), math.inf, device=x.device)
1129
+ for _ in range(1, n_clusters):
1130
+ last_center_index = center_indices[-1]
1131
+ min_distances = min_distances.minimum(
1132
+ torch.cdist(x, x[last_center_index : last_center_index + 1])
1133
+ .float()
1134
+ .square_()
1135
+ .squeeze_(1)
1136
+ )
1137
+ probs = min_distances / (min_distances.sum() + 1e-12)
1138
+ center_indices.append(torch.multinomial(probs, 1).item())
1139
+ centers = x[center_indices]
1140
+ del min_distances, probs
1141
+ for _ in range(n_iters):
1142
+ distances = torch.cdist(x, centers) # [n_data, n_clusters]
1143
+ labels = distances.argmin(1) # [n_data]
1144
+ # [n_clusters, dim]
1145
+ new_centers = torch.zeros_like(centers).index_add_(0, labels, x)
1146
+ # [n_clusters]
1147
+ counts = labels.bincount(minlength=n_clusters)
1148
+ if (counts == 0).sum().item() != 0:
1149
+ # TODO: 割り当てがないクラスタの処理
1150
+ warnings.warn("Some clusters have no assigned data points.")
1151
+ new_centers /= counts[:, None].clamp_(min=1).float()
1152
+ centers = new_centers
1153
+ return centers
1154
+
1155
+
1156
  # %% [markdown]
1157
  # ## Pitch Estimator
1158
 
 
1200
  )
1201
 
1202
  # 自己相関
 
1203
  # 元々これに 2.0 / corr_win_length を掛けて使おうと思っていたが、
1204
  # この値は振幅の 2 乗に比例していて、NN に入力するために良い感じに分散を
1205
  # 標準化する方法が思いつかなかったのでやめた
 
1245
  self,
1246
  input_instfreq_channels: int = 192,
1247
  input_corr_channels: int = 256,
1248
+ pitch_bins: int = 448,
1249
  channels: int = 192,
1250
+ intermediate_channels: int = 192 * 2,
1251
+ n_blocks: int = 9,
1252
  delay: int = 1, # 10ms, 特徴抽出と合わせると 22.5ms
1253
  embed_kernel_size: int = 3,
1254
  kernel_size: int = 33,
1255
+ pitch_bins_per_octave: int = 96,
1256
  ):
1257
  super().__init__()
1258
+ self.pitch_bins_per_octave = pitch_bins_per_octave
1259
 
1260
  self.instfreq_embed_0 = nn.Conv1d(input_instfreq_channels, channels, 1)
1261
  self.instfreq_embed_1 = nn.Conv1d(channels, channels, 1)
 
1269
  delay,
1270
  embed_kernel_size,
1271
  kernel_size,
1272
+ enable_scaling=True,
1273
  )
1274
+ self.head = nn.Conv1d(channels, pitch_bins, 1)
1275
 
1276
  def forward(self, wav: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
1277
  # wav: [batch_size, 1, wav_length]
 
1294
  corr_diff = F.gelu(self.corr_embed_0(corr_diff), approximate="tanh")
1295
  corr_diff = self.corr_embed_1(corr_diff)
1296
  # [batch_size, channels, length]
1297
+ x = F.gelu(instfreq_features + corr_diff, approximate="tanh")
1298
  x = self.backbone(x)
1299
+ # [batch_size, pitch_bins, length]
1300
  x = self.head(x)
1301
  return x, energy
1302
 
1303
  def sample_pitch(
1304
+ self, pitch: torch.Tensor, band_width: int = 4, return_features: bool = False
1305
  ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
1306
+ # pitch: [batch_size, pitch_bins, length]
1307
  # 返されるピッチの値には 0 は含まれない
1308
+ batch_size, pitch_bins, length = pitch.size()
1309
  pitch = pitch.softmax(1)
1310
  if return_features:
1311
  unvoiced_proba = pitch[:, :1, :].clone()
1312
  pitch[:, 0, :] = -100.0
1313
  pitch = (
1314
+ pitch.transpose(1, 2).contiguous().view(batch_size * length, 1, pitch_bins)
 
 
1315
  )
1316
  band_pitch = F.conv1d(
1317
  pitch,
1318
  torch.ones((1, 1, 1), device=pitch.device).expand(1, 1, band_width),
1319
  )
1320
+ # [batch_size * length, 1, pitch_bins - band_width + 1] -> Long[batch_size * length, 1]
1321
  quantized_band_pitch = band_pitch.argmax(2)
1322
  if return_features:
1323
  # [batch_size * length, 1]
 
1325
  # [batch_size * length, 1]
1326
  half_pitch_band_proba = band_pitch.gather(
1327
  2,
1328
+ (quantized_band_pitch - self.pitch_bins_per_octave).clamp_(min=1)[
1329
+ :, :, None
1330
+ ],
1331
  )
1332
+ half_pitch_band_proba[
1333
+ quantized_band_pitch <= self.pitch_bins_per_octave
1334
+ ] = 0.0
1335
  half_pitch_proba = (half_pitch_band_proba / (band_proba + 1e-6)).view(
1336
  batch_size, 1, length
1337
  )
1338
  # [batch_size * length, 1]
1339
  double_pitch_band_proba = band_pitch.gather(
1340
  2,
1341
+ (quantized_band_pitch + self.pitch_bins_per_octave).clamp_(
1342
+ max=pitch_bins - band_width
1343
  )[:, :, None],
1344
  )
1345
  double_pitch_band_proba[
1346
  quantized_band_pitch
1347
+ > pitch_bins - band_width - self.pitch_bins_per_octave
1348
  ] = 0.0
1349
  double_pitch_proba = (double_pitch_band_proba / (band_proba + 1e-6)).view(
1350
  batch_size, 1, length
1351
  )
1352
+ # Long[1, pitch_bins]
1353
+ mask = torch.arange(pitch_bins, device=pitch.device)[None, :]
1354
+ # bool[batch_size * length, pitch_bins]
1355
  mask = (quantized_band_pitch <= mask) & (
1356
  mask < quantized_band_pitch + band_width
1357
  )
 
1500
  return noise, excitation # [batch_size, length * hop_length]
1501
 
1502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1503
  D4C_PREVENT_ZERO_DIVISION = True # False にすると本家の処理
1504
 
1505
 
 
1887
  def __init__(
1888
  self,
1889
  channels: int,
1890
+ speaker_embedding_channels: int = 128,
1891
  hop_length: int = 240,
1892
  n_pre_blocks: int = 4,
1893
  out_sample_rate: float = 24000.0,
 
1899
  self.prenet = ConvNeXtStack(
1900
  in_channels=channels,
1901
  channels=channels,
1902
+ intermediate_channels=channels * 2,
1903
  n_blocks=n_pre_blocks,
1904
  delay=2, # 20ms 遅延
1905
  embed_kernel_size=7,
1906
  kernel_size=33,
1907
  enable_scaling=True,
1908
+ use_mha=True,
1909
+ cross_attention=True,
1910
+ kv_channels=speaker_embedding_channels,
1911
  )
1912
  self.ir_generator = ConvNeXtStack(
1913
  in_channels=channels,
1914
  channels=channels,
1915
+ intermediate_channels=channels * 2,
1916
  n_blocks=2,
1917
  delay=0,
1918
  embed_kernel_size=3,
 
1926
  self.aperiodicity_generator = ConvNeXtStack(
1927
  in_channels=channels,
1928
  channels=channels,
1929
+ intermediate_channels=channels * 2,
1930
  n_blocks=1,
1931
  delay=0,
1932
  embed_kernel_size=3,
 
1939
  self.post_filter_generator = ConvNeXtStack(
1940
  in_channels=channels,
1941
  channels=channels,
1942
+ intermediate_channels=channels * 2,
1943
  n_blocks=1,
1944
  delay=0,
1945
  embed_kernel_size=3,
 
1951
  self.register_buffer("post_filter_scale", torch.tensor(0.01))
1952
 
1953
  def forward(
1954
+ self, x: torch.Tensor, pitch: torch.Tensor, speaker_embedding: torch.Tensor
1955
  ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
1956
  # x: [batch_size, channels, length]
1957
  # pitch: [batch_size, length]
1958
+ # speaker_embedding: [batch_size, speaker_embedding_length, speaker_embedding_channels]
1959
  batch_size, _, length = x.size()
1960
 
1961
+ x = self.prenet(x, speaker_embedding)
1962
  ir = self.ir_generator(x)
1963
  ir = F.silu(ir, inplace=True)
1964
  # [batch_size, 512, length]
 
2042
  # [batch_size, 1, length * hop_length]
2043
  y_g_hat = (periodic_signal + aperiodic_signal)[:, None, :]
2044
 
 
 
2045
  return y_g_hat, {
2046
  "periodic_signal": periodic_signal.detach(),
2047
  "aperiodic_signal": aperiodic_signal.detach(),
 
2158
  phone_extractor: PhoneExtractor,
2159
  pitch_estimator: PitchEstimator,
2160
  n_speakers: int,
2161
+ pitch_bins: int,
2162
  hidden_channels: int,
2163
+ vq_topk: int = 4,
2164
+ training_time_vq: Literal["none", "self", "random"] = "none",
2165
+ phone_noise_ratio: int = 0.5,
2166
+ floor_noise_level: float = 1e-3,
2167
  ):
2168
  super().__init__()
2169
  self.frozen_modules = {
2170
  "phone_extractor": phone_extractor.eval().requires_grad_(False),
2171
  "pitch_estimator": pitch_estimator.eval().requires_grad_(False),
2172
  }
2173
+ self.pitch_bins = pitch_bins
2174
+ self.phone_noise_ratio = phone_noise_ratio
2175
+ self.floor_noise_level = floor_noise_level
2176
  self.out_sample_rate = out_sample_rate = 24000
2177
+ phone_channels = 128
2178
+ self.vq = VectorQuantizer(
2179
+ n_speakers=n_speakers,
2180
+ codebook_size=512,
2181
+ channels=phone_channels,
2182
+ topk=vq_topk,
2183
+ training_time_vq=training_time_vq,
2184
+ )
2185
+ self.embed_phone = nn.Conv1d(phone_channels, hidden_channels, 1)
2186
  self.embed_phone.weight.data.normal_(0.0, math.sqrt(2.0 / (256 * 5)))
2187
  self.embed_phone.bias.data.zero_()
2188
+ self.embed_quantized_pitch = nn.Embedding(pitch_bins, hidden_channels)
2189
  phase = (
2190
+ torch.arange(pitch_bins, dtype=torch.float)[:, None]
2191
  * (
2192
  torch.arange(0, hidden_channels, 2, dtype=torch.float)
2193
  * (-math.log(10000.0) / hidden_channels)
 
2204
  self.embed_speaker.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0))
2205
  self.embed_formant_shift = nn.Embedding(9, hidden_channels)
2206
  self.embed_formant_shift.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0))
2207
+
2208
+ self.key_value_speaker_embedding_length = 384
2209
+ self.key_value_speaker_embedding_channels = 128
2210
+ self.key_value_speaker_embedding = nn.Embedding(
2211
+ n_speakers,
2212
+ self.key_value_speaker_embedding_length
2213
+ * self.key_value_speaker_embedding_channels,
2214
+ )
2215
+ self.key_value_speaker_embedding.weight.data[0].normal_()
2216
+ self.key_value_speaker_embedding.weight.data[1:] = (
2217
+ self.key_value_speaker_embedding.weight.data[0]
2218
+ )
2219
+
2220
  self.vocoder = Vocoder(
2221
  channels=hidden_channels,
2222
+ speaker_embedding_channels=self.key_value_speaker_embedding_channels,
2223
  hop_length=out_sample_rate // 100,
2224
  n_pre_blocks=4,
2225
  out_sample_rate=out_sample_rate,
 
2247
  )
2248
  )
2249
 
2250
+ def initialize_vq(self, inputs: Sequence[Iterable[torch.Tensor]]):
2251
+ collector_func = self.frozen_modules["phone_extractor"].units
2252
+ target_layer = self.frozen_modules["phone_extractor"].head
2253
+
2254
+ self.vq.build_codebooks(
2255
+ collector_func,
2256
+ target_layer,
2257
+ inputs,
2258
+ )
2259
+ self.vq.enable_hook(target_layer)
2260
+
2261
+ def enable_hook(self):
2262
+ target_layer = self.frozen_modules["phone_extractor"].head
2263
+ self.vq.enable_hook(target_layer)
2264
+
2265
  def _get_resampler(
2266
  self, orig_freq, new_freq, device, cache={}
2267
  ) -> torchaudio.transforms.Resample:
 
2291
  # slice_start_indices: [batch_size]
2292
 
2293
  batch_size, _, _ = x.size()
2294
+ self.vq.set_target_speaker_ids(target_speaker_id)
2295
 
2296
  with torch.inference_mode():
2297
  phone_extractor: PhoneExtractor = self.frozen_modules["phone_extractor"]
2298
  pitch_estimator: PitchEstimator = self.frozen_modules["pitch_estimator"]
2299
  # [batch_size, 1, wav_length] -> [batch_size, phone_channels, length]
2300
  phone = phone_extractor.units(x).transpose(1, 2)
2301
+
2302
+ if self.training and self.phone_noise_ratio != 0.0:
2303
+ phone *= (1.0 - self.phone_noise_ratio) / phone.square().mean(
2304
+ 1, keepdim=True
2305
+ ).sqrt_()
2306
+ noise = torch.randn_like(phone)
2307
+ noise *= (
2308
+ self.phone_noise_ratio
2309
+ / noise.square().mean(1, keepdim=True).sqrt_()
2310
+ )
2311
+ phone += noise
2312
+ # F.rms_norm は PyTorch >= 2.4 が必要
2313
+ phone *= (
2314
+ 1.0
2315
+ / phone.square()
2316
+ .mean(1, keepdim=True)
2317
+ .add_(torch.finfo(torch.float).eps)
2318
+ .sqrt_()
2319
+ )
2320
+
2321
+ # [batch_size, 1, wav_length] -> [batch_size, pitch_bins, length], [batch_size, 1, length]
2322
  pitch, energy = pitch_estimator(x)
2323
  # augmentation
2324
  if self.training:
2325
+ # [batch_size, pitch_bins - 1]
2326
  weights = pitch.softmax(1)[:, 1:, :].mean(2)
2327
  # [batch_size]
2328
  mean_pitch = (
2329
+ weights
2330
+ * torch.arange(
2331
+ 1,
2332
+ self.embed_quantized_pitch.num_embeddings,
2333
+ device=weights.device,
2334
+ )
2335
  ).sum(1) / weights.sum(1)
2336
  mean_pitch = mean_pitch.round_().long()
2337
  target_pitch = torch.randint_like(mean_pitch, 64, 257)
2338
  shift = target_pitch - mean_pitch
2339
  shift_ratio = (
2340
+ 2.0 ** (shift.float() / pitch_estimator.pitch_bins_per_octave)
2341
  ).tolist()
2342
  shift = []
2343
  interval_length = 100 # 1s
 
2357
  shift_ratio_i = shift_numer_i / shift_denom_i
2358
  shift_i = int(
2359
  round(
2360
+ math.log2(shift_ratio_i)
2361
+ * pitch_estimator.pitch_bins_per_octave
2362
  )
2363
  )
2364
  shift.append(shift_i)
 
2390
  # [batch_size, 1, sum(wav_length) + batch_size * 16000]
2391
  concatenated_shifted_x = torch.cat(concatenated_shifted_x, dim=2)
2392
  assert concatenated_shifted_x.size(2) % (256 * 160) == 0
2393
+ # [1, pitch_bins, length / shift_ratio], [1, 1, length / shift_ratio]
2394
  concatenated_pitch, concatenated_energy = pitch_estimator(
2395
  concatenated_shifted_x
2396
  )
 
2432
  energy[i : i + 1, :, :length] = energy_i[:, :, :length]
2433
  torch.backends.cudnn.benchmark = True
2434
 
2435
+ # [batch_size, pitch_bins, length] -> Long[batch_size, length], [batch_size, 3, length]
2436
  quantized_pitch, pitch_features = pitch_estimator.sample_pitch(
2437
  pitch, return_features=True
2438
  )
 
2444
  quantized_pitch
2445
  + (
2446
  pitch_shift_semitone[:, None]
2447
+ * (pitch_estimator.pitch_bins_per_octave / 12.0)
2448
  )
2449
  .round_()
2450
  .long()
2451
+ ).clamp_(1, self.pitch_bins - 1),
2452
  )
2453
  pitch = 55.0 * 2.0 ** (
2454
+ quantized_pitch.float() / pitch_estimator.pitch_bins_per_octave
2455
  )
2456
  # phone が 2.5ms 先読みしているのに対して、
2457
  # energy は 12.5ms, pitch_features は 22.5ms 先読みしているので、
 
2486
  # [batch_size, hidden_channels, length] -> [batch_size, hidden_channels, segment_length]
2487
  x = slice_segments(x, slice_start_indices, slice_segment_length)
2488
  x = F.silu(x, inplace=True)
2489
+
2490
+ speaker_embedding = self.key_value_speaker_embedding(target_speaker_id).view(
2491
+ batch_size,
2492
+ self.key_value_speaker_embedding_length,
2493
+ self.key_value_speaker_embedding_channels,
2494
+ )
2495
+
2496
  # [batch_size, hidden_channels, segment_length] -> [batch_size, 1, segment_length * 240]
2497
+ y_g_hat, stats = self.vocoder(x, pitch, speaker_embedding)
2498
  stats["pitch"] = pitch
2499
  if return_stats:
2500
  return y_g_hat, stats
 
2502
  return y_g_hat
2503
 
2504
  def _normalize_melsp(self, x):
2505
+ return x.clamp(min=1e-10).log_()
2506
 
2507
  def forward_and_compute_loss(
2508
  self,
 
2513
  slice_segment_length: int,
2514
  y_all: torch.Tensor,
2515
  enable_loss_ap: bool = False,
2516
+ ) -> tuple[
2517
+ torch.Tensor,
2518
+ torch.Tensor,
2519
+ torch.Tensor,
2520
+ torch.Tensor,
2521
+ torch.Tensor,
2522
+ torch.Tensor,
2523
+ dict[str, float],
2524
+ ]:
2525
  # noisy_wavs_16k: [batch_size, 1, wav_length]
2526
  # target_speaker_id: Long[batch_size]
2527
  # formant_shift_semitone: [batch_size]
 
2531
 
2532
  stats = {}
2533
  loss_mel = 0.0
2534
+ loss_loudness = 0.0
2535
+ loudness_win_lengths = [512, 1024, 2048, 4096]
2536
 
2537
  # [batch_size, 1, wav_length] -> [batch_size, 1, wav_length * 240]
2538
  y_hat_all, intermediates = self(
 
2541
  formant_shift_semitone,
2542
  return_stats=True,
2543
  )
2544
+ y_hat_all = y_hat_all.detach().where(y_all == 0.0, y_hat_all)
2545
 
2546
  with torch.amp.autocast("cuda", enabled=False):
2547
  periodic_signal = intermediates["periodic_signal"].float()
 
2550
  periodic_signal = periodic_signal[:, : noise_excitation.size(1)]
2551
  aperiodic_signal = aperiodic_signal[:, : noise_excitation.size(1)]
2552
  y_hat_all = y_hat_all.float()
2553
+ floor_noise = torch.randn_like(y_all) * self.floor_noise_level
2554
+ y_all = y_all + floor_noise
2555
+ y_hat_all += floor_noise
2556
  y_hat_all_truncated = y_hat_all.squeeze(1)[:, : periodic_signal.size(1)]
2557
  y_all_truncated = y_all.squeeze(1)[:, : periodic_signal.size(1)]
2558
 
2559
+ y_loudness = compute_loudness(
2560
+ y_all_truncated, self.out_sample_rate, loudness_win_lengths
2561
+ )
2562
+ y_hat_loudness = compute_loudness(
2563
+ y_hat_all_truncated, self.out_sample_rate, loudness_win_lengths
2564
+ )
2565
+ for win_length, y_loudness_i, y_hat_loudness_i in zip(
2566
+ loudness_win_lengths, y_loudness, y_hat_loudness
2567
+ ):
2568
+ loss_loudness_i = F.mse_loss(y_hat_loudness_i, y_loudness_i)
2569
+ loss_loudness += loss_loudness_i * math.sqrt(win_length)
2570
+ stats[f"loss_loudness_{win_length}"] = loss_loudness_i.item()
2571
+
2572
  for melspectrogram in self.melspectrograms:
2573
  melsp_periodic_signal = melspectrogram(periodic_signal)
2574
  melsp_aperiodic_signal = melspectrogram(aperiodic_signal)
 
2608
  t = (
2609
  torch.arange(intermediates["pitch"].size(1), device=y_all.device)
2610
  * 0.01
2611
+ + 0.005
2612
  )
2613
  y_coarse_aperiodicity, y_rms = d4c(
2614
  y_all.squeeze(1),
 
2630
  loss_ap = F.mse_loss(
2631
  y_hat_coarse_aperiodicity, y_coarse_aperiodicity, reduction="none"
2632
  )
2633
+ loss_ap *= (rms / (rms + 1e-3) * (rms > 1e-5))[:, :, None]
2634
  loss_ap = loss_ap.mean()
2635
  else:
2636
  loss_ap = torch.tensor(0.0)
 
2641
  )
2642
  # [batch_size, 1, wav_length] -> [batch_size, 1, slice_segment_length * 240]
2643
  y = slice_segments(y_all, slice_start_indices * 240, slice_segment_length * 240)
2644
+ return y, y_hat, y_hat_all, loss_loudness, loss_mel, loss_ap, stats
2645
 
2646
  def merge_weights(self):
2647
  self.vocoder.merge_weights()
 
2659
  dump_layer(self.embed_pitch_features, f)
2660
  dump_layer(self.vocoder, f)
2661
 
2662
+ def dump_speaker_embeddings(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
2663
+ if isinstance(f, (str, bytes, os.PathLike)):
2664
+ with open(f, "wb") as f:
2665
+ self.dump_speaker_embeddings(f)
2666
+ return
2667
+ if not hasattr(f, "write"):
2668
+ raise TypeError
2669
+
2670
+ dump_params(self.vq.codebooks, f)
2671
+ dump_layer(self.embed_speaker, f)
2672
+ dump_layer(self.embed_formant_shift, f)
2673
+ dump_layer(self.key_value_speaker_embedding, f)
2674
+
2675
+ def dump_embedding_setter(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
2676
+ if isinstance(f, (str, bytes, os.PathLike)):
2677
+ with open(f, "wb") as f:
2678
+ self.dump_embedding_setter(f)
2679
+ return
2680
+ if not hasattr(f, "write"):
2681
+ raise TypeError
2682
+
2683
+ self.vocoder.prenet.dump_kv(f)
2684
+
2685
 
2686
  # Discriminator
2687
 
 
2815
  t = t + n_pad
2816
  x = x.view(b, c, t // self.period, self.period)
2817
 
2818
+ for conv in self.convs:
2819
+ x = conv(x)
2820
  x = F.silu(x, inplace=True)
2821
  fmap.append(x)
2822
  if self.san:
 
2863
  fmap = []
2864
 
2865
  x = self._spectrogram(x).unsqueeze(1)
2866
+ for conv in self.convs:
2867
+ x = conv(x)
2868
  x = F.silu(x, inplace=True)
2869
  fmap.append(x)
2870
  if self.san:
 
2984
  # adversarial loss
2985
  adv_loss = 0.0
2986
  for dg, name in zip(y_d_gs, self.discriminator_names):
 
2987
  if self.san:
2988
+ dg_fun = dg[0].float()
2989
+ g_loss = F.softplus(1.0 - dg_fun).square().mean()
2990
  else:
2991
+ dg = dg.float()
2992
  g_loss = (1.0 - dg).square().mean()
2993
  stats[f"{name}_gg_loss"] = g_loss.item()
2994
  adv_loss += g_loss
 
3206
  return res[..., : signal.size(-1)]
3207
 
3208
 
3209
+ def random_formant_shift(
3210
+ wav: torch.Tensor,
3211
+ sample_rate: int,
3212
+ formant_shift_semitone_min: float = -3.0,
3213
+ formant_shift_semitone_max: float = 3.0,
3214
+ ) -> torch.Tensor:
3215
+ assert wav.ndim == 2
3216
+ assert wav.size(0) == 1
3217
+
3218
+ device = wav.device
3219
+
3220
+ hop_length = 256
3221
+
3222
+ # [wav_length]
3223
+ wav_np = wav.ravel().double().cpu().numpy()
3224
+ f0, t = pyworld.dio(
3225
+ wav_np,
3226
+ sample_rate,
3227
+ f0_floor=55,
3228
+ f0_ceil=1400,
3229
+ frame_period=hop_length * 1000 / sample_rate,
3230
+ )
3231
+ f0 = pyworld.stonemask(wav_np, f0, t, sample_rate)
3232
+ world_sp = pyworld.cheaptrick(wav_np, f0, t, sample_rate)
3233
+ world_sp = (
3234
+ torch.from_numpy(world_sp).float().to(device).sqrt_()[None]
3235
+ ) # [1, length, n_fft // 2 + 1]
3236
+
3237
+ n_fft = win_length = (world_sp.size(2) - 1) * 2
3238
+
3239
+ window = torch.hann_window(win_length, device=device)
3240
+
3241
+ # [1, n_fft // 2 + 1, length]
3242
+ stft_sp = torch.stft(
3243
+ wav,
3244
+ n_fft=n_fft,
3245
+ hop_length=hop_length,
3246
+ win_length=win_length,
3247
+ window=window,
3248
+ return_complex=True,
3249
+ )
3250
+ assert world_sp.size(1) == stft_sp.size(2), (world_sp.size(), stft_sp.size())
3251
+ assert world_sp.size(2) == stft_sp.size(1), (world_sp.size(), stft_sp.size())
3252
+
3253
+ shift_semitones = (
3254
+ torch.rand(()).item()
3255
+ * (formant_shift_semitone_max - formant_shift_semitone_min)
3256
+ + formant_shift_semitone_min
3257
+ )
3258
+ shift_ratio = 2.0 ** (shift_semitones / 12.0)
3259
+ shifted_world_sp = F.interpolate(
3260
+ world_sp, scale_factor=shift_ratio, mode="linear", align_corners=True
3261
+ )
3262
+
3263
+ if shifted_world_sp.size(2) > n_fft // 2 + 1:
3264
+ shifted_world_sp = shifted_world_sp[:, :, : n_fft // 2 + 1]
3265
+ elif shifted_world_sp.size(2) < n_fft // 2 + 1:
3266
+ shifted_world_sp = F.pad(
3267
+ shifted_world_sp, (0, n_fft // 2 + 1 - shifted_world_sp.size(2))
3268
+ )
3269
+
3270
+ ratio = ((shifted_world_sp + 1e-5) / (world_sp + 1e-5)).clamp(0.1, 10.0)
3271
+ stft_sp *= ratio.transpose(-2, -1) # [1, n_fft // 2 + 1, length]
3272
+
3273
+ out = torch.istft(
3274
+ stft_sp,
3275
+ n_fft=n_fft,
3276
+ hop_length=hop_length,
3277
+ win_length=win_length,
3278
+ window=window,
3279
+ length=wav.size(-1),
3280
+ )
3281
+
3282
+ return out
3283
+
3284
+
3285
  def random_filter(audio: torch.Tensor) -> torch.Tensor:
3286
  assert audio.ndim == 2
3287
  ab = torch.rand(audio.size(0), 6) * 0.75 - 0.375
 
3324
 
3325
 
3326
  def get_butterworth_lpf(
3327
+ cutoff_freq: float, sample_rate: int, cache={}
3328
  ) -> tuple[torch.Tensor, torch.Tensor]:
3329
  if (cutoff_freq, sample_rate) not in cache:
3330
  q = math.sqrt(0.5)
 
3335
  b0 = b1 * 0.5
3336
  a1 = -2.0 * cos_omega / (1.0 + alpha)
3337
  a2 = (1.0 - alpha) / (1.0 + alpha)
3338
+ cache[(cutoff_freq, sample_rate)] = (
3339
+ torch.tensor([b0, b1, b0]),
3340
+ torch.tensor([1.0, a1, a2]),
3341
  )
3342
  return cache[(cutoff_freq, sample_rate)]
3343
 
 
3347
  sample_rate: int,
3348
  noise_files: list[Union[str, bytes, os.PathLike]],
3349
  ir_files: list[Union[str, bytes, os.PathLike]],
3350
+ snr_candidates: list[float] = [20.0, 25.0, 30.0, 35.0, 40.0, 45.0],
3351
+ formant_shift_probability: float = 0.5,
3352
+ formant_shift_semitone_min: float = -3.0,
3353
+ formant_shift_semitone_max: float = 3.0,
3354
+ reverb_probability: float = 0.5,
3355
+ lpf_probability: float = 0.2,
3356
+ lpf_cutoff_freq_candidates: list[float] = [2000.0, 3000.0, 4000.0, 6000.0],
3357
  ) -> torch.Tensor:
3358
  # [1, wav_length]
3359
  assert clean.size(0) == 1
3360
  n_samples = clean.size(1)
3361
 
 
 
3362
  original_clean_rms = clean.square().mean().sqrt_()
3363
 
3364
+ # clean をフォルマントシフトする
3365
+ if torch.rand(()) < formant_shift_probability:
3366
+ clean = random_formant_shift(
3367
+ clean, sample_rate, formant_shift_semitone_min, formant_shift_semitone_max
3368
+ )
3369
+
3370
  # noise を取得して clean と concat する
3371
  noise = get_noise(n_samples, sample_rate, noise_files)
3372
  signals = torch.cat([clean, noise])
 
3375
  signals = random_filter(signals)
3376
 
3377
  # clean, noise にリバーブをかける
3378
+ if torch.rand(()) < reverb_probability:
3379
  ir_file = ir_files[torch.randint(0, len(ir_files), ())]
3380
  ir, sr = torchaudio.load(ir_file, backend="soundfile")
3381
  assert ir.size() == (2, sr), ir.size()
 
3383
  signals = convolve(signals, ir)
3384
 
3385
  # clean, noise に同じ LPF をかける
3386
+ if torch.rand(()) < lpf_probability:
3387
  if signals.abs().max() > 0.8:
3388
  signals /= signals.abs().max() * 1.25
3389
+ cutoff_freq = lpf_cutoff_freq_candidates[
3390
+ torch.randint(0, len(lpf_cutoff_freq_candidates), ())
 
3391
  ]
3392
  b, a = get_butterworth_lpf(cutoff_freq, sample_rate)
3393
  signals = torchaudio.functional.lfilter(signals, a, b, clamp=False)
 
3397
  clean_rms = clean.square().mean().sqrt_()
3398
  clean *= original_clean_rms / clean_rms
3399
 
3400
+ if len(snr_candidates) >= 1:
3401
+ # clean, noise の音量をピークを重視して取る
3402
+ clean_level = clean.square().square_().mean().sqrt_().sqrt_()
3403
+ noise_level = noise.square().square_().mean().sqrt_().sqrt_()
3404
+ # SNR
3405
+ snr = snr_candidates[torch.randint(0, len(snr_candidates), ())]
3406
+ # noisy を生成
3407
+ noisy = clean + noise * (
3408
+ 0.1 ** (snr / 20.0) * clean_level / (noise_level + 1e-5)
3409
+ )
3410
+
3411
  return noisy
3412
 
3413
 
 
3421
  segment_length: int = 100, # 1s
3422
  noise_files: Optional[list[Union[str, bytes, os.PathLike]]] = None,
3423
  ir_files: Optional[list[Union[str, bytes, os.PathLike]]] = None,
3424
+ augmentation_snr_candidates: list[float] = [20.0, 25.0, 30.0, 35.0, 40.0, 45.0],
3425
+ augmentation_formant_shift_probability: float = 0.5,
3426
+ augmentation_formant_shift_semitone_min: float = -3.0,
3427
+ augmentation_formant_shift_semitone_max: float = 3.0,
3428
+ augmentation_reverb_probability: float = 0.5,
3429
+ augmentation_lpf_probability: float = 0.2,
3430
+ augmentation_lpf_cutoff_freq_candidates: list[float] = [
3431
+ 2000.0,
3432
+ 3000.0,
3433
+ 4000.0,
3434
+ 6000.0,
3435
+ ],
3436
  ):
3437
  self.audio_files = audio_files
3438
  self.in_sample_rate = in_sample_rate
 
3441
  self.segment_length = segment_length
3442
  self.noise_files = noise_files
3443
  self.ir_files = ir_files
3444
+ self.augmentation_snr_candidates = augmentation_snr_candidates
3445
+ self.augmentation_formant_shift_probability = (
3446
+ augmentation_formant_shift_probability
3447
+ )
3448
+ self.augmentation_formant_shift_semitone_min = (
3449
+ augmentation_formant_shift_semitone_min
3450
+ )
3451
+ self.augmentation_formant_shift_semitone_max = (
3452
+ augmentation_formant_shift_semitone_max
3453
+ )
3454
+ self.augmentation_reverb_probability = augmentation_reverb_probability
3455
+ self.augmentation_lpf_probability = augmentation_lpf_probability
3456
+ self.augmentation_lpf_cutoff_freq_candidates = (
3457
+ augmentation_lpf_cutoff_freq_candidates
3458
+ )
3459
 
3460
  if (noise_files is None) is not (ir_files is None):
3461
  raise ValueError("noise_files and ir_files must be both None or not None")
 
3497
  clean_wav
3498
  )
3499
  noisy_wav_16k = augment_audio(
3500
+ clean_wav_16k,
3501
+ self.in_sample_rate,
3502
+ self.noise_files,
3503
+ self.ir_files,
3504
+ self.augmentation_snr_candidates,
3505
+ self.augmentation_formant_shift_probability,
3506
+ self.augmentation_formant_shift_semitone_min,
3507
+ self.augmentation_formant_shift_semitone_max,
3508
+ self.augmentation_reverb_probability,
3509
+ self.augmentation_lpf_probability,
3510
+ self.augmentation_lpf_cutoff_freq_candidates,
3511
  )
3512
 
3513
  clean_wav = clean_wav.squeeze_(0)
 
3593
  }
3594
 
3595
 
3596
+ def get_compressed_optimizer_state_dict(
3597
+ optimizer: torch.optim.Optimizer,
3598
+ ) -> dict:
3599
+ state_dict = {}
3600
+ for k0, v0 in optimizer.state_dict().items():
3601
+ if k0 != "state":
3602
+ state_dict[k0] = v0
3603
+ continue
3604
+ state_dict[k0] = {}
3605
+ for k1, v1 in v0.items():
3606
+ state_dict[k0][k1] = {}
3607
+ for k2, v2 in v1.items():
3608
+ if isinstance(v2, torch.Tensor):
3609
+ state_dict[k0][k1][k2] = v2.bfloat16()
3610
+ assert state_dict[k0][k1][k2].isfinite().all()
3611
+ else:
3612
+ state_dict[k0][k1][k2] = v2
3613
+ return state_dict
3614
+
3615
+
3616
+ def get_decompressed_optimizer_state_dict(compressed_state_dict: dict) -> dict:
3617
+ state_dict = {}
3618
+ for k0, v0 in compressed_state_dict.items():
3619
+ if k0 != "state":
3620
+ state_dict[k0] = v0
3621
+ continue
3622
+ state_dict[k0] = {}
3623
+ for k1, v1 in v0.items():
3624
+ state_dict[k0][k1] = {}
3625
+ for k2, v2 in v1.items():
3626
+ if isinstance(v2, torch.Tensor):
3627
+ state_dict[k0][k1][k2] = v2.float()
3628
+ assert state_dict[k0][k1][k2].isfinite().all()
3629
+ else:
3630
+ state_dict[k0][k1][k2] = v2
3631
+ return state_dict
3632
+
3633
+
3634
  def prepare_training():
3635
  # 各種準備をする
3636
  # 副作用として、出力ディレクトリと TensorBoard のログファイルなどが生成される
 
3655
  if not in_wav_dataset_dir.is_dir():
3656
  raise ValueError(f"{in_wav_dataset_dir} is not found.")
3657
  if resume:
3658
+ latest_checkpoint_file = out_dir / "checkpoint_latest.pt.gz"
3659
  if not latest_checkpoint_file.is_file():
3660
  raise ValueError(f"{latest_checkpoint_file} is not found.")
3661
  else:
3662
  if out_dir.is_dir():
3663
+ if (out_dir / "checkpoint_latest.pt.gz").is_file():
3664
  raise ValueError(
3665
+ f"{out_dir / 'checkpoint_latest.pt.gz'} already exists. "
3666
  "Please specify a different output directory, or use --resume option."
3667
  )
3668
  for file in out_dir.iterdir():
3669
+ if file.suffix == ".pt.gz":
3670
  raise ValueError(
3671
  f"{out_dir} already contains model files. "
3672
  "Please specify a different output directory."
 
3778
  segment_length=h.segment_length,
3779
  noise_files=noise_files,
3780
  ir_files=ir_files,
3781
+ augmentation_snr_candidates=h.augmentation_snr_candidates,
3782
+ augmentation_formant_shift_probability=h.augmentation_formant_shift_probability,
3783
+ augmentation_formant_shift_semitone_min=h.augmentation_formant_shift_semitone_min,
3784
+ augmentation_formant_shift_semitone_max=h.augmentation_formant_shift_semitone_max,
3785
+ augmentation_reverb_probability=h.augmentation_reverb_probability,
3786
+ augmentation_lpf_probability=h.augmentation_lpf_probability,
3787
+ augmentation_lpf_cutoff_freq_candidates=h.augmentation_lpf_cutoff_freq_candidates,
3788
  )
3789
  training_loader = torch.utils.data.DataLoader(
3790
  training_dataset,
 
3813
  print("Computing pitch shifts for test files...")
3814
  test_pitch_shifts = []
3815
  source_f0s = []
3816
+ for i, (file, target_ids) in enumerate(
3817
+ tqdm(test_filelist, desc="Computing pitch shifts")
3818
+ ):
3819
  source_f0 = compute_mean_f0([file], method="harvest")
3820
  source_f0s.append(source_f0)
3821
  if math.isnan(source_f0):
 
3839
  repo_root() / h.phone_extractor_file, map_location="cpu", weights_only=True
3840
  )
3841
  print(
3842
+ phone_extractor.load_state_dict(
3843
+ phone_extractor_checkpoint["phone_extractor"], strict=False
3844
+ )
3845
  )
3846
  del phone_extractor_checkpoint
3847
 
 
3858
  phone_extractor,
3859
  pitch_estimator,
3860
  n_speakers,
3861
+ h.pitch_bins,
3862
  h.hidden_channels,
3863
+ h.vq_topk,
3864
+ h.training_time_vq,
3865
+ h.phone_noise_ratio,
3866
+ h.floor_noise_level,
3867
  ).to(device)
3868
  net_d = MultiPeriodDiscriminator(san=h.san).to(device)
3869
 
 
3883
  grad_scaler = torch.amp.GradScaler("cuda", enabled=h.use_amp)
3884
  grad_balancer = GradBalancer(
3885
  weights={
3886
+ "loss_loudness": h.grad_weight_loudness,
3887
  "loss_mel": h.grad_weight_mel,
3888
  "loss_adv": h.grad_weight_adv,
3889
  "loss_fm": h.grad_weight_fm,
 
3898
  # チェックポイント読み出し
3899
 
3900
  initial_iteration = 0
3901
+ if resume: # 学習再開
3902
  checkpoint_file = latest_checkpoint_file
3903
+ elif h.pretrained_file is not None: # ファインチューニング
3904
  checkpoint_file = repo_root() / h.pretrained_file
3905
+ else: # 事前学習
3906
  checkpoint_file = None
3907
+
3908
  if checkpoint_file is not None:
3909
+ with gzip.open(checkpoint_file, "rb") as f:
3910
+ checkpoint = torch.load(f, map_location="cpu", weights_only=True)
3911
  if not resume and not skip_training: # ファインチューニング
3912
+ initial_speaker_embedding = checkpoint["net_g"]["embed_speaker.weight"][:1]
3913
+ initial_speaker_embedding_for_cross_attention = checkpoint["net_g"][
3914
+ "key_value_speaker_embedding.weight"
3915
+ ][:1]
3916
+ checkpoint["net_g"]["embed_speaker.weight"] = initial_speaker_embedding[
3917
+ [0] * n_speakers
3918
+ ]
3919
+ checkpoint["net_g"]["key_value_speaker_embedding.weight"] = (
3920
+ initial_speaker_embedding_for_cross_attention[[0] * n_speakers]
3921
+ )
3922
+ checkpoint["net_g"]["vq.codebooks"] = checkpoint["net_g"]["vq.codebooks"][
3923
+ [0] * n_speakers
3924
+ ]
 
 
 
 
 
 
 
3925
  print(net_g.load_state_dict(checkpoint["net_g"], strict=False))
3926
  print(net_d.load_state_dict(checkpoint["net_d"], strict=False))
3927
  if resume or skip_training:
3928
+ optim_g.load_state_dict(
3929
+ get_decompressed_optimizer_state_dict(checkpoint["optim_g"])
3930
+ )
3931
+ optim_d.load_state_dict(
3932
+ get_decompressed_optimizer_state_dict(checkpoint["optim_d"])
3933
+ )
3934
  initial_iteration = checkpoint["iteration"]
3935
  grad_balancer.load_state_dict(checkpoint["grad_balancer"])
3936
  grad_scaler.load_state_dict(checkpoint["grad_scaler"])
3937
 
3938
+ def wav_iterator(files):
3939
+ for file in files:
3940
+ wav, sr = torchaudio.load(file, backend="soundfile")
3941
+ wav = wav.to(device)
3942
+ if sr != h.in_sample_rate:
3943
+ wav = get_resampler(sr, h.in_sample_rate, device)(wav)
3944
+ yield wav[:, None, :]
3945
+
3946
+ if resume:
3947
+ net_g.enable_hook()
3948
+ else:
3949
+ net_g.initialize_vq([wav_iterator(files) for files in speaker_audio_files])
3950
+
3951
  # スケジューラ
3952
 
3953
+ def get_exponential_warmup_scheduler(
3954
  optimizer: torch.optim.Optimizer,
3955
  warmup_epochs: int,
3956
+ decay: float,
 
3957
  ) -> torch.optim.lr_scheduler.LambdaLR:
 
 
 
 
3958
  def lr_lambda(current_epoch: int) -> float:
3959
  if current_epoch < warmup_epochs:
3960
  return current_epoch / warmup_epochs
 
 
 
3961
  else:
3962
+ return decay ** (current_epoch - warmup_epochs)
3963
 
3964
  return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
3965
 
3966
+ scheduler_g = get_exponential_warmup_scheduler(
3967
+ optim_g, h.warmup_steps, h.learning_rate_decay
3968
  )
3969
+ scheduler_d = get_exponential_warmup_scheduler(
3970
+ optim_d, h.warmup_steps, h.learning_rate_decay
3971
  )
3972
  with warnings.catch_warnings():
3973
  warnings.filterwarnings(
 
3989
  writer = None
3990
  else:
3991
  writer = SummaryWriter(out_dir)
3992
+ if not h.record_metrics:
3993
+ writer.add_scalar = lambda *args, **kwargs: None
3994
+ writer.add_histogram = lambda *args, **kwargs: None
3995
  writer.add_text(
3996
  "log",
3997
  f"start training w/ {torch.cuda.get_device_name(device) if torch.cuda.is_available() else 'cpu'}.",
 
4085
  if h.profile
4086
  else nullcontext()
4087
  ) as profiler:
4088
+ for iteration in tqdm(range(initial_iteration, h.n_steps), desc="Training"):
 
4089
  # === 1. データ前処理 ===
4090
  try:
4091
  batch = next(data_iter)
4092
+ except (NameError, StopIteration):
4093
  data_iter = iter(training_loader)
4094
  batch = next(data_iter)
4095
  (
 
4105
  # === 2.1 Generator の順伝播 ===
4106
  if h.compile_convnext:
4107
  ConvNeXtStack.forward = compiled_convnextstack_forward
4108
+ (
4109
+ y,
4110
+ y_hat,
4111
+ y_hat_for_backward,
4112
+ loss_loudness,
4113
+ loss_mel,
4114
+ loss_ap,
4115
+ generator_stats,
4116
+ ) = net_g.forward_and_compute_loss(
4117
+ noisy_wavs_16k[:, None, :],
4118
+ speaker_ids,
4119
+ formant_shift_semitone,
4120
+ slice_start_indices=slice_starts,
4121
+ slice_segment_length=h.segment_length,
4122
+ y_all=clean_wavs[:, None, :],
4123
+ enable_loss_ap=h.grad_weight_ap != 0.0,
4124
  )
4125
  if h.compile_convnext:
4126
  ConvNeXtStack.forward = raw_convnextstack_forward
4127
  assert y_hat.isfinite().all()
4128
+ assert loss_loudness.isfinite().all()
4129
  assert loss_mel.isfinite().all()
4130
  assert loss_ap.isfinite().all()
4131
 
 
4156
  assert param.grad is None
4157
  gradient_balancer_stats = grad_balancer.backward(
4158
  {
4159
+ "loss_loudness": loss_loudness,
4160
  "loss_mel": loss_mel,
4161
  "loss_adv": loss_adv,
4162
  "loss_fm": loss_fm,
 
4166
  grad_scaler,
4167
  skip_update_ema=iteration > 10 and iteration % 5 != 0,
4168
  )
4169
+ loss_loudness = loss_loudness.item()
4170
  loss_mel = loss_mel.item()
4171
  loss_adv = loss_adv.item()
4172
  loss_fm = loss_fm.item()
 
4187
  grad_scaler.update()
4188
 
4189
  # === 3. ログ ===
4190
+ dict_scalars["loss_g/loss_loudness"].append(loss_loudness)
4191
  dict_scalars["loss_g/loss_mel"].append(loss_mel)
4192
  if h.grad_weight_ap:
4193
  dict_scalars["loss_g/loss_ap"].append(loss_ap)
 
4296
  )
4297
 
4298
  # === 4. 検証 ===
4299
+ if (iteration + 1) % h.evaluation_interval == 0 or iteration + 1 in {
 
 
4300
  1,
 
4301
  h.n_steps,
4302
  }:
4303
  torch.backends.cudnn.benchmark = False
 
4394
  torch.cuda.empty_cache()
4395
 
4396
  # === 5. 保存 ===
4397
+ if (iteration + 1) % h.save_interval == 0 or iteration + 1 in {
 
 
4398
  1,
 
4399
  h.n_steps,
4400
  }:
4401
  # チェックポイント
4402
  name = f"{in_wav_dataset_dir.name}_{iteration + 1:08d}"
4403
+ checkpoint_file_save = out_dir / f"checkpoint_{name}.pt.gz"
4404
  if checkpoint_file_save.exists():
4405
  checkpoint_file_save = checkpoint_file_save.with_name(
4406
  f"{checkpoint_file_save.name}_{hash(None):x}"
4407
  )
4408
+ with gzip.open(checkpoint_file_save, "wb") as f:
4409
+ torch.save(
4410
+ {
4411
+ "iteration": iteration + 1,
4412
+ "net_g": net_g.state_dict(),
4413
+ "phone_extractor": phone_extractor.state_dict(),
4414
+ "pitch_estimator": pitch_estimator.state_dict(),
4415
+ "net_d": {
4416
+ k: v.half() for k, v in net_d.state_dict().items()
4417
+ },
4418
+ "optim_g": get_compressed_optimizer_state_dict(optim_g),
4419
+ "optim_d": get_compressed_optimizer_state_dict(optim_d),
4420
+ "grad_balancer": grad_balancer.state_dict(),
4421
+ "grad_scaler": grad_scaler.state_dict(),
4422
+ "h": dict(h),
4423
+ },
4424
+ f,
4425
+ )
4426
+ shutil.copy(checkpoint_file_save, out_dir / "checkpoint_latest.pt.gz")
4427
 
4428
  # 推論用
4429
  paraphernalia_dir = out_dir / f"paraphernalia_{name}"
 
4437
  phone_extractor_fp16.remove_weight_norm()
4438
  phone_extractor_fp16.merge_weights()
4439
  phone_extractor_fp16.half()
4440
+ phone_extractor_fp16.dump(paraphernalia_dir / "phone_extractor.bin")
4441
  del phone_extractor_fp16
4442
  pitch_estimator_fp16 = PitchEstimator()
4443
  pitch_estimator_fp16.load_state_dict(pitch_estimator.state_dict())
4444
  pitch_estimator_fp16.merge_weights()
4445
  pitch_estimator_fp16.half()
4446
+ pitch_estimator_fp16.dump(paraphernalia_dir / "pitch_estimator.bin")
4447
  del pitch_estimator_fp16
4448
  net_g_fp16 = ConverterNetwork(
4449
+ nn.Module(),
4450
+ nn.Module(),
4451
+ len(speakers),
4452
+ h.pitch_bins,
4453
+ h.hidden_channels,
4454
+ h.vq_topk,
4455
+ h.training_time_vq,
4456
+ h.phone_noise_ratio,
4457
+ h.floor_noise_level,
4458
  )
4459
  net_g_fp16.load_state_dict(net_g.state_dict())
4460
  net_g_fp16.merge_weights()
4461
  net_g_fp16.half()
4462
+ net_g_fp16.dump(paraphernalia_dir / "waveform_generator.bin")
4463
+ net_g_fp16.dump_speaker_embeddings(
4464
+ paraphernalia_dir / "speaker_embeddings.bin"
4465
+ )
4466
+ net_g_fp16.dump_embedding_setter(
4467
+ paraphernalia_dir / "embedding_setter.bin"
4468
+ )
4469
  del net_g_fp16
4470
  shutil.copy(
4471
  repo_root() / "assets/images/noimage.png", paraphernalia_dir
pyproject.toml CHANGED
@@ -1,34 +1,95 @@
1
- [tool.poetry]
2
  name = "beatrice-trainer"
3
- version = "2.0.0b2"
4
  description = "A tool to train Beatrice models"
5
- license = "MIT"
6
- authors = ["Project Beatrice <167534685+prj-beatrice@users.noreply.github.com>"]
 
 
7
  readme = "README.md"
8
- homepage = "https://prj-beatrice.com/"
9
- repository = "https://huggingface.co/fierce-cats/beatrice-trainer"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- [tool.poetry.dependencies]
12
- python = ">=3.9"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  torch = [
14
- { version = ">=2.1", markers = "sys_platform == 'win32'", source = "torch-cuda" },
15
- { version = ">=2.1", markers = "sys_platform != 'win32'" },
 
 
16
  ]
17
  torchaudio = [
18
- { version = ">=2.1", markers = "sys_platform == 'win32'", source = "torch-cuda" },
19
- { version = ">=2.1", markers = "sys_platform != 'win32'" },
 
 
20
  ]
21
- tqdm = ">=4"
22
- numpy = "^1"
23
- tensorboard = ">=2"
24
- soundfile = ">=0.11"
25
- pyworld = ">=0.3.2"
26
 
27
- [[tool.poetry.source]]
28
- name = "torch-cuda"
29
- url = "https://download.pytorch.org/whl/cu121"
30
- priority = "explicit"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  [build-system]
33
- requires = ["poetry-core"]
34
- build-backend = "poetry.core.masonry.api"
 
1
+ [project]
2
  name = "beatrice-trainer"
3
+ version = "2.0.0rc0"
4
  description = "A tool to train Beatrice models"
5
+ authors = [
6
+ { name = "Project Beatrice", email = "167534685+prj-beatrice@users.noreply.github.com" },
7
+ ]
8
+ requires-python = ">=3.9"
9
  readme = "README.md"
10
+ license = "MIT"
11
+ dependencies = [
12
+ "torch>=2.1",
13
+ "torchaudio>=2.1,<2.9",
14
+ "tqdm>=4",
15
+ "numpy>=1",
16
+ "tensorboard>=2",
17
+ "soundfile>=0.11",
18
+ "pyworld>=0.3.2",
19
+ ]
20
+
21
+ [project.optional-dependencies]
22
+ cpu = ["torch>=2.1", "torchaudio>=2.1,<2.9"]
23
+ cu118 = ["torch>=2.1", "torchaudio>=2.1,<2.9"]
24
+ cu126 = ["torch>=2.1", "torchaudio>=2.1,<2.9"]
25
+ cu128 = ["torch>=2.1", "torchaudio>=2.1,<2.9"]
26
+
27
+ [project.urls]
28
+ Homepage = "https://prj-beatrice.com/"
29
+ Repository = "https://huggingface.co/fierce-cats/beatrice-trainer"
30
 
31
+ [tool.uv]
32
+ conflicts = [
33
+ [
34
+ { extra = "cpu" },
35
+ { extra = "cu118" },
36
+ ],
37
+ [
38
+ { extra = "cpu" },
39
+ { extra = "cu126" },
40
+ ],
41
+ [
42
+ { extra = "cpu" },
43
+ { extra = "cu128" },
44
+ ],
45
+ [
46
+ { extra = "cu118" },
47
+ { extra = "cu126" },
48
+ ],
49
+ [
50
+ { extra = "cu118" },
51
+ { extra = "cu128" },
52
+ ],
53
+ [
54
+ { extra = "cu126" },
55
+ { extra = "cu128" },
56
+ ],
57
+ ]
58
+
59
+ [tool.uv.sources]
60
  torch = [
61
+ { index = "pytorch-cpu", extra = "cpu" },
62
+ { index = "pytorch-cu118", extra = "cu118" },
63
+ { index = "pytorch-cu126", extra = "cu126" },
64
+ { index = "pytorch-cu128", extra = "cu128" },
65
  ]
66
  torchaudio = [
67
+ { index = "pytorch-cpu", extra = "cpu" },
68
+ { index = "pytorch-cu118", extra = "cu118" },
69
+ { index = "pytorch-cu126", extra = "cu126" },
70
+ { index = "pytorch-cu128", extra = "cu128" },
71
  ]
 
 
 
 
 
72
 
73
+ [[tool.uv.index]]
74
+ name = "pytorch-cpu"
75
+ url = "https://download.pytorch.org/whl/cpu"
76
+ explicit = true
77
+
78
+ [[tool.uv.index]]
79
+ name = "pytorch-cu118"
80
+ url = "https://download.pytorch.org/whl/cu118"
81
+ explicit = true
82
+
83
+ [[tool.uv.index]]
84
+ name = "pytorch-cu126"
85
+ url = "https://download.pytorch.org/whl/cu126"
86
+ explicit = true
87
+
88
+ [[tool.uv.index]]
89
+ name = "pytorch-cu128"
90
+ url = "https://download.pytorch.org/whl/cu128"
91
+ explicit = true
92
 
93
  [build-system]
94
+ requires = ["hatchling"]
95
+ build-backend = "hatchling.build"