Убраны комментарии и отформатирован код
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- MVSepLess_Epsilon_Colab.ipynb +16 -6
- mvsepless/__init__.py +0 -0
- mvsepless/__main__.py +61 -14
- mvsepless/audio.py +789 -781
- mvsepless/downloader.py +90 -92
- mvsepless/ensemble.py +206 -224
- mvsepless/infer.py +116 -65
- mvsepless/infer_utils.py +41 -69
- mvsepless/model_manager.py +682 -609
- mvsepless/models.json +0 -0
- mvsepless/models/bandit/core/__init__.py +669 -691
- mvsepless/models/bandit/core/data/__init__.py +2 -2
- mvsepless/models/bandit/core/data/_types.py +17 -17
- mvsepless/models/bandit/core/data/augmentation.py +102 -102
- mvsepless/models/bandit/core/data/augmented.py +34 -34
- mvsepless/models/bandit/core/data/base.py +60 -60
- mvsepless/models/bandit/core/data/dnr/datamodule.py +64 -68
- mvsepless/models/bandit/core/data/dnr/dataset.py +360 -366
- mvsepless/models/bandit/core/data/dnr/preprocess.py +51 -51
- mvsepless/models/bandit/core/data/musdb/datamodule.py +75 -75
- mvsepless/models/bandit/core/data/musdb/dataset.py +241 -273
- mvsepless/models/bandit/core/data/musdb/preprocess.py +223 -226
- mvsepless/models/bandit/core/data/musdb/validation.yaml +14 -14
- mvsepless/models/bandit/core/loss/__init__.py +8 -8
- mvsepless/models/bandit/core/loss/_complex.py +27 -27
- mvsepless/models/bandit/core/loss/_multistem.py +43 -43
- mvsepless/models/bandit/core/loss/_timefreq.py +94 -95
- mvsepless/models/bandit/core/loss/snr.py +131 -139
- mvsepless/models/bandit/core/metrics/__init__.py +7 -9
- mvsepless/models/bandit/core/metrics/_squim.py +350 -443
- mvsepless/models/bandit/core/metrics/snr.py +124 -127
- mvsepless/models/bandit/core/model/__init__.py +3 -3
- mvsepless/models/bandit/core/model/_spectral.py +54 -54
- mvsepless/models/bandit/core/model/bsrnn/__init__.py +23 -23
- mvsepless/models/bandit/core/model/bsrnn/bandsplit.py +119 -135
- mvsepless/models/bandit/core/model/bsrnn/core.py +619 -651
- mvsepless/models/bandit/core/model/bsrnn/maskestim.py +327 -351
- mvsepless/models/bandit/core/model/bsrnn/tfmodel.py +287 -320
- mvsepless/models/bandit/core/model/bsrnn/utils.py +518 -525
- mvsepless/models/bandit/core/model/bsrnn/wrapper.py +828 -829
- mvsepless/models/bandit/core/utils/audio.py +324 -412
- mvsepless/models/bandit/model_from_config.py +26 -26
- mvsepless/models/bandit_v2/bandit.py +360 -363
- mvsepless/models/bandit_v2/bandsplit.py +127 -130
- mvsepless/models/bandit_v2/film.py +23 -23
- mvsepless/models/bandit_v2/maskestim.py +269 -281
- mvsepless/models/bandit_v2/tfmodel.py +141 -145
- mvsepless/models/bandit_v2/utils.py +384 -523
- mvsepless/models/bs_roformer/__init__.py +6 -6
- mvsepless/models/bs_roformer/attend.py +120 -126
MVSepLess_Epsilon_Colab.ipynb
CHANGED
|
@@ -51,7 +51,7 @@
|
|
| 51 |
"prodigyopt\n",
|
| 52 |
"torch_log_wmse\n",
|
| 53 |
"rotary_embedding_torch\n",
|
| 54 |
-
"gradio\n",
|
| 55 |
"omegaconf\n",
|
| 56 |
"beartype\n",
|
| 57 |
"spafe\n",
|
|
@@ -395,6 +395,18 @@
|
|
| 395 |
"!python mvsepless/model_manager.py vbach list"
|
| 396 |
]
|
| 397 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
{
|
| 399 |
"cell_type": "markdown",
|
| 400 |
"metadata": {
|
|
@@ -461,9 +473,7 @@
|
|
| 461 |
},
|
| 462 |
{
|
| 463 |
"cell_type": "markdown",
|
| 464 |
-
"metadata": {
|
| 465 |
-
"id": "4GSGHM4rYSVp"
|
| 466 |
-
},
|
| 467 |
"source": [
|
| 468 |
"## Инференс"
|
| 469 |
]
|
|
@@ -485,7 +495,7 @@
|
|
| 485 |
"# @markdown ---\n",
|
| 486 |
"# @markdown ### Hubert\n",
|
| 487 |
"# @markdown * Стэк\n",
|
| 488 |
-
"stack = \"
|
| 489 |
"# @markdown * Имя модели для fairseq\n",
|
| 490 |
"fairseq_embedder = \"hubert_base\" # @param [\"hubert_base\",\"contentvec_base\",\"korean_hubert_base\",\"chinese_hubert_base\",\"portuguese_hubert_base\",\"japanese_hubert_base\"]\n",
|
| 491 |
"# @markdown * Имя модели для transformers\n",
|
|
@@ -497,7 +507,7 @@
|
|
| 497 |
"# @markdown * Стерео режим\n",
|
| 498 |
"stereo_mode = \"mono\" # @param [\"mono\",\"left/right\",\"sim/dif\"]\n",
|
| 499 |
"# @markdown * Метод определения тона\n",
|
| 500 |
-
"method_pitch = \"rmvpe+\" # @param [\"rmvpe+\",\"mangio-crepe\",\"fcpe\"]\n",
|
| 501 |
"# @markdown * Изменение высоты тона (полутона)\n",
|
| 502 |
"pitch = 0 # @param {\"type\":\"slider\",\"min\":-48,\"max\":48,\"step\":1}\n",
|
| 503 |
"# @markdown * Длина шага (для mangio-crepe)\n",
|
|
|
|
| 51 |
"prodigyopt\n",
|
| 52 |
"torch_log_wmse\n",
|
| 53 |
"rotary_embedding_torch\n",
|
| 54 |
+
"gradio<=6.0\n",
|
| 55 |
"omegaconf\n",
|
| 56 |
"beartype\n",
|
| 57 |
"spafe\n",
|
|
|
|
| 395 |
"!python mvsepless/model_manager.py vbach list"
|
| 396 |
]
|
| 397 |
},
|
| 398 |
+
{
|
| 399 |
+
"cell_type": "code",
|
| 400 |
+
"execution_count": null,
|
| 401 |
+
"metadata": {},
|
| 402 |
+
"outputs": [],
|
| 403 |
+
"source": [
|
| 404 |
+
"#@title Удаление голосовой модели\n",
|
| 405 |
+
"%cd $mvsepless_dir\n",
|
| 406 |
+
"voicemodel_name = \"\" # @param {\"type\":\"string\",\"placeholder\":\"Имя модели\"}\n",
|
| 407 |
+
"!python mvsepless/model_manager.py vbach remove --model_name \"$voicemodel_name\""
|
| 408 |
+
]
|
| 409 |
+
},
|
| 410 |
{
|
| 411 |
"cell_type": "markdown",
|
| 412 |
"metadata": {
|
|
|
|
| 473 |
},
|
| 474 |
{
|
| 475 |
"cell_type": "markdown",
|
| 476 |
+
"metadata": {},
|
|
|
|
|
|
|
| 477 |
"source": [
|
| 478 |
"## Инференс"
|
| 479 |
]
|
|
|
|
| 495 |
"# @markdown ---\n",
|
| 496 |
"# @markdown ### Hubert\n",
|
| 497 |
"# @markdown * Стэк\n",
|
| 498 |
+
"stack = \"fairseq\" # @param [\"fairseq\",\"transformers\"]\n",
|
| 499 |
"# @markdown * Имя модели для fairseq\n",
|
| 500 |
"fairseq_embedder = \"hubert_base\" # @param [\"hubert_base\",\"contentvec_base\",\"korean_hubert_base\",\"chinese_hubert_base\",\"portuguese_hubert_base\",\"japanese_hubert_base\"]\n",
|
| 501 |
"# @markdown * Имя модели для transformers\n",
|
|
|
|
| 507 |
"# @markdown * Стерео режим\n",
|
| 508 |
"stereo_mode = \"mono\" # @param [\"mono\",\"left/right\",\"sim/dif\"]\n",
|
| 509 |
"# @markdown * Метод определения тона\n",
|
| 510 |
+
"method_pitch = \"rmvpe+\" # @param [\"rmvpe+\",\"mangio-crepe\",\"mangio-crepe-tiny\",\"fcpe\",'harvest\",\"pm\",\"pyin\"]\n",
|
| 511 |
"# @markdown * Изменение высоты тона (полутона)\n",
|
| 512 |
"pitch = 0 # @param {\"type\":\"slider\",\"min\":-48,\"max\":48,\"step\":1}\n",
|
| 513 |
"# @markdown * Длина шага (для mangio-crepe)\n",
|
mvsepless/__init__.py
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
mvsepless/__main__.py
CHANGED
|
@@ -11,18 +11,59 @@ if __name__ == "__main__":
|
|
| 11 |
parser = argparse.ArgumentParser(description="MVSepless")
|
| 12 |
subparsers = parser.add_subparsers(dest="command")
|
| 13 |
app_parser = subparsers.add_parser("app", help="Приложение MVSepless")
|
| 14 |
-
app_parser.add_argument(
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
cli_parser = subparsers.add_parser("cli", help="CLI MVSepless Lite")
|
| 17 |
-
cli_parser.add_argument(
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
cli_parser.add_argument(
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
cli_parser.add_argument(
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
args = parser.parse_args()
|
| 27 |
|
| 28 |
if args.command == "app":
|
|
@@ -39,7 +80,13 @@ if __name__ == "__main__":
|
|
| 39 |
],
|
| 40 |
)
|
| 41 |
mvsepless_lite_app = mvsepless_app(theme=theme)
|
| 42 |
-
mvsepless_lite_app.launch(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
elif args.command == "cli":
|
| 44 |
input_data = args.input
|
| 45 |
if os.path.isdir(input_data):
|
|
@@ -62,6 +109,6 @@ if __name__ == "__main__":
|
|
| 62 |
output_format=args.output_format,
|
| 63 |
output_bitrate=args.output_bitrate,
|
| 64 |
template=args.template,
|
| 65 |
-
selected_stems=args.selected_stems
|
| 66 |
)
|
| 67 |
-
print("Разделение завершено.")
|
|
|
|
| 11 |
parser = argparse.ArgumentParser(description="MVSepless")
|
| 12 |
subparsers = parser.add_subparsers(dest="command")
|
| 13 |
app_parser = subparsers.add_parser("app", help="Приложение MVSepless")
|
| 14 |
+
app_parser.add_argument(
|
| 15 |
+
"--port", type=int, default=None, help="Порт для запуска сервера Gradio."
|
| 16 |
+
)
|
| 17 |
+
app_parser.add_argument(
|
| 18 |
+
"--share",
|
| 19 |
+
action="store_true",
|
| 20 |
+
help="Создать публичную ссылку для приложения Gradio.",
|
| 21 |
+
)
|
| 22 |
cli_parser = subparsers.add_parser("cli", help="CLI MVSepless Lite")
|
| 23 |
+
cli_parser.add_argument(
|
| 24 |
+
"--input", type=str, required=True, help="Входной аудиофайл или каталог."
|
| 25 |
+
)
|
| 26 |
+
cli_parser.add_argument(
|
| 27 |
+
"--output_dir", type=str, default=None, help="Каталог для выходных файлов."
|
| 28 |
+
)
|
| 29 |
+
cli_parser.add_argument(
|
| 30 |
+
"--model_type",
|
| 31 |
+
type=str,
|
| 32 |
+
default="mel_band_roformer",
|
| 33 |
+
help="Тип модели разделения.",
|
| 34 |
+
)
|
| 35 |
+
cli_parser.add_argument(
|
| 36 |
+
"--model_name",
|
| 37 |
+
type=str,
|
| 38 |
+
default="Mel-Band-Roformer_Vocals_kimberley_jensen",
|
| 39 |
+
help="Имя модели разделения.",
|
| 40 |
+
)
|
| 41 |
+
cli_parser.add_argument(
|
| 42 |
+
"--ext_inst", action="store_true", help="Извлечь инструментал."
|
| 43 |
+
)
|
| 44 |
+
cli_parser.add_argument(
|
| 45 |
+
"--output_format",
|
| 46 |
+
type=str,
|
| 47 |
+
default="mp3",
|
| 48 |
+
choices=Separator.audio.output_formats,
|
| 49 |
+
help="Формат выходного файла.",
|
| 50 |
+
)
|
| 51 |
+
cli_parser.add_argument(
|
| 52 |
+
"--output_bitrate", type=str, default="320k", help="Битрейт выходного файла."
|
| 53 |
+
)
|
| 54 |
+
cli_parser.add_argument(
|
| 55 |
+
"--template",
|
| 56 |
+
type=str,
|
| 57 |
+
default="NAME (STEM) MODEL",
|
| 58 |
+
help="Шаблон именования выходных файлов.",
|
| 59 |
+
)
|
| 60 |
+
cli_parser.add_argument(
|
| 61 |
+
"--selected_stems",
|
| 62 |
+
type=str,
|
| 63 |
+
nargs="*",
|
| 64 |
+
default=None,
|
| 65 |
+
help="Выбранные стемы для разделения.",
|
| 66 |
+
)
|
| 67 |
args = parser.parse_args()
|
| 68 |
|
| 69 |
if args.command == "app":
|
|
|
|
| 80 |
],
|
| 81 |
)
|
| 82 |
mvsepless_lite_app = mvsepless_app(theme=theme)
|
| 83 |
+
mvsepless_lite_app.launch(
|
| 84 |
+
server_name="0.0.0.0",
|
| 85 |
+
server_port=args.port,
|
| 86 |
+
share=args.share,
|
| 87 |
+
allowed_paths=["/"],
|
| 88 |
+
debug=True,
|
| 89 |
+
)
|
| 90 |
elif args.command == "cli":
|
| 91 |
input_data = args.input
|
| 92 |
if os.path.isdir(input_data):
|
|
|
|
| 109 |
output_format=args.output_format,
|
| 110 |
output_bitrate=args.output_bitrate,
|
| 111 |
template=args.template,
|
| 112 |
+
selected_stems=args.selected_stems,
|
| 113 |
)
|
| 114 |
+
print("Разделение завершено.")
|
mvsepless/audio.py
CHANGED
|
@@ -1,781 +1,789 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from pathlib import Path
|
| 3 |
-
import sys
|
| 4 |
-
import json
|
| 5 |
-
import subprocess
|
| 6 |
-
import numpy as np
|
| 7 |
-
from typing import Literal
|
| 8 |
-
from collections.abc import Callable
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
from numpy.typing import DTypeLike
|
| 11 |
-
import tempfile
|
| 12 |
-
import librosa
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
self.
|
| 58 |
-
"mp3"
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
"
|
| 62 |
-
"
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
"
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
},
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
"
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
]
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
of
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
else:
|
| 495 |
-
raise
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
"
|
| 529 |
-
"
|
| 530 |
-
"
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
"
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
"
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
"
|
| 543 |
-
"
|
| 544 |
-
"
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
"
|
| 549 |
-
"
|
| 550 |
-
"
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
"
|
| 555 |
-
"
|
| 556 |
-
"
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
"
|
| 561 |
-
"
|
| 562 |
-
"
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
"
|
| 567 |
-
"
|
| 568 |
-
"
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
"
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
)
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
"
|
| 636 |
-
"
|
| 637 |
-
"
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
)
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
)
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import sys
|
| 4 |
+
import json
|
| 5 |
+
import subprocess
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import Literal
|
| 8 |
+
from collections.abc import Callable
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from numpy.typing import DTypeLike
|
| 11 |
+
import tempfile
|
| 12 |
+
import librosa
|
| 13 |
+
|
| 14 |
+
if not __package__:
|
| 15 |
+
from namer import Namer
|
| 16 |
+
else:
|
| 17 |
+
from .namer import Namer
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class NotInputFileSpecified(Exception):
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class NotOutputFileSpecified(Exception):
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class NotSupportedDataType(Exception):
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ErrorDecode(Exception):
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ErrorEncode(Exception):
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class NotSupportedFormat(Exception):
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SampleRateError(Exception):
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class FileIsNotAudio(Exception):
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Audio(Namer):
|
| 53 |
+
def __init__(self):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.ffmpeg_path = os.environ.get("MVSEPLESS_FFMPEG", "ffmpeg")
|
| 56 |
+
self.ffprobe_path = os.environ.get("MVSEPLESS_FFPROBE", "ffprobe")
|
| 57 |
+
self.output_formats = (
|
| 58 |
+
"mp3",
|
| 59 |
+
"wav",
|
| 60 |
+
"flac",
|
| 61 |
+
"ogg",
|
| 62 |
+
"opus",
|
| 63 |
+
"m4a",
|
| 64 |
+
"aac",
|
| 65 |
+
"ac3",
|
| 66 |
+
"aiff",
|
| 67 |
+
)
|
| 68 |
+
self.input_formats = (
|
| 69 |
+
"mp3",
|
| 70 |
+
"wav",
|
| 71 |
+
"flac",
|
| 72 |
+
"ogg",
|
| 73 |
+
"opus",
|
| 74 |
+
"m4a",
|
| 75 |
+
"aac",
|
| 76 |
+
"ac3",
|
| 77 |
+
"aiff",
|
| 78 |
+
"mp4",
|
| 79 |
+
"mkv",
|
| 80 |
+
"webm",
|
| 81 |
+
"avi",
|
| 82 |
+
"mov",
|
| 83 |
+
"ts",
|
| 84 |
+
)
|
| 85 |
+
self.supported_dtypes = ("int16", "int32", "float32", "float64")
|
| 86 |
+
self.dtypes_dict = {
|
| 87 |
+
"int16": "s16le",
|
| 88 |
+
"int32": "s32le",
|
| 89 |
+
"float32": "f32le",
|
| 90 |
+
"float64": "f64le",
|
| 91 |
+
np.int16: "s16le",
|
| 92 |
+
np.int32: "s32le",
|
| 93 |
+
np.float32: "f32le",
|
| 94 |
+
np.float64: "f64le",
|
| 95 |
+
}
|
| 96 |
+
self.bitrate_limit = {
|
| 97 |
+
"mp3": {"min": 8, "max": 320},
|
| 98 |
+
"aac": {"min": 8, "max": 512},
|
| 99 |
+
"m4a": {"min": 8, "max": 512},
|
| 100 |
+
"ac3": {"min": 32, "max": 640},
|
| 101 |
+
"ogg": {"min": 64, "max": 500},
|
| 102 |
+
"opus": {"min": 6, "max": 512},
|
| 103 |
+
}
|
| 104 |
+
self.sample_rates = {
|
| 105 |
+
"mp3": {
|
| 106 |
+
"supported": (
|
| 107 |
+
44100,
|
| 108 |
+
48000,
|
| 109 |
+
32000,
|
| 110 |
+
22050,
|
| 111 |
+
24000,
|
| 112 |
+
16000,
|
| 113 |
+
11025,
|
| 114 |
+
12000,
|
| 115 |
+
8000,
|
| 116 |
+
)
|
| 117 |
+
},
|
| 118 |
+
"opus": {"supported": (48000, 24000, 16000, 12000, 8000)},
|
| 119 |
+
"m4a": {
|
| 120 |
+
"supported": (
|
| 121 |
+
96000,
|
| 122 |
+
88200,
|
| 123 |
+
64000,
|
| 124 |
+
48000,
|
| 125 |
+
44100,
|
| 126 |
+
32000,
|
| 127 |
+
24000,
|
| 128 |
+
22050,
|
| 129 |
+
16000,
|
| 130 |
+
12000,
|
| 131 |
+
11025,
|
| 132 |
+
8000,
|
| 133 |
+
7350,
|
| 134 |
+
)
|
| 135 |
+
},
|
| 136 |
+
"aac": {
|
| 137 |
+
"supported": (
|
| 138 |
+
96000,
|
| 139 |
+
88200,
|
| 140 |
+
64000,
|
| 141 |
+
48000,
|
| 142 |
+
44100,
|
| 143 |
+
32000,
|
| 144 |
+
24000,
|
| 145 |
+
22050,
|
| 146 |
+
16000,
|
| 147 |
+
12000,
|
| 148 |
+
11025,
|
| 149 |
+
8000,
|
| 150 |
+
7350,
|
| 151 |
+
)
|
| 152 |
+
},
|
| 153 |
+
"ac3": {
|
| 154 |
+
"supported": (
|
| 155 |
+
48000,
|
| 156 |
+
44100,
|
| 157 |
+
32000,
|
| 158 |
+
)
|
| 159 |
+
},
|
| 160 |
+
"ogg": {"min": 6, "max": 192000},
|
| 161 |
+
"wav": {"min": 0, "max": float("inf")},
|
| 162 |
+
"aiff": {"min": 0, "max": float("inf")},
|
| 163 |
+
"flac": {"min": 0, "max": 192000},
|
| 164 |
+
}
|
| 165 |
+
self.check_ffmpeg()
|
| 166 |
+
self.check_ffprobe()
|
| 167 |
+
|
| 168 |
+
def check_ffmpeg(self):
|
| 169 |
+
try:
|
| 170 |
+
ffmpeg_version_output = subprocess.check_output(
|
| 171 |
+
[self.ffmpeg_path, "-version"], text=True
|
| 172 |
+
)
|
| 173 |
+
except FileNotFoundError:
|
| 174 |
+
if "PYTEST_CURRENT_TEST" not in os.environ:
|
| 175 |
+
raise FileNotFoundError(
|
| 176 |
+
"FFMPEG не установлен. Укажите путь к установленному FFMPEG через переменную окружения MVSEPLESS_FFMPEG"
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
def check_ffprobe(self):
|
| 180 |
+
try:
|
| 181 |
+
ffmpeg_version_output = subprocess.check_output(
|
| 182 |
+
[self.ffprobe_path, "-version"], text=True
|
| 183 |
+
)
|
| 184 |
+
except FileNotFoundError:
|
| 185 |
+
if "PYTEST_CURRENT_TEST" not in os.environ:
|
| 186 |
+
raise FileNotFoundError(
|
| 187 |
+
"FFPROBE не установлен. Укажите путь к установленному FFPROBE через переменную окружения MVSEPLESS_FFPROBE"
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def fit_sr(
|
| 191 |
+
self,
|
| 192 |
+
f: (
|
| 193 |
+
str
|
| 194 |
+
| Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"]
|
| 195 |
+
) = "mp3",
|
| 196 |
+
sr: int = 44100,
|
| 197 |
+
) -> int:
|
| 198 |
+
format_info = self.sample_rates.get(f.lower())
|
| 199 |
+
|
| 200 |
+
if not format_info:
|
| 201 |
+
return None
|
| 202 |
+
|
| 203 |
+
if "supported" in format_info:
|
| 204 |
+
supported_rates = format_info["supported"]
|
| 205 |
+
if sr in supported_rates:
|
| 206 |
+
return sr
|
| 207 |
+
|
| 208 |
+
return min(supported_rates, key=lambda x: abs(x - sr))
|
| 209 |
+
|
| 210 |
+
elif "min" in format_info and "max" in format_info:
|
| 211 |
+
min_rate = format_info["min"]
|
| 212 |
+
max_rate = format_info["max"]
|
| 213 |
+
|
| 214 |
+
if sr < min_rate:
|
| 215 |
+
return min_rate
|
| 216 |
+
elif sr > max_rate:
|
| 217 |
+
return max_rate
|
| 218 |
+
else:
|
| 219 |
+
return sr
|
| 220 |
+
|
| 221 |
+
return None
|
| 222 |
+
|
| 223 |
+
def fit_br(
|
| 224 |
+
self,
|
| 225 |
+
f: (
|
| 226 |
+
str
|
| 227 |
+
| Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"]
|
| 228 |
+
) = "mp3",
|
| 229 |
+
br: int = 320,
|
| 230 |
+
) -> int:
|
| 231 |
+
if f not in self.bitrate_limit:
|
| 232 |
+
raise NotSupportedFormat(f"Формат {f} не поддерживается")
|
| 233 |
+
|
| 234 |
+
limits = self.bitrate_limit[f]
|
| 235 |
+
|
| 236 |
+
if br < limits["min"]:
|
| 237 |
+
return limits["min"]
|
| 238 |
+
elif br > limits["max"]:
|
| 239 |
+
return limits["max"]
|
| 240 |
+
else:
|
| 241 |
+
return br
|
| 242 |
+
|
| 243 |
+
def get_info(
|
| 244 |
+
self,
|
| 245 |
+
i: str | os.PathLike | Callable | None = None,
|
| 246 |
+
) -> dict[int, dict[int, float]]:
|
| 247 |
+
audio_info = {}
|
| 248 |
+
if i:
|
| 249 |
+
if isinstance(i, Path):
|
| 250 |
+
i = str(i)
|
| 251 |
+
if os.path.exists(i):
|
| 252 |
+
cmd = [
|
| 253 |
+
self.ffprobe_path,
|
| 254 |
+
"-i",
|
| 255 |
+
i,
|
| 256 |
+
"-v",
|
| 257 |
+
"quiet",
|
| 258 |
+
"-hide_banner",
|
| 259 |
+
"-show_entries",
|
| 260 |
+
"stream=index,sample_rate,duration",
|
| 261 |
+
"-select_streams",
|
| 262 |
+
"a",
|
| 263 |
+
"-of",
|
| 264 |
+
"json",
|
| 265 |
+
]
|
| 266 |
+
|
| 267 |
+
process = subprocess.Popen(
|
| 268 |
+
cmd,
|
| 269 |
+
stdin=subprocess.PIPE,
|
| 270 |
+
stdout=subprocess.PIPE,
|
| 271 |
+
stderr=subprocess.PIPE,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
stdout, stderr = process.communicate()
|
| 275 |
+
|
| 276 |
+
if process.returncode != 0:
|
| 277 |
+
print(f"STDERR: {stderr.decode('utf-8')}")
|
| 278 |
+
print(f"STDOUT: {stdout.decode('utf-8')}")
|
| 279 |
+
|
| 280 |
+
json_output = json.loads(stdout)
|
| 281 |
+
streams = json_output["streams"]
|
| 282 |
+
if not streams:
|
| 283 |
+
pass
|
| 284 |
+
|
| 285 |
+
else:
|
| 286 |
+
for a, stream in enumerate(streams):
|
| 287 |
+
audio_info[a] = {
|
| 288 |
+
"sample_rate": int(stream.get("sample_rate", 0)),
|
| 289 |
+
"duration": float(stream.get("duration", 0)),
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
return audio_info
|
| 293 |
+
|
| 294 |
+
else:
|
| 295 |
+
raise FileExistsError("Указанного файла не существует")
|
| 296 |
+
|
| 297 |
+
else:
|
| 298 |
+
raise NotInputFileSpecified("Не указан путь к файлу")
|
| 299 |
+
|
| 300 |
+
def check(self, i: str | os.PathLike | Callable | None = None) -> bool:
|
| 301 |
+
if i:
|
| 302 |
+
if isinstance(i, Path):
|
| 303 |
+
i = str(i)
|
| 304 |
+
if os.path.exists(i):
|
| 305 |
+
info = self.get_info(i=i)
|
| 306 |
+
if info:
|
| 307 |
+
list_streams = list(info.keys())
|
| 308 |
+
if len(list_streams) > 0:
|
| 309 |
+
if info[0].get("sample_rate") > 0:
|
| 310 |
+
return True
|
| 311 |
+
else:
|
| 312 |
+
return False
|
| 313 |
+
else:
|
| 314 |
+
return False
|
| 315 |
+
else:
|
| 316 |
+
return False
|
| 317 |
+
else:
|
| 318 |
+
raise FileExistsError("Указанного файла не существует")
|
| 319 |
+
else:
|
| 320 |
+
raise NotInputFileSpecified("Не указан путь к файлу")
|
| 321 |
+
|
| 322 |
+
def read(
|
| 323 |
+
self,
|
| 324 |
+
i: str | os.PathLike | Callable | None = None,
|
| 325 |
+
sr: int | None = None,
|
| 326 |
+
mono: bool = False,
|
| 327 |
+
dtype: DTypeLike = np.float32,
|
| 328 |
+
s: int = 0,
|
| 329 |
+
) -> tuple[np.ndarray, int, float]:
|
| 330 |
+
output_format = self.dtypes_dict.get(dtype, None)
|
| 331 |
+
if not output_format:
|
| 332 |
+
raise NotSupportedDataType(f"Этот тип данных не поддерживается {dtype}")
|
| 333 |
+
if i:
|
| 334 |
+
if isinstance(i, Path):
|
| 335 |
+
i = str(i)
|
| 336 |
+
if os.path.exists(i):
|
| 337 |
+
audio_info = self.get_info(i=i)
|
| 338 |
+
list_streams = list(audio_info.keys())
|
| 339 |
+
if audio_info.get(s, False):
|
| 340 |
+
stream = s
|
| 341 |
+
else:
|
| 342 |
+
if len(list_streams) > 0:
|
| 343 |
+
stream = 0
|
| 344 |
+
else:
|
| 345 |
+
raise FileIsNotAudio("В входном файле нет аудио потоков")
|
| 346 |
+
|
| 347 |
+
sample_rate_input = audio_info[stream]["sample_rate"]
|
| 348 |
+
if sample_rate_input == 0:
|
| 349 |
+
raise FileIsNotAudio("В входном файле нет аудио потоков")
|
| 350 |
+
|
| 351 |
+
cmd = [
|
| 352 |
+
self.ffmpeg_path,
|
| 353 |
+
"-i",
|
| 354 |
+
i,
|
| 355 |
+
"-map",
|
| 356 |
+
f"0:a:{stream}",
|
| 357 |
+
"-vn",
|
| 358 |
+
"-f",
|
| 359 |
+
output_format,
|
| 360 |
+
"-ac",
|
| 361 |
+
"1" if mono else "2",
|
| 362 |
+
]
|
| 363 |
+
|
| 364 |
+
if sr:
|
| 365 |
+
cmd.extend(["-ar", str(sr)])
|
| 366 |
+
else:
|
| 367 |
+
sr = sample_rate_input
|
| 368 |
+
|
| 369 |
+
cmd.append("pipe:1")
|
| 370 |
+
|
| 371 |
+
process = subprocess.Popen(
|
| 372 |
+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=10**8
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
try:
|
| 376 |
+
|
| 377 |
+
raw_audio, stderr = process.communicate(timeout=300)
|
| 378 |
+
|
| 379 |
+
if process.returncode != 0:
|
| 380 |
+
raise ErrorDecode(f"FFmpeg error: {stderr.decode()}")
|
| 381 |
+
|
| 382 |
+
except subprocess.TimeoutExpired:
|
| 383 |
+
process.kill()
|
| 384 |
+
raise ErrorDecode("FFmpeg timeout при чтении файла")
|
| 385 |
+
|
| 386 |
+
audio_array = np.frombuffer(raw_audio, dtype=dtype)
|
| 387 |
+
|
| 388 |
+
channels = 1 if mono else 2
|
| 389 |
+
audio_array = audio_array.reshape((-1, channels)).T
|
| 390 |
+
if audio_array.ndim > 1 and channels == 1:
|
| 391 |
+
audio_array = np.mean(
|
| 392 |
+
audio_array, axis=tuple(range(audio_array.ndim - 1))
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
len_samples = float(audio_array.shape[-1])
|
| 396 |
+
|
| 397 |
+
duration = len_samples / sr
|
| 398 |
+
|
| 399 |
+
print(f"Частота дискретизации: {sr}")
|
| 400 |
+
|
| 401 |
+
return audio_array.copy(), sr, duration
|
| 402 |
+
else:
|
| 403 |
+
raise FileExistsError("Указанного файла не существует")
|
| 404 |
+
|
| 405 |
+
else:
|
| 406 |
+
raise NotInputFileSpecified("Не указан путь к файлу")
|
| 407 |
+
|
| 408 |
+
def write(
|
| 409 |
+
self,
|
| 410 |
+
o: str | os.PathLike | Callable | None = None,
|
| 411 |
+
array: np.ndarray = np.array([], dtype=np.float32),
|
| 412 |
+
sr: int = 44100,
|
| 413 |
+
of: (
|
| 414 |
+
str
|
| 415 |
+
| Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"]
|
| 416 |
+
| None
|
| 417 |
+
) = None,
|
| 418 |
+
br: str | int | None = None,
|
| 419 |
+
) -> str:
|
| 420 |
+
if isinstance(array, np.ndarray):
|
| 421 |
+
|
| 422 |
+
if len(array.shape) == 1:
|
| 423 |
+
array = array.reshape(-1, 1)
|
| 424 |
+
elif len(array.shape) == 2:
|
| 425 |
+
if array.shape[0] == 2:
|
| 426 |
+
array = array.T
|
| 427 |
+
else:
|
| 428 |
+
raise ValueError(
|
| 429 |
+
"numpy-массив должен быть либо одномерным, либо двухмерным"
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
if array.dtype == np.int16:
|
| 433 |
+
input_format = "s16le"
|
| 434 |
+
elif array.dtype == np.int32:
|
| 435 |
+
input_format = "s32le"
|
| 436 |
+
elif array.dtype == np.float32:
|
| 437 |
+
input_format = "f32le"
|
| 438 |
+
elif array.dtype == np.float64:
|
| 439 |
+
input_format = "f64le"
|
| 440 |
+
else:
|
| 441 |
+
raise NotSupportedDataType(
|
| 442 |
+
f"Этот тип данных не поддерживается {array.dtype}"
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
if array.shape[1] == 1:
|
| 446 |
+
audio_bytes = array.tobytes()
|
| 447 |
+
|
| 448 |
+
channels = 1
|
| 449 |
+
|
| 450 |
+
elif array.shape[1] == 2:
|
| 451 |
+
audio_bytes = array.tobytes()
|
| 452 |
+
|
| 453 |
+
channels = 2
|
| 454 |
+
else:
|
| 455 |
+
raise ValueError("numpy-массив должен содержать 1 или 2 канала")
|
| 456 |
+
|
| 457 |
+
else:
|
| 458 |
+
raise ValueError("Вход должен быть numpy-массивом")
|
| 459 |
+
|
| 460 |
+
if o:
|
| 461 |
+
if isinstance(o, Path):
|
| 462 |
+
o = str(o)
|
| 463 |
+
output_dir = os.path.dirname(o)
|
| 464 |
+
output_base = os.path.basename(o)
|
| 465 |
+
output_name, output_ext = os.path.splitext(output_base)
|
| 466 |
+
if output_dir != "":
|
| 467 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 468 |
+
if output_ext == "":
|
| 469 |
+
if of:
|
| 470 |
+
o += f".{of}"
|
| 471 |
+
else:
|
| 472 |
+
o += f".mp3"
|
| 473 |
+
elif output_ext == ".":
|
| 474 |
+
if of:
|
| 475 |
+
o += f"{of}"
|
| 476 |
+
else:
|
| 477 |
+
o += f"mp3"
|
| 478 |
+
else:
|
| 479 |
+
raise NotOutputFileSpecified("Не указан путь к выходному файлу")
|
| 480 |
+
|
| 481 |
+
if of:
|
| 482 |
+
if of in self.output_formats:
|
| 483 |
+
output_name, output_ext = os.path.splitext(o)
|
| 484 |
+
if output_ext == f".{of}":
|
| 485 |
+
pass
|
| 486 |
+
else:
|
| 487 |
+
o = f"{os.path.join(output_dir, output_name)}.{of}"
|
| 488 |
+
else:
|
| 489 |
+
raise NotSupportedFormat(f"Неподдерживаемый формат: {of}")
|
| 490 |
+
else:
|
| 491 |
+
of = os.path.splitext(o)[1].strip(".")
|
| 492 |
+
if of in self.output_formats:
|
| 493 |
+
pass
|
| 494 |
+
else:
|
| 495 |
+
raise NotSupportedFormat(f"Неподдерживаемый формат: {of}")
|
| 496 |
+
|
| 497 |
+
if sr:
|
| 498 |
+
if isinstance(sr, int):
|
| 499 |
+
sample_rate_fixed = self.fit_sr(f=of, sr=sr)
|
| 500 |
+
elif isinstance(sr, float):
|
| 501 |
+
sr = int(sr)
|
| 502 |
+
sample_rate_fixed = self.fit_sr(f=of, sr=sr)
|
| 503 |
+
else:
|
| 504 |
+
raise SampleRateError(
|
| 505 |
+
f"Частота дискретизации должна быть числом\n\nЗначение: {sr}\nТип: {type(sr)}"
|
| 506 |
+
)
|
| 507 |
+
else:
|
| 508 |
+
raise SampleRateError("Не указана частота дискретизации")
|
| 509 |
+
|
| 510 |
+
bitrate_fixed = "320k"
|
| 511 |
+
|
| 512 |
+
if of not in ["wav", "flac", "aiff"]:
|
| 513 |
+
if br:
|
| 514 |
+
if isinstance(br, int):
|
| 515 |
+
bitrate_fixed = self.fit_br(f=of, br=br)
|
| 516 |
+
elif isinstance(br, float):
|
| 517 |
+
bitrate_fixed = self.fit_br(f=of, br=int(br))
|
| 518 |
+
elif isinstance(br, str):
|
| 519 |
+
bitrate_fixed = self.fit_br(f=of, br=int(br.strip("k").strip("K")))
|
| 520 |
+
else:
|
| 521 |
+
bitrate_fixed = self.fit_br(f=of, br=320)
|
| 522 |
+
else:
|
| 523 |
+
bitrate_fixed = self.fit_br(of, 320)
|
| 524 |
+
|
| 525 |
+
format_settings = {
|
| 526 |
+
"wav": [
|
| 527 |
+
"-c:a",
|
| 528 |
+
"pcm_f32le",
|
| 529 |
+
"-sample_fmt",
|
| 530 |
+
"flt",
|
| 531 |
+
],
|
| 532 |
+
"aiff": [
|
| 533 |
+
"-c:a",
|
| 534 |
+
"pcm_f32be",
|
| 535 |
+
"-sample_fmt",
|
| 536 |
+
"flt",
|
| 537 |
+
],
|
| 538 |
+
"flac": [
|
| 539 |
+
"-c:a",
|
| 540 |
+
"flac",
|
| 541 |
+
"-compression_level",
|
| 542 |
+
"12",
|
| 543 |
+
"-sample_fmt",
|
| 544 |
+
"s32",
|
| 545 |
+
],
|
| 546 |
+
"mp3": [
|
| 547 |
+
"-c:a",
|
| 548 |
+
"libmp3lame",
|
| 549 |
+
"-b:a",
|
| 550 |
+
f"{bitrate_fixed}k",
|
| 551 |
+
],
|
| 552 |
+
"ogg": [
|
| 553 |
+
"-c:a",
|
| 554 |
+
"libvorbis",
|
| 555 |
+
"-b:a",
|
| 556 |
+
f"{bitrate_fixed}k",
|
| 557 |
+
],
|
| 558 |
+
"opus": [
|
| 559 |
+
"-c:a",
|
| 560 |
+
"libopus",
|
| 561 |
+
"-b:a",
|
| 562 |
+
f"{bitrate_fixed}k",
|
| 563 |
+
],
|
| 564 |
+
"m4a": [
|
| 565 |
+
"-c:a",
|
| 566 |
+
"aac",
|
| 567 |
+
"-b:a",
|
| 568 |
+
f"{bitrate_fixed}k",
|
| 569 |
+
],
|
| 570 |
+
"aac": [
|
| 571 |
+
"-c:a",
|
| 572 |
+
"aac",
|
| 573 |
+
"-b:a",
|
| 574 |
+
f"{bitrate_fixed}k",
|
| 575 |
+
],
|
| 576 |
+
"ac3": [
|
| 577 |
+
"-c:a",
|
| 578 |
+
"ac3",
|
| 579 |
+
"-b:a",
|
| 580 |
+
f"{bitrate_fixed}k",
|
| 581 |
+
],
|
| 582 |
+
}
|
| 583 |
+
|
| 584 |
+
cmd = [
|
| 585 |
+
self.ffmpeg_path,
|
| 586 |
+
"-y",
|
| 587 |
+
"-f",
|
| 588 |
+
input_format,
|
| 589 |
+
"-ar",
|
| 590 |
+
str(sr),
|
| 591 |
+
"-ac",
|
| 592 |
+
str(channels),
|
| 593 |
+
"-i",
|
| 594 |
+
"pipe:0",
|
| 595 |
+
"-ac",
|
| 596 |
+
str(channels),
|
| 597 |
+
]
|
| 598 |
+
|
| 599 |
+
cmd.extend(["-ar", str(sample_rate_fixed)])
|
| 600 |
+
cmd.extend(format_settings[of])
|
| 601 |
+
o_dir, o_base = os.path.split(o)
|
| 602 |
+
o_base_n, o_base_ext = os.path.splitext(o_base)
|
| 603 |
+
o_base_n = self.sanitize(o_base_n)
|
| 604 |
+
o_base_n = self.short(o_base_n)
|
| 605 |
+
o = os.path.join(o_dir, f"{o_base_n}{o_base_ext}")
|
| 606 |
+
o = self.iter(o)
|
| 607 |
+
cmd.append(o)
|
| 608 |
+
|
| 609 |
+
process = subprocess.Popen(
|
| 610 |
+
cmd,
|
| 611 |
+
stdin=subprocess.PIPE,
|
| 612 |
+
stdout=subprocess.PIPE,
|
| 613 |
+
stderr=subprocess.PIPE,
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
try:
|
| 617 |
+
stdout, stderr = process.communicate(input=audio_bytes, timeout=300)
|
| 618 |
+
except subprocess.TimeoutExpired:
|
| 619 |
+
process.kill()
|
| 620 |
+
raise ErrorEncode("FFmpeg timeout: операция заняла слишком много времени")
|
| 621 |
+
|
| 622 |
+
if process.returncode != 0:
|
| 623 |
+
raise ErrorEncode(
|
| 624 |
+
f"FFmpeg завершился с ошибкой (код: {process.returncode})"
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
return os.path.abspath(o)
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
class Inverter(Audio):
|
| 631 |
+
def __init__(self):
|
| 632 |
+
super().__init__()
|
| 633 |
+
self.test = "test"
|
| 634 |
+
self.w_types = [
|
| 635 |
+
"boxcar",
|
| 636 |
+
"triang",
|
| 637 |
+
"blackman",
|
| 638 |
+
"hamming",
|
| 639 |
+
"hann",
|
| 640 |
+
"bartlett",
|
| 641 |
+
"flattop",
|
| 642 |
+
"parzen",
|
| 643 |
+
"bohman",
|
| 644 |
+
"blackmanharris",
|
| 645 |
+
"nuttall",
|
| 646 |
+
"barthann",
|
| 647 |
+
"cosine",
|
| 648 |
+
"exponential",
|
| 649 |
+
"tukey",
|
| 650 |
+
"taylor",
|
| 651 |
+
"lanczos",
|
| 652 |
+
]
|
| 653 |
+
|
| 654 |
+
def load_audio(self, filepath):
|
| 655 |
+
try:
|
| 656 |
+
y, sr, _ = self.read(i=filepath, sr=None, mono=False)
|
| 657 |
+
return y, sr
|
| 658 |
+
except Exception as e:
|
| 659 |
+
print(f"Ошибка загрузки аудио: {e}")
|
| 660 |
+
return None, None
|
| 661 |
+
|
| 662 |
+
def process_channel(
|
| 663 |
+
self, y1_ch, y2_ch, sr, method, w_size=2048, overlap=2, w_type="hann"
|
| 664 |
+
):
|
| 665 |
+
HOP_LENGTH = w_size // overlap
|
| 666 |
+
if method == "waveform":
|
| 667 |
+
return y1_ch - y2_ch
|
| 668 |
+
|
| 669 |
+
elif method == "spectrogram":
|
| 670 |
+
S1 = librosa.stft(
|
| 671 |
+
y1_ch, n_fft=w_size, hop_length=HOP_LENGTH, win_length=w_size
|
| 672 |
+
)
|
| 673 |
+
S2 = librosa.stft(
|
| 674 |
+
y2_ch, n_fft=w_size, hop_length=HOP_LENGTH, win_length=w_size
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
mag1 = np.abs(S1)
|
| 678 |
+
mag2 = np.abs(S2)
|
| 679 |
+
|
| 680 |
+
mag_result = np.maximum(mag1 - mag2, 0)
|
| 681 |
+
|
| 682 |
+
phase = np.angle(S1)
|
| 683 |
+
|
| 684 |
+
S_result = mag_result * np.exp(1j * phase)
|
| 685 |
+
|
| 686 |
+
return librosa.istft(
|
| 687 |
+
S_result,
|
| 688 |
+
n_fft=w_size,
|
| 689 |
+
hop_length=HOP_LENGTH,
|
| 690 |
+
win_length=w_size,
|
| 691 |
+
length=len(y1_ch),
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
def process_audio(
|
| 695 |
+
self,
|
| 696 |
+
audio1_path,
|
| 697 |
+
audio2_path,
|
| 698 |
+
out_format,
|
| 699 |
+
method,
|
| 700 |
+
output_path="./inverted.mp3",
|
| 701 |
+
w_size=2048,
|
| 702 |
+
overlap=2,
|
| 703 |
+
w_type="hann",
|
| 704 |
+
):
|
| 705 |
+
y1, sr1 = self.load_audio(audio1_path)
|
| 706 |
+
y2, sr2 = self.load_audio(audio2_path)
|
| 707 |
+
|
| 708 |
+
if sr1 is None or sr2 is None:
|
| 709 |
+
raise Exception("Произошла ошибка при чтении файлов")
|
| 710 |
+
|
| 711 |
+
channels1 = 1 if y1.ndim == 1 else y1.shape[0]
|
| 712 |
+
channels2 = 1 if y2.ndim == 1 else y2.shape[0]
|
| 713 |
+
|
| 714 |
+
if channels1 > 1:
|
| 715 |
+
y1 = y1.T
|
| 716 |
+
else:
|
| 717 |
+
y1 = y1.reshape(-1, 1)
|
| 718 |
+
|
| 719 |
+
if channels2 > 1:
|
| 720 |
+
y2 = y2.T
|
| 721 |
+
else:
|
| 722 |
+
y2 = y2.reshape(-1, 1)
|
| 723 |
+
|
| 724 |
+
if sr1 != sr2:
|
| 725 |
+
if channels2 > 1:
|
| 726 |
+
y2_resampled_list = []
|
| 727 |
+
for c in range(channels2):
|
| 728 |
+
channel_resampled = librosa.resample(
|
| 729 |
+
y2[:, c], orig_sr=sr2, target_sr=sr1
|
| 730 |
+
)
|
| 731 |
+
y2_resampled_list.append(channel_resampled)
|
| 732 |
+
|
| 733 |
+
min_channel_length = min(len(ch) for ch in y2_resampled_list)
|
| 734 |
+
|
| 735 |
+
y2_resampled = np.zeros(
|
| 736 |
+
(min_channel_length, channels2), dtype=np.float32
|
| 737 |
+
)
|
| 738 |
+
for c, channel in enumerate(y2_resampled_list):
|
| 739 |
+
y2_resampled[:, c] = channel[:min_channel_length]
|
| 740 |
+
|
| 741 |
+
y2 = y2_resampled
|
| 742 |
+
else:
|
| 743 |
+
y2 = librosa.resample(y2[:, 0], orig_sr=sr2, target_sr=sr1)
|
| 744 |
+
y2 = y2.reshape(-1, 1)
|
| 745 |
+
sr2 = sr1
|
| 746 |
+
|
| 747 |
+
min_len = min(len(y1), len(y2))
|
| 748 |
+
y1 = y1[:min_len]
|
| 749 |
+
y2 = y2[:min_len]
|
| 750 |
+
|
| 751 |
+
result_channels = []
|
| 752 |
+
|
| 753 |
+
if channels1 == 1 and channels2 > 1:
|
| 754 |
+
y2 = y2.mean(axis=1, keepdims=True)
|
| 755 |
+
channels2 = 1
|
| 756 |
+
|
| 757 |
+
for c in range(channels1):
|
| 758 |
+
y1_ch = y1[:, c]
|
| 759 |
+
|
| 760 |
+
if channels2 == 1:
|
| 761 |
+
y2_ch = y2[:, 0]
|
| 762 |
+
else:
|
| 763 |
+
y2_ch = y2[:, min(c, channels2 - 1)]
|
| 764 |
+
|
| 765 |
+
result_ch = self.process_channel(
|
| 766 |
+
y1_ch, y2_ch, sr1, method, w_size=w_size, overlap=overlap, w_type=w_type
|
| 767 |
+
)
|
| 768 |
+
result_channels.append(result_ch)
|
| 769 |
+
|
| 770 |
+
if len(result_channels) > 1:
|
| 771 |
+
result = np.column_stack(result_channels)
|
| 772 |
+
else:
|
| 773 |
+
result = np.array(result_channels[0])
|
| 774 |
+
|
| 775 |
+
if result.ndim > 1:
|
| 776 |
+
for c in range(result.shape[1]):
|
| 777 |
+
channel = result[:, c]
|
| 778 |
+
max_val = np.max(np.abs(channel))
|
| 779 |
+
if max_val > 0:
|
| 780 |
+
result[:, c] = channel * 0.9 / max_val
|
| 781 |
+
else:
|
| 782 |
+
max_val = np.max(np.abs(result))
|
| 783 |
+
if max_val > 0:
|
| 784 |
+
result = result * 0.9 / max_val
|
| 785 |
+
|
| 786 |
+
inverted = self.write(
|
| 787 |
+
o=output_path, array=result.T, sr=sr1, of=out_format, br="320k"
|
| 788 |
+
)
|
| 789 |
+
return inverted
|
mvsepless/downloader.py
CHANGED
|
@@ -1,92 +1,90 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import yt_dlp
|
| 3 |
-
import time
|
| 4 |
-
from tqdm import tqdm
|
| 5 |
-
import urllib.request
|
| 6 |
-
|
| 7 |
-
DOWNLOAD_DIR = os.environ.get(
|
| 8 |
-
"MVSEPLESS_DOWNLOAD_DIR", os.path.join(os.getcwd(), "downloaded")
|
| 9 |
-
)
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
break
|
| 36 |
-
except Exception as e:
|
| 37 |
-
print(f"Попытка {attempt + 1}/{retries} не удалась. Ошибка: {e}")
|
| 38 |
-
if attempt < retries - 1:
|
| 39 |
-
print("Повторная попытка...")
|
| 40 |
-
time.sleep(2)
|
| 41 |
-
else:
|
| 42 |
-
print("Все попытки загрузки завершились неудачно")
|
| 43 |
-
raise
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
outtmpl = "%(title)s.%(ext)s" if title is None else f"{title}.%(ext)s"
|
| 55 |
-
|
| 56 |
-
ydl_opts = {
|
| 57 |
-
"format": "bestaudio/best",
|
| 58 |
-
"outtmpl": os.path.join(
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
"
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
"
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
entry = info["entries"][0]
|
| 81 |
-
filename = ydl.prepare_filename(entry)
|
| 82 |
-
else:
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
return
|
| 91 |
-
except Exception as e:
|
| 92 |
-
return None
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import yt_dlp
|
| 3 |
+
import time
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import urllib.request
|
| 6 |
+
|
| 7 |
+
DOWNLOAD_DIR = os.environ.get(
|
| 8 |
+
"MVSEPLESS_DOWNLOAD_DIR", os.path.join(os.getcwd(), "downloaded")
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def dw_file(url_model: str, local_path: str, retries: int = 180):
|
| 13 |
+
dir_name = os.path.dirname(local_path)
|
| 14 |
+
if dir_name != "":
|
| 15 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 16 |
+
|
| 17 |
+
class TqdmUpTo(tqdm):
|
| 18 |
+
def update_to(self, b=1, bsize=1, tsize=None):
|
| 19 |
+
if tsize is not None:
|
| 20 |
+
self.total = tsize
|
| 21 |
+
self.update(b * bsize - self.n)
|
| 22 |
+
|
| 23 |
+
for attempt in range(retries):
|
| 24 |
+
try:
|
| 25 |
+
with TqdmUpTo(
|
| 26 |
+
unit="B",
|
| 27 |
+
unit_scale=True,
|
| 28 |
+
unit_divisor=1024,
|
| 29 |
+
miniters=1,
|
| 30 |
+
desc=os.path.basename(local_path),
|
| 31 |
+
) as t:
|
| 32 |
+
urllib.request.urlretrieve(
|
| 33 |
+
url_model, local_path, reporthook=t.update_to
|
| 34 |
+
)
|
| 35 |
+
break
|
| 36 |
+
except Exception as e:
|
| 37 |
+
print(f"Попытка {attempt + 1}/{retries} не удалась. Ошибка: {e}")
|
| 38 |
+
if attempt < retries - 1:
|
| 39 |
+
print("Повторная попытка...")
|
| 40 |
+
time.sleep(2)
|
| 41 |
+
else:
|
| 42 |
+
print("Все попытки загрузки завершились неудачно")
|
| 43 |
+
raise
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def dw_yt_dlp(
|
| 47 |
+
url,
|
| 48 |
+
output_dir=None,
|
| 49 |
+
cookie=None,
|
| 50 |
+
output_format="mp3",
|
| 51 |
+
output_bitrate="320",
|
| 52 |
+
title=None,
|
| 53 |
+
):
|
| 54 |
+
outtmpl = "%(title)s.%(ext)s" if title is None else f"{title}.%(ext)s"
|
| 55 |
+
|
| 56 |
+
ydl_opts = {
|
| 57 |
+
"format": "bestaudio/best",
|
| 58 |
+
"outtmpl": os.path.join(
|
| 59 |
+
DOWNLOAD_DIR if not output_dir else output_dir, outtmpl
|
| 60 |
+
),
|
| 61 |
+
"postprocessors": [
|
| 62 |
+
{
|
| 63 |
+
"key": "FFmpegExtractAudio",
|
| 64 |
+
"preferredcodec": output_format,
|
| 65 |
+
"preferredquality": output_bitrate,
|
| 66 |
+
}
|
| 67 |
+
],
|
| 68 |
+
"noplaylist": True,
|
| 69 |
+
"quiet": True,
|
| 70 |
+
"no_warnings": True,
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
if cookie and os.path.exists(cookie):
|
| 74 |
+
ydl_opts["cookiefile"] = cookie
|
| 75 |
+
|
| 76 |
+
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
| 77 |
+
try:
|
| 78 |
+
info = ydl.extract_info(url, download=True)
|
| 79 |
+
if "_type" in info and info["_type"] == "playlist":
|
| 80 |
+
entry = info["entries"][0]
|
| 81 |
+
filename = ydl.prepare_filename(entry)
|
| 82 |
+
else:
|
| 83 |
+
filename = ydl.prepare_filename(info)
|
| 84 |
+
|
| 85 |
+
base, _ = os.path.splitext(filename)
|
| 86 |
+
audio_file = base + f".{output_format}"
|
| 87 |
+
|
| 88 |
+
return os.path.join(DOWNLOAD_DIR, audio_file)
|
| 89 |
+
except Exception as e:
|
| 90 |
+
return None
|
|
|
|
|
|
mvsepless/ensemble.py
CHANGED
|
@@ -1,224 +1,206 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
import
|
| 5 |
-
import
|
| 6 |
-
import
|
| 7 |
-
import
|
| 8 |
-
import
|
| 9 |
-
|
| 10 |
-
if not __package__:
|
| 11 |
-
from audio import Audio
|
| 12 |
-
from namer import Namer
|
| 13 |
-
else:
|
| 14 |
-
from .audio import Audio
|
| 15 |
-
from .namer import Namer
|
| 16 |
-
|
| 17 |
-
audio = Audio()
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
dims
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
indices = list(indices)
|
| 44 |
-
indices.insert(axis % len(a.shape), argmax)
|
| 45 |
-
return a[tuple(indices)]
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def absmin(a, *, axis):
|
| 49 |
-
dims = list(a.shape)
|
| 50 |
-
dims.pop(axis)
|
| 51 |
-
indices = np.ogrid[tuple(slice(0, d) for d in dims)]
|
| 52 |
-
argmax = np.abs(a).argmin(axis=axis)
|
| 53 |
-
indices.insert((len(a.shape) + axis) % len(a.shape), argmax)
|
| 54 |
-
return a[tuple(indices)]
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def lambda_max(arr, axis=None, key=None, keepdims=False):
|
| 58 |
-
idxs = np.argmax(key(arr), axis)
|
| 59 |
-
if axis is not None:
|
| 60 |
-
idxs = np.expand_dims(idxs, axis)
|
| 61 |
-
result = np.take_along_axis(arr, idxs, axis)
|
| 62 |
-
if not keepdims:
|
| 63 |
-
result = np.squeeze(result, axis=axis)
|
| 64 |
-
return result
|
| 65 |
-
else:
|
| 66 |
-
return arr.flatten()[idxs]
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def lambda_min(arr, axis=None, key=None, keepdims=False):
|
| 70 |
-
idxs = np.argmin(key(arr), axis)
|
| 71 |
-
if axis is not None:
|
| 72 |
-
idxs = np.expand_dims(idxs, axis)
|
| 73 |
-
result = np.take_along_axis(arr, idxs, axis)
|
| 74 |
-
if not keepdims:
|
| 75 |
-
result = np.squeeze(result, axis=axis)
|
| 76 |
-
return result
|
| 77 |
-
else:
|
| 78 |
-
return arr.flatten()[idxs]
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
def average_waveforms(pred_track, weights, algorithm):
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
pred_track =
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
pred_track =
|
| 111 |
-
elif algorithm in ["
|
| 112 |
-
pred_track =
|
| 113 |
-
pred_track =
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
pred_track =
|
| 117 |
-
|
| 118 |
-
pred_track =
|
| 119 |
-
|
| 120 |
-
pred_track =
|
| 121 |
-
|
| 122 |
-
pred_track =
|
| 123 |
-
|
| 124 |
-
pred_track =
|
| 125 |
-
|
| 126 |
-
pred_track =
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
""
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
:
|
| 143 |
-
|
| 144 |
-
:
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
if
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
wav = wav[:, :max_length]
|
| 208 |
-
|
| 209 |
-
data.append(wav)
|
| 210 |
-
|
| 211 |
-
data = np.array(data)
|
| 212 |
-
res = average_waveforms(data, weights, ensemble_type)
|
| 213 |
-
print("Форма результата: {}".format(res.shape))
|
| 214 |
-
|
| 215 |
-
output_wav = f"{output}_orig.wav"
|
| 216 |
-
output = f"{output}.{out_format}"
|
| 217 |
-
|
| 218 |
-
output = audio.write(o=output, array=res.T, sr=sr, of=out_format, br="320k")
|
| 219 |
-
if add_wav:
|
| 220 |
-
output_wav = audio.write(o=output_wav, array=res.T, sr=sr, of="wav")
|
| 221 |
-
return output, output_wav
|
| 222 |
-
|
| 223 |
-
else:
|
| 224 |
-
return output
|
|
|
|
| 1 |
+
__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import librosa
|
| 6 |
+
import tempfile
|
| 7 |
+
import numpy as np
|
| 8 |
+
import argparse
|
| 9 |
+
|
| 10 |
+
if not __package__:
|
| 11 |
+
from audio import Audio
|
| 12 |
+
from namer import Namer
|
| 13 |
+
else:
|
| 14 |
+
from .audio import Audio
|
| 15 |
+
from .namer import Namer
|
| 16 |
+
|
| 17 |
+
audio = Audio()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def stft(wave, nfft, hl):
|
| 21 |
+
wave_left = np.asfortranarray(wave[0])
|
| 22 |
+
wave_right = np.asfortranarray(wave[1])
|
| 23 |
+
spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
|
| 24 |
+
spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
|
| 25 |
+
spec = np.asfortranarray([spec_left, spec_right])
|
| 26 |
+
return spec
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def istft(spec, hl, length):
|
| 30 |
+
spec_left = np.asfortranarray(spec[0])
|
| 31 |
+
spec_right = np.asfortranarray(spec[1])
|
| 32 |
+
wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
|
| 33 |
+
wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
|
| 34 |
+
wave = np.asfortranarray([wave_left, wave_right])
|
| 35 |
+
return wave
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def absmax(a, *, axis):
|
| 39 |
+
dims = list(a.shape)
|
| 40 |
+
dims.pop(axis)
|
| 41 |
+
indices = np.ogrid[tuple(slice(0, d) for d in dims)]
|
| 42 |
+
argmax = np.abs(a).argmax(axis=axis)
|
| 43 |
+
indices = list(indices)
|
| 44 |
+
indices.insert(axis % len(a.shape), argmax)
|
| 45 |
+
return a[tuple(indices)]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def absmin(a, *, axis):
|
| 49 |
+
dims = list(a.shape)
|
| 50 |
+
dims.pop(axis)
|
| 51 |
+
indices = np.ogrid[tuple(slice(0, d) for d in dims)]
|
| 52 |
+
argmax = np.abs(a).argmin(axis=axis)
|
| 53 |
+
indices.insert((len(a.shape) + axis) % len(a.shape), argmax)
|
| 54 |
+
return a[tuple(indices)]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def lambda_max(arr, axis=None, key=None, keepdims=False):
|
| 58 |
+
idxs = np.argmax(key(arr), axis)
|
| 59 |
+
if axis is not None:
|
| 60 |
+
idxs = np.expand_dims(idxs, axis)
|
| 61 |
+
result = np.take_along_axis(arr, idxs, axis)
|
| 62 |
+
if not keepdims:
|
| 63 |
+
result = np.squeeze(result, axis=axis)
|
| 64 |
+
return result
|
| 65 |
+
else:
|
| 66 |
+
return arr.flatten()[idxs]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def lambda_min(arr, axis=None, key=None, keepdims=False):
|
| 70 |
+
idxs = np.argmin(key(arr), axis)
|
| 71 |
+
if axis is not None:
|
| 72 |
+
idxs = np.expand_dims(idxs, axis)
|
| 73 |
+
result = np.take_along_axis(arr, idxs, axis)
|
| 74 |
+
if not keepdims:
|
| 75 |
+
result = np.squeeze(result, axis=axis)
|
| 76 |
+
return result
|
| 77 |
+
else:
|
| 78 |
+
return arr.flatten()[idxs]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def average_waveforms(pred_track, weights, algorithm):
|
| 82 |
+
|
| 83 |
+
pred_track = np.array(pred_track)
|
| 84 |
+
final_length = pred_track.shape[-1]
|
| 85 |
+
|
| 86 |
+
mod_track = []
|
| 87 |
+
for i in range(pred_track.shape[0]):
|
| 88 |
+
if algorithm == "avg_wave":
|
| 89 |
+
mod_track.append(pred_track[i] * weights[i])
|
| 90 |
+
elif algorithm in ["median_wave", "min_wave", "max_wave"]:
|
| 91 |
+
mod_track.append(pred_track[i])
|
| 92 |
+
elif algorithm in ["avg_fft", "min_fft", "max_fft", "median_fft"]:
|
| 93 |
+
spec = stft(pred_track[i], nfft=2048, hl=1024)
|
| 94 |
+
if algorithm in ["avg_fft"]:
|
| 95 |
+
mod_track.append(spec * weights[i])
|
| 96 |
+
else:
|
| 97 |
+
mod_track.append(spec)
|
| 98 |
+
pred_track = np.array(mod_track)
|
| 99 |
+
|
| 100 |
+
if algorithm in ["avg_wave"]:
|
| 101 |
+
pred_track = pred_track.sum(axis=0)
|
| 102 |
+
pred_track /= np.array(weights).sum().T
|
| 103 |
+
elif algorithm in ["median_wave"]:
|
| 104 |
+
pred_track = np.median(pred_track, axis=0)
|
| 105 |
+
elif algorithm in ["min_wave"]:
|
| 106 |
+
pred_track = np.array(pred_track)
|
| 107 |
+
pred_track = lambda_min(pred_track, axis=0, key=np.abs)
|
| 108 |
+
elif algorithm in ["max_wave"]:
|
| 109 |
+
pred_track = np.array(pred_track)
|
| 110 |
+
pred_track = lambda_max(pred_track, axis=0, key=np.abs)
|
| 111 |
+
elif algorithm in ["avg_fft"]:
|
| 112 |
+
pred_track = pred_track.sum(axis=0)
|
| 113 |
+
pred_track /= np.array(weights).sum()
|
| 114 |
+
pred_track = istft(pred_track, 1024, final_length)
|
| 115 |
+
elif algorithm in ["min_fft"]:
|
| 116 |
+
pred_track = np.array(pred_track)
|
| 117 |
+
pred_track = lambda_min(pred_track, axis=0, key=np.abs)
|
| 118 |
+
pred_track = istft(pred_track, 1024, final_length)
|
| 119 |
+
elif algorithm in ["max_fft"]:
|
| 120 |
+
pred_track = np.array(pred_track)
|
| 121 |
+
pred_track = absmax(pred_track, axis=0)
|
| 122 |
+
pred_track = istft(pred_track, 1024, final_length)
|
| 123 |
+
elif algorithm in ["median_fft"]:
|
| 124 |
+
pred_track = np.array(pred_track)
|
| 125 |
+
pred_track = np.median(pred_track, axis=0)
|
| 126 |
+
pred_track = istft(pred_track, 1024, final_length)
|
| 127 |
+
return pred_track
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def ensemble_audio_files(
|
| 131 |
+
files,
|
| 132 |
+
output="res.wav",
|
| 133 |
+
ensemble_type="avg_wave",
|
| 134 |
+
weights=None,
|
| 135 |
+
out_format="wav",
|
| 136 |
+
add_wav=False,
|
| 137 |
+
) -> str | tuple[str, str]:
|
| 138 |
+
print("Алгоритм склеивания: {}".format(ensemble_type))
|
| 139 |
+
print("Количество входных файлов: {}".format(len(files)))
|
| 140 |
+
if weights is not None:
|
| 141 |
+
weights = np.array(weights)
|
| 142 |
+
else:
|
| 143 |
+
weights = np.ones(len(files))
|
| 144 |
+
print("Весы: {}".format(weights))
|
| 145 |
+
print("Имя выходного файла: {}".format(output))
|
| 146 |
+
|
| 147 |
+
data = []
|
| 148 |
+
sr = None
|
| 149 |
+
max_length = 0
|
| 150 |
+
max_channels = 0
|
| 151 |
+
|
| 152 |
+
for f in files:
|
| 153 |
+
if not os.path.isfile(f):
|
| 154 |
+
print("Не удается найти файл: {}. Check paths.".format(f))
|
| 155 |
+
exit()
|
| 156 |
+
print("Читается файл: {}".format(f))
|
| 157 |
+
wav, current_sr, _ = audio.read(i=f, sr=None, mono=False)
|
| 158 |
+
if sr is None:
|
| 159 |
+
sr = current_sr
|
| 160 |
+
elif sr != current_sr:
|
| 161 |
+
print("Частота дискретизации на всех файлах должна быть одинаковой")
|
| 162 |
+
exit()
|
| 163 |
+
|
| 164 |
+
if wav.ndim == 1:
|
| 165 |
+
channels = 1
|
| 166 |
+
length = len(wav)
|
| 167 |
+
else:
|
| 168 |
+
channels = wav.shape[0]
|
| 169 |
+
length = wav.shape[1]
|
| 170 |
+
|
| 171 |
+
max_length = max(max_length, length)
|
| 172 |
+
max_channels = max(max_channels, channels)
|
| 173 |
+
print("Форма сигнала: {} частота дискретизации: {}".format(wav.shape, sr))
|
| 174 |
+
|
| 175 |
+
for f in files:
|
| 176 |
+
wav, current_sr, _ = audio.read(i=f, sr=None, mono=False)
|
| 177 |
+
|
| 178 |
+
if wav.ndim == 1:
|
| 179 |
+
wav = np.vstack([wav, wav])
|
| 180 |
+
elif wav.shape[0] == 1:
|
| 181 |
+
wav = np.vstack([wav[0], wav[0]])
|
| 182 |
+
elif wav.shape[0] > 2:
|
| 183 |
+
wav = wav[:2, :]
|
| 184 |
+
|
| 185 |
+
if wav.shape[1] < max_length:
|
| 186 |
+
pad_width = ((0, 0), (0, max_length - wav.shape[1]))
|
| 187 |
+
wav = np.pad(wav, pad_width, mode="constant")
|
| 188 |
+
elif wav.shape[1] > max_length:
|
| 189 |
+
wav = wav[:, :max_length]
|
| 190 |
+
|
| 191 |
+
data.append(wav)
|
| 192 |
+
|
| 193 |
+
data = np.array(data)
|
| 194 |
+
res = average_waveforms(data, weights, ensemble_type)
|
| 195 |
+
print("Форма результата: {}".format(res.shape))
|
| 196 |
+
|
| 197 |
+
output_wav = f"{output}_orig.wav"
|
| 198 |
+
output = f"{output}.{out_format}"
|
| 199 |
+
|
| 200 |
+
output = audio.write(o=output, array=res.T, sr=sr, of=out_format, br="320k")
|
| 201 |
+
if add_wav:
|
| 202 |
+
output_wav = audio.write(o=output_wav, array=res.T, sr=sr, of="wav")
|
| 203 |
+
return output, output_wav
|
| 204 |
+
|
| 205 |
+
else:
|
| 206 |
+
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mvsepless/infer.py
CHANGED
|
@@ -20,17 +20,13 @@ from namer import Namer
|
|
| 20 |
namer = Namer()
|
| 21 |
audio = Audio()
|
| 22 |
|
| 23 |
-
from infer_utils import
|
| 24 |
-
prefer_target_instrument,
|
| 25 |
-
demix,
|
| 26 |
-
get_model_from_config
|
| 27 |
-
)
|
| 28 |
|
| 29 |
|
| 30 |
def normalize_peak(audio, peak):
|
| 31 |
current_peak = np.max(np.abs(audio))
|
| 32 |
if current_peak == 0:
|
| 33 |
-
return audio
|
| 34 |
scale_factor = peak / current_peak
|
| 35 |
return audio * scale_factor
|
| 36 |
|
|
@@ -57,10 +53,18 @@ def cleanup_model(model):
|
|
| 57 |
torch.cuda.ipc_collect()
|
| 58 |
|
| 59 |
gc.collect()
|
| 60 |
-
sys.stdout.write(
|
|
|
|
|
|
|
|
|
|
| 61 |
sys.stdout.flush()
|
| 62 |
except Exception as e:
|
| 63 |
-
sys.stdout.write(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
sys.stdout.flush()
|
| 65 |
|
| 66 |
|
|
@@ -87,22 +91,32 @@ def once_inference(
|
|
| 87 |
model_id: int = 0,
|
| 88 |
):
|
| 89 |
results = []
|
| 90 |
-
sys.stdout.write(json.dumps({"reading": path}, ensure_ascii=False) +
|
| 91 |
sys.stdout.flush()
|
| 92 |
-
sys.stdout.write(
|
|
|
|
|
|
|
| 93 |
sys.stdout.flush()
|
| 94 |
-
sys.stdout.write(
|
|
|
|
|
|
|
| 95 |
sys.stdout.flush()
|
| 96 |
|
| 97 |
if config.training.target_instrument is not None:
|
| 98 |
-
sys.stdout.write(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
sys.stdout.flush()
|
| 100 |
|
| 101 |
try:
|
| 102 |
mix, sr, _ = audio.read(i=path, sr=sample_rate, mono=False)
|
| 103 |
except Exception as e:
|
| 104 |
error_msg = f"Не удалось прочитать аудио: {path}\nОшибка: {e}"
|
| 105 |
-
sys.stdout.write(json.dumps({"error": error_msg}, ensure_ascii=False) +
|
| 106 |
sys.stdout.flush()
|
| 107 |
return results
|
| 108 |
|
|
@@ -128,13 +142,19 @@ def once_inference(
|
|
| 128 |
|
| 129 |
full_result.append(waveforms)
|
| 130 |
except Exception as e:
|
| 131 |
-
sys.stdout.write(
|
|
|
|
|
|
|
|
|
|
| 132 |
sys.stdout.flush()
|
| 133 |
del m
|
| 134 |
gc.collect()
|
| 135 |
|
| 136 |
if not full_result:
|
| 137 |
-
sys.stdout.write(
|
|
|
|
|
|
|
|
|
|
| 138 |
sys.stdout.flush()
|
| 139 |
return results
|
| 140 |
|
|
@@ -151,9 +171,7 @@ def once_inference(
|
|
| 151 |
for el in waveforms:
|
| 152 |
waveforms[el] /= len(full_result)
|
| 153 |
|
| 154 |
-
if
|
| 155 |
-
extract_instrumental and config.training.target_instrument is not None
|
| 156 |
-
): # Если включен "Extract Instrumental / Извлечь инструментал" и найден целевой инструмент
|
| 157 |
second_stem = [
|
| 158 |
s
|
| 159 |
for s in config.training.instruments
|
|
@@ -169,8 +187,7 @@ def once_inference(
|
|
| 169 |
extract_instrumental
|
| 170 |
and selected_instruments
|
| 171 |
and config.training.target_instrument is None
|
| 172 |
-
):
|
| 173 |
-
|
| 174 |
|
| 175 |
all_instruments = config.training.instruments
|
| 176 |
if len(all_instruments) > 2:
|
|
@@ -178,21 +195,19 @@ def once_inference(
|
|
| 178 |
waveforms["inverted -"] = mix_orig.copy()
|
| 179 |
for instr in instruments:
|
| 180 |
if instr in waveforms:
|
| 181 |
-
waveforms["inverted -"] -= waveforms[
|
| 182 |
-
instr
|
| 183 |
-
] # стем "inverted -": вычитание выбранного стема из оригинального сигнала (не всегда хорошо)
|
| 184 |
|
| 185 |
if "inverted -" not in instruments:
|
| 186 |
instruments.append("inverted -")
|
| 187 |
|
| 188 |
-
unselected_stems = [
|
|
|
|
|
|
|
| 189 |
if unselected_stems:
|
| 190 |
waveforms["inverted +"] = np.zeros_like(mix_orig)
|
| 191 |
for stem in unselected_stems:
|
| 192 |
if stem in waveforms:
|
| 193 |
-
waveforms["inverted +"] += waveforms[
|
| 194 |
-
stem
|
| 195 |
-
] # стем "inverted +": сложение не выбранных инструментов в один стем
|
| 196 |
if "inverted +" not in instruments:
|
| 197 |
instruments.append("inverted +")
|
| 198 |
|
|
@@ -216,9 +231,7 @@ def once_inference(
|
|
| 216 |
):
|
| 217 |
|
| 218 |
waveforms["instrumental -"] = mix_orig.copy()
|
| 219 |
-
waveforms["instrumental -"] -= waveforms[
|
| 220 |
-
"vocals"
|
| 221 |
-
] # стем "inverted -": вычитание выбранного стема из оригинального сигнала (не всегда хорошо)
|
| 222 |
|
| 223 |
if "instrumental -" not in instruments:
|
| 224 |
instruments.append("instrumental -")
|
|
@@ -229,9 +242,7 @@ def once_inference(
|
|
| 229 |
waveforms["instrumental +"] = np.zeros_like(mix_orig)
|
| 230 |
for stem in non_vocal_stems:
|
| 231 |
if stem in waveforms:
|
| 232 |
-
waveforms["instrumental +"] += waveforms[
|
| 233 |
-
stem
|
| 234 |
-
] # стем "inverted +": сложение не выбранных инструментов в один стем
|
| 235 |
if "instrumental +" not in instruments:
|
| 236 |
instruments.append("instrumental +")
|
| 237 |
|
|
@@ -249,25 +260,40 @@ def once_inference(
|
|
| 249 |
estimates = estimates * std + mean
|
| 250 |
|
| 251 |
file_name = os.path.splitext(os.path.basename(path))[0]
|
| 252 |
-
file_name_shorted = namer.short_input_name_template(
|
|
|
|
|
|
|
| 253 |
custom_name = namer.template(
|
| 254 |
-
template,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
)
|
| 256 |
output_path = os.path.join(store_dir, f"{custom_name}.{output_format}")
|
| 257 |
|
| 258 |
-
sys.stdout.write(
|
|
|
|
|
|
|
| 259 |
sys.stdout.flush()
|
| 260 |
|
| 261 |
output_path = audio.write(
|
| 262 |
-
o=output_path,
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
-
results.append(
|
| 266 |
-
(instr, output_path)
|
| 267 |
-
) # запись информации о разделении: (название стема, путь к файлу)
|
| 268 |
del estimates
|
| 269 |
except Exception as e:
|
| 270 |
-
sys.stdout.write(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
sys.stdout.flush()
|
| 272 |
gc.collect()
|
| 273 |
|
|
@@ -307,7 +333,15 @@ def run_inference(
|
|
| 307 |
instruments = prefer_target_instrument(config)
|
| 308 |
|
| 309 |
if config.training.target_instrument is not None:
|
| 310 |
-
sys.stdout.write(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
sys.stdout.flush()
|
| 312 |
else:
|
| 313 |
if selected_instruments is not None and selected_instruments != []:
|
|
@@ -315,7 +349,10 @@ def run_inference(
|
|
| 315 |
instr for instr in instruments if instr in selected_instruments
|
| 316 |
]
|
| 317 |
if verbose:
|
| 318 |
-
sys.stdout.write(
|
|
|
|
|
|
|
|
|
|
| 319 |
sys.stdout.flush()
|
| 320 |
|
| 321 |
os.makedirs(store_dir, exist_ok=True)
|
|
@@ -345,9 +382,11 @@ def run_inference(
|
|
| 345 |
|
| 346 |
time.sleep(1)
|
| 347 |
time_taken = time.time() - start_time
|
| 348 |
-
sys.stdout.write(
|
|
|
|
|
|
|
| 349 |
sys.stdout.flush()
|
| 350 |
-
sys.stdout.write(json.dumps({"done": results}, ensure_ascii=False) +
|
| 351 |
sys.stdout.flush()
|
| 352 |
return results
|
| 353 |
|
|
@@ -357,7 +396,15 @@ def load_model(model_type, config_path, start_check_point, device_ids, force_cpu
|
|
| 357 |
if force_cpu:
|
| 358 |
device = "cpu"
|
| 359 |
elif torch.cuda.is_available():
|
| 360 |
-
sys.stdout.write(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
sys.stdout.flush()
|
| 362 |
device = "cuda"
|
| 363 |
|
|
@@ -372,7 +419,7 @@ def load_model(model_type, config_path, start_check_point, device_ids, force_cpu
|
|
| 372 |
elif torch.backends.mps.is_available():
|
| 373 |
device = "mps"
|
| 374 |
|
| 375 |
-
sys.stdout.write(json.dumps({"device": device}, ensure_ascii=False) +
|
| 376 |
sys.stdout.flush()
|
| 377 |
|
| 378 |
model_load_start_time = time.time()
|
|
@@ -382,29 +429,30 @@ def load_model(model_type, config_path, start_check_point, device_ids, force_cpu
|
|
| 382 |
|
| 383 |
if model_type == "vr":
|
| 384 |
model.load_checkpoint(start_check_point, device)
|
| 385 |
-
model.settings(
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
|
|
|
| 390 |
primary_stem=config.training.instruments[0],
|
| 391 |
-
secondary_stem=config.training.instruments[1]
|
| 392 |
)
|
| 393 |
return model, config, device
|
| 394 |
|
| 395 |
elif model_type == "mdxnet":
|
| 396 |
if start_check_point != "":
|
| 397 |
-
sys.stdout.write(json.dumps({"checkpoint": start_check_point}) +
|
| 398 |
sys.stdout.flush()
|
| 399 |
model.init_onnx_session(start_check_point, device)
|
| 400 |
|
| 401 |
return model, config, device
|
| 402 |
-
|
| 403 |
else:
|
| 404 |
if start_check_point != "":
|
| 405 |
-
sys.stdout.write(json.dumps({"checkpoint": start_check_point}) +
|
| 406 |
sys.stdout.flush()
|
| 407 |
-
|
| 408 |
if model_type in ["htdemucs", "apollo"]:
|
| 409 |
state_dict = torch.load(
|
| 410 |
start_check_point, map_location=device, weights_only=False
|
|
@@ -428,7 +476,10 @@ def load_model(model_type, config_path, start_check_point, device_ids, force_cpu
|
|
| 428 |
except RuntimeError:
|
| 429 |
model.load_state_dict(state_dict, strict=False)
|
| 430 |
|
| 431 |
-
sys.stdout.write(
|
|
|
|
|
|
|
|
|
|
| 432 |
sys.stdout.flush()
|
| 433 |
|
| 434 |
if (
|
|
@@ -443,7 +494,10 @@ def load_model(model_type, config_path, start_check_point, device_ids, force_cpu
|
|
| 443 |
|
| 444 |
load_time = time.time() - model_load_start_time
|
| 445 |
|
| 446 |
-
sys.stdout.write(
|
|
|
|
|
|
|
|
|
|
| 447 |
sys.stdout.flush()
|
| 448 |
|
| 449 |
return model, config, device
|
|
@@ -503,13 +557,11 @@ def parse_args():
|
|
| 503 |
description="Модифицированный Music-Source-Separation-Training для разделения аудио на источники"
|
| 504 |
)
|
| 505 |
|
| 506 |
-
# Обязательные аргументы
|
| 507 |
parser.add_argument("--input", type=str, help="Путь к входному файлу или папке")
|
| 508 |
parser.add_argument(
|
| 509 |
"--store_dir", type=str, required=True, help="Путь для сохранения результатов"
|
| 510 |
)
|
| 511 |
|
| 512 |
-
# Основные параметры модели
|
| 513 |
parser.add_argument(
|
| 514 |
"--model_type",
|
| 515 |
type=str,
|
|
@@ -523,7 +575,7 @@ def parse_args():
|
|
| 523 |
"bandit",
|
| 524 |
"bandit_v2",
|
| 525 |
"mdxnet",
|
| 526 |
-
"vr"
|
| 527 |
],
|
| 528 |
help="Тип модели (по умолчанию: htdemucs)",
|
| 529 |
)
|
|
@@ -537,7 +589,6 @@ def parse_args():
|
|
| 537 |
"--start_check_point", type=str, required=True, help="Путь к чекпоинту модели"
|
| 538 |
)
|
| 539 |
|
| 540 |
-
# Параметры вывода
|
| 541 |
parser.add_argument(
|
| 542 |
"--output_format",
|
| 543 |
type=str,
|
|
@@ -620,4 +671,4 @@ def main():
|
|
| 620 |
|
| 621 |
|
| 622 |
if __name__ == "__main__":
|
| 623 |
-
main()
|
|
|
|
| 20 |
namer = Namer()
|
| 21 |
audio = Audio()
|
| 22 |
|
| 23 |
+
from infer_utils import prefer_target_instrument, demix, get_model_from_config
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
def normalize_peak(audio, peak):
|
| 27 |
current_peak = np.max(np.abs(audio))
|
| 28 |
if current_peak == 0:
|
| 29 |
+
return audio
|
| 30 |
scale_factor = peak / current_peak
|
| 31 |
return audio * scale_factor
|
| 32 |
|
|
|
|
| 53 |
torch.cuda.ipc_collect()
|
| 54 |
|
| 55 |
gc.collect()
|
| 56 |
+
sys.stdout.write(
|
| 57 |
+
json.dumps({"cleanup": "Модель выгружена из памяти"}, ensure_ascii=False)
|
| 58 |
+
+ "\n"
|
| 59 |
+
)
|
| 60 |
sys.stdout.flush()
|
| 61 |
except Exception as e:
|
| 62 |
+
sys.stdout.write(
|
| 63 |
+
json.dumps(
|
| 64 |
+
{"error": f"Ошибка при выгрузке модели: {str(e)}"}, ensure_ascii=False
|
| 65 |
+
)
|
| 66 |
+
+ "\n"
|
| 67 |
+
)
|
| 68 |
sys.stdout.flush()
|
| 69 |
|
| 70 |
|
|
|
|
| 91 |
model_id: int = 0,
|
| 92 |
):
|
| 93 |
results = []
|
| 94 |
+
sys.stdout.write(json.dumps({"reading": path}, ensure_ascii=False) + "\n")
|
| 95 |
sys.stdout.flush()
|
| 96 |
+
sys.stdout.write(
|
| 97 |
+
json.dumps({"selected_stems": selected_instruments}, ensure_ascii=False) + "\n"
|
| 98 |
+
)
|
| 99 |
sys.stdout.flush()
|
| 100 |
+
sys.stdout.write(
|
| 101 |
+
json.dumps({"stems": list(instruments)}, ensure_ascii=False) + "\n"
|
| 102 |
+
)
|
| 103 |
sys.stdout.flush()
|
| 104 |
|
| 105 |
if config.training.target_instrument is not None:
|
| 106 |
+
sys.stdout.write(
|
| 107 |
+
json.dumps(
|
| 108 |
+
{"target_instrument": config.training.target_instrument},
|
| 109 |
+
ensure_ascii=False,
|
| 110 |
+
)
|
| 111 |
+
+ "\n"
|
| 112 |
+
)
|
| 113 |
sys.stdout.flush()
|
| 114 |
|
| 115 |
try:
|
| 116 |
mix, sr, _ = audio.read(i=path, sr=sample_rate, mono=False)
|
| 117 |
except Exception as e:
|
| 118 |
error_msg = f"Не удалось прочитать аудио: {path}\nОшибка: {e}"
|
| 119 |
+
sys.stdout.write(json.dumps({"error": error_msg}, ensure_ascii=False) + "\n")
|
| 120 |
sys.stdout.flush()
|
| 121 |
return results
|
| 122 |
|
|
|
|
| 142 |
|
| 143 |
full_result.append(waveforms)
|
| 144 |
except Exception as e:
|
| 145 |
+
sys.stdout.write(
|
| 146 |
+
json.dumps({"error": f"Ошибка при демиксе: {e}"}, ensure_ascii=False)
|
| 147 |
+
+ "\n"
|
| 148 |
+
)
|
| 149 |
sys.stdout.flush()
|
| 150 |
del m
|
| 151 |
gc.collect()
|
| 152 |
|
| 153 |
if not full_result:
|
| 154 |
+
sys.stdout.write(
|
| 155 |
+
json.dumps({"error": "Пустой результат демикса."}, ensure_ascii=False)
|
| 156 |
+
+ "\n"
|
| 157 |
+
)
|
| 158 |
sys.stdout.flush()
|
| 159 |
return results
|
| 160 |
|
|
|
|
| 171 |
for el in waveforms:
|
| 172 |
waveforms[el] /= len(full_result)
|
| 173 |
|
| 174 |
+
if extract_instrumental and config.training.target_instrument is not None:
|
|
|
|
|
|
|
| 175 |
second_stem = [
|
| 176 |
s
|
| 177 |
for s in config.training.instruments
|
|
|
|
| 187 |
extract_instrumental
|
| 188 |
and selected_instruments
|
| 189 |
and config.training.target_instrument is None
|
| 190 |
+
):
|
|
|
|
| 191 |
|
| 192 |
all_instruments = config.training.instruments
|
| 193 |
if len(all_instruments) > 2:
|
|
|
|
| 195 |
waveforms["inverted -"] = mix_orig.copy()
|
| 196 |
for instr in instruments:
|
| 197 |
if instr in waveforms:
|
| 198 |
+
waveforms["inverted -"] -= waveforms[instr]
|
|
|
|
|
|
|
| 199 |
|
| 200 |
if "inverted -" not in instruments:
|
| 201 |
instruments.append("inverted -")
|
| 202 |
|
| 203 |
+
unselected_stems = [
|
| 204 |
+
s for s in all_instruments if s not in selected_instruments
|
| 205 |
+
]
|
| 206 |
if unselected_stems:
|
| 207 |
waveforms["inverted +"] = np.zeros_like(mix_orig)
|
| 208 |
for stem in unselected_stems:
|
| 209 |
if stem in waveforms:
|
| 210 |
+
waveforms["inverted +"] += waveforms[stem]
|
|
|
|
|
|
|
| 211 |
if "inverted +" not in instruments:
|
| 212 |
instruments.append("inverted +")
|
| 213 |
|
|
|
|
| 231 |
):
|
| 232 |
|
| 233 |
waveforms["instrumental -"] = mix_orig.copy()
|
| 234 |
+
waveforms["instrumental -"] -= waveforms["vocals"]
|
|
|
|
|
|
|
| 235 |
|
| 236 |
if "instrumental -" not in instruments:
|
| 237 |
instruments.append("instrumental -")
|
|
|
|
| 242 |
waveforms["instrumental +"] = np.zeros_like(mix_orig)
|
| 243 |
for stem in non_vocal_stems:
|
| 244 |
if stem in waveforms:
|
| 245 |
+
waveforms["instrumental +"] += waveforms[stem]
|
|
|
|
|
|
|
| 246 |
if "instrumental +" not in instruments:
|
| 247 |
instruments.append("instrumental +")
|
| 248 |
|
|
|
|
| 260 |
estimates = estimates * std + mean
|
| 261 |
|
| 262 |
file_name = os.path.splitext(os.path.basename(path))[0]
|
| 263 |
+
file_name_shorted = namer.short_input_name_template(
|
| 264 |
+
template, STEM=instr, MODEL=model_name, ID=model_id, NAME=file_name
|
| 265 |
+
)
|
| 266 |
custom_name = namer.template(
|
| 267 |
+
template,
|
| 268 |
+
STEM=instr,
|
| 269 |
+
MODEL=model_name,
|
| 270 |
+
ID=model_id,
|
| 271 |
+
NAME=file_name_shorted,
|
| 272 |
)
|
| 273 |
output_path = os.path.join(store_dir, f"{custom_name}.{output_format}")
|
| 274 |
|
| 275 |
+
sys.stdout.write(
|
| 276 |
+
json.dumps({"writing": output_path}, ensure_ascii=False) + "\n"
|
| 277 |
+
)
|
| 278 |
sys.stdout.flush()
|
| 279 |
|
| 280 |
output_path = audio.write(
|
| 281 |
+
o=output_path,
|
| 282 |
+
array=estimates,
|
| 283 |
+
sr=sr,
|
| 284 |
+
of=output_format,
|
| 285 |
+
br=output_bitrate,
|
| 286 |
+
)
|
| 287 |
|
| 288 |
+
results.append((instr, output_path))
|
|
|
|
|
|
|
| 289 |
del estimates
|
| 290 |
except Exception as e:
|
| 291 |
+
sys.stdout.write(
|
| 292 |
+
json.dumps(
|
| 293 |
+
{"error": f"Ошибка при обработке {instr}: {e}"}, ensure_ascii=False
|
| 294 |
+
)
|
| 295 |
+
+ "\n"
|
| 296 |
+
)
|
| 297 |
sys.stdout.flush()
|
| 298 |
gc.collect()
|
| 299 |
|
|
|
|
| 333 |
instruments = prefer_target_instrument(config)
|
| 334 |
|
| 335 |
if config.training.target_instrument is not None:
|
| 336 |
+
sys.stdout.write(
|
| 337 |
+
json.dumps(
|
| 338 |
+
{
|
| 339 |
+
"info": "Целевой инструмент найден в конфигурации модели. Выбранные стемы будут проигнорированы."
|
| 340 |
+
},
|
| 341 |
+
ensure_ascii=False,
|
| 342 |
+
)
|
| 343 |
+
+ "\n"
|
| 344 |
+
)
|
| 345 |
sys.stdout.flush()
|
| 346 |
else:
|
| 347 |
if selected_instruments is not None and selected_instruments != []:
|
|
|
|
| 349 |
instr for instr in instruments if instr in selected_instruments
|
| 350 |
]
|
| 351 |
if verbose:
|
| 352 |
+
sys.stdout.write(
|
| 353 |
+
json.dumps({"selected_stems": instruments}, ensure_ascii=False)
|
| 354 |
+
+ "\n"
|
| 355 |
+
)
|
| 356 |
sys.stdout.flush()
|
| 357 |
|
| 358 |
os.makedirs(store_dir, exist_ok=True)
|
|
|
|
| 382 |
|
| 383 |
time.sleep(1)
|
| 384 |
time_taken = time.time() - start_time
|
| 385 |
+
sys.stdout.write(
|
| 386 |
+
json.dumps({"time": f"{time_taken:.2f} сек."}, ensure_ascii=False) + "\n"
|
| 387 |
+
)
|
| 388 |
sys.stdout.flush()
|
| 389 |
+
sys.stdout.write(json.dumps({"done": results}, ensure_ascii=False) + "\n")
|
| 390 |
sys.stdout.flush()
|
| 391 |
return results
|
| 392 |
|
|
|
|
| 396 |
if force_cpu:
|
| 397 |
device = "cpu"
|
| 398 |
elif torch.cuda.is_available():
|
| 399 |
+
sys.stdout.write(
|
| 400 |
+
json.dumps(
|
| 401 |
+
{
|
| 402 |
+
"info": "Разделение выполняется на ядрах CUDA. Для выполнения на процессоре установите force_cpu=True."
|
| 403 |
+
},
|
| 404 |
+
ensure_ascii=False,
|
| 405 |
+
)
|
| 406 |
+
+ "\n"
|
| 407 |
+
)
|
| 408 |
sys.stdout.flush()
|
| 409 |
device = "cuda"
|
| 410 |
|
|
|
|
| 419 |
elif torch.backends.mps.is_available():
|
| 420 |
device = "mps"
|
| 421 |
|
| 422 |
+
sys.stdout.write(json.dumps({"device": device}, ensure_ascii=False) + "\n")
|
| 423 |
sys.stdout.flush()
|
| 424 |
|
| 425 |
model_load_start_time = time.time()
|
|
|
|
| 429 |
|
| 430 |
if model_type == "vr":
|
| 431 |
model.load_checkpoint(start_check_point, device)
|
| 432 |
+
model.settings(
|
| 433 |
+
enable_post_process=False,
|
| 434 |
+
post_process_threshold=config.inference.post_process_threshold,
|
| 435 |
+
batch_size=config.inference.batch_size,
|
| 436 |
+
window_size=config.inference.window_size,
|
| 437 |
+
high_end_process=config.inference.high_end_process,
|
| 438 |
primary_stem=config.training.instruments[0],
|
| 439 |
+
secondary_stem=config.training.instruments[1],
|
| 440 |
)
|
| 441 |
return model, config, device
|
| 442 |
|
| 443 |
elif model_type == "mdxnet":
|
| 444 |
if start_check_point != "":
|
| 445 |
+
sys.stdout.write(json.dumps({"checkpoint": start_check_point}) + "\n")
|
| 446 |
sys.stdout.flush()
|
| 447 |
model.init_onnx_session(start_check_point, device)
|
| 448 |
|
| 449 |
return model, config, device
|
| 450 |
+
|
| 451 |
else:
|
| 452 |
if start_check_point != "":
|
| 453 |
+
sys.stdout.write(json.dumps({"checkpoint": start_check_point}) + "\n")
|
| 454 |
sys.stdout.flush()
|
| 455 |
+
|
| 456 |
if model_type in ["htdemucs", "apollo"]:
|
| 457 |
state_dict = torch.load(
|
| 458 |
start_check_point, map_location=device, weights_only=False
|
|
|
|
| 476 |
except RuntimeError:
|
| 477 |
model.load_state_dict(state_dict, strict=False)
|
| 478 |
|
| 479 |
+
sys.stdout.write(
|
| 480 |
+
json.dumps({"stems": list(config.training.instruments)}, ensure_ascii=False)
|
| 481 |
+
+ "\n"
|
| 482 |
+
)
|
| 483 |
sys.stdout.flush()
|
| 484 |
|
| 485 |
if (
|
|
|
|
| 494 |
|
| 495 |
load_time = time.time() - model_load_start_time
|
| 496 |
|
| 497 |
+
sys.stdout.write(
|
| 498 |
+
json.dumps({"model_load_time": f"{load_time:.2f} сек."}, ensure_ascii=False)
|
| 499 |
+
+ "\n"
|
| 500 |
+
)
|
| 501 |
sys.stdout.flush()
|
| 502 |
|
| 503 |
return model, config, device
|
|
|
|
| 557 |
description="Модифицированный Music-Source-Separation-Training для разделения аудио на источники"
|
| 558 |
)
|
| 559 |
|
|
|
|
| 560 |
parser.add_argument("--input", type=str, help="Путь к входному файлу или папке")
|
| 561 |
parser.add_argument(
|
| 562 |
"--store_dir", type=str, required=True, help="Путь для сохранения результатов"
|
| 563 |
)
|
| 564 |
|
|
|
|
| 565 |
parser.add_argument(
|
| 566 |
"--model_type",
|
| 567 |
type=str,
|
|
|
|
| 575 |
"bandit",
|
| 576 |
"bandit_v2",
|
| 577 |
"mdxnet",
|
| 578 |
+
"vr",
|
| 579 |
],
|
| 580 |
help="Тип модели (по умолчанию: htdemucs)",
|
| 581 |
)
|
|
|
|
| 589 |
"--start_check_point", type=str, required=True, help="Путь к чекпоинту модели"
|
| 590 |
)
|
| 591 |
|
|
|
|
| 592 |
parser.add_argument(
|
| 593 |
"--output_format",
|
| 594 |
type=str,
|
|
|
|
| 671 |
|
| 672 |
|
| 673 |
if __name__ == "__main__":
|
| 674 |
+
main()
|
mvsepless/infer_utils.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# coding: utf-8
|
| 2 |
__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
|
| 3 |
|
| 4 |
import sys
|
|
@@ -15,9 +14,6 @@ from typing import Dict, List, Tuple, Any, List, Optional
|
|
| 15 |
|
| 16 |
|
| 17 |
def load_config(model_type: str, config_path: str) -> Any:
|
| 18 |
-
"""
|
| 19 |
-
Load the configuration from the specified path based on the model type.
|
| 20 |
-
"""
|
| 21 |
try:
|
| 22 |
with open(config_path, "r") as f:
|
| 23 |
if model_type == "htdemucs":
|
|
@@ -32,9 +28,6 @@ def load_config(model_type: str, config_path: str) -> Any:
|
|
| 32 |
|
| 33 |
|
| 34 |
def get_model_from_config(model_type: str, config_path: str) -> Tuple:
|
| 35 |
-
"""
|
| 36 |
-
Load the model specified by the model type and configuration file.
|
| 37 |
-
"""
|
| 38 |
config = load_config(model_type, config_path)
|
| 39 |
|
| 40 |
if model_type == "mdx23c":
|
|
@@ -47,11 +40,9 @@ def get_model_from_config(model_type: str, config_path: str) -> Tuple:
|
|
| 47 |
|
| 48 |
model = MDXNet(**dict(config.model))
|
| 49 |
|
| 50 |
-
# В функции get_model_from_config добавьте:
|
| 51 |
-
|
| 52 |
elif model_type == "vr":
|
| 53 |
from models.vr_arch import VRNet
|
| 54 |
-
|
| 55 |
model = VRNet(**dict(config.model))
|
| 56 |
|
| 57 |
elif model_type == "htdemucs":
|
|
@@ -60,7 +51,7 @@ def get_model_from_config(model_type: str, config_path: str) -> Tuple:
|
|
| 60 |
model = get_model(config)
|
| 61 |
|
| 62 |
elif model_type == "mel_band_roformer":
|
| 63 |
-
if hasattr(config, "windowed"):
|
| 64 |
from models.windowed_roformer.model import MelBandRoformerWSA
|
| 65 |
|
| 66 |
model = MelBandRoformerWSA(**dict(config.model))
|
|
@@ -114,10 +105,8 @@ def get_model_from_config(model_type: str, config_path: str) -> Tuple:
|
|
| 114 |
|
| 115 |
return model, config
|
| 116 |
|
|
|
|
| 117 |
def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor:
|
| 118 |
-
"""
|
| 119 |
-
Generate a windowing array with a linear fade-in at the beginning and a fade-out at the end.
|
| 120 |
-
"""
|
| 121 |
fadein = torch.linspace(0, 1, fade_size)
|
| 122 |
fadeout = torch.linspace(1, 0, fade_size)
|
| 123 |
|
|
@@ -126,6 +115,7 @@ def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor:
|
|
| 126 |
window[:fade_size] = fadein
|
| 127 |
return window
|
| 128 |
|
|
|
|
| 129 |
def demix_mdxnet(
|
| 130 |
config: Any,
|
| 131 |
model: Any,
|
|
@@ -133,18 +123,17 @@ def demix_mdxnet(
|
|
| 133 |
device: torch.device,
|
| 134 |
pbar: bool = False,
|
| 135 |
) -> Dict[str, np.ndarray]:
|
| 136 |
-
"""
|
| 137 |
-
MDX-Net specific demixing function с поддержкой overlap
|
| 138 |
-
"""
|
| 139 |
mix_tensor = torch.tensor(mix, dtype=torch.float32)
|
| 140 |
inv_mix_tensor = torch.tensor(-mix, dtype=torch.float32)
|
| 141 |
-
|
| 142 |
num_overlap = config.inference.num_overlap
|
| 143 |
denoise = config.inference.denoise
|
| 144 |
stem_name = model.primary_stem
|
| 145 |
if denoise:
|
| 146 |
processed_wav = model.process_wave(mix_tensor, device, num_overlap, pbar=pbar)
|
| 147 |
-
inv_processed_wav = model.process_wave(
|
|
|
|
|
|
|
| 148 |
result = processed_wav.cpu().numpy()
|
| 149 |
inv_result = inv_processed_wav.cpu().numpy()
|
| 150 |
result_separation = (result + -inv_result) * 0.5
|
|
@@ -152,9 +141,12 @@ def demix_mdxnet(
|
|
| 152 |
processed_wav = model.process_wave(mix_tensor, device, num_overlap, pbar=pbar)
|
| 153 |
result_separation = processed_wav.cpu().numpy()
|
| 154 |
|
| 155 |
-
result_separation = np.nan_to_num(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
-
return {stem_name: result_separation} # Перемещаем на CPU для возврата
|
| 158 |
|
| 159 |
def demix_vr(
|
| 160 |
config: Any,
|
|
@@ -163,12 +155,10 @@ def demix_vr(
|
|
| 163 |
device: torch.device,
|
| 164 |
pbar: bool = False,
|
| 165 |
) -> Dict[str, np.ndarray]:
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
# Convert to tensor and add batch dimension
|
| 171 |
-
return model.demix(mix, config.audio.sample_rate, device, config.inference.aggression)
|
| 172 |
|
| 173 |
def demix_demucs(config, model, mix, device, pbar=False):
|
| 174 |
mix = torch.tensor(mix, dtype=torch.float32)
|
|
@@ -176,8 +166,8 @@ def demix_demucs(config, model, mix, device, pbar=False):
|
|
| 176 |
num_instruments = len(config.training.instruments)
|
| 177 |
num_overlap = config.inference.num_overlap
|
| 178 |
step = chunk_size // num_overlap
|
| 179 |
-
fade_size = chunk_size // 10
|
| 180 |
-
windowing_array = _getWindowingArray(chunk_size, fade_size)
|
| 181 |
|
| 182 |
batch_size = config.inference.batch_size
|
| 183 |
use_amp = getattr(config.training, "use_amp", True)
|
|
@@ -209,9 +199,9 @@ def demix_demucs(config, model, mix, device, pbar=False):
|
|
| 209 |
x = model(arr)
|
| 210 |
|
| 211 |
window = windowing_array.clone()
|
| 212 |
-
if i - step == 0:
|
| 213 |
window[:fade_size] = 1
|
| 214 |
-
elif i >= mix.shape[1]:
|
| 215 |
window[-fade_size:] = 1
|
| 216 |
|
| 217 |
for j, (start, seg_len) in enumerate(batch_locations):
|
|
@@ -220,10 +210,14 @@ def demix_demucs(config, model, mix, device, pbar=False):
|
|
| 220 |
)
|
| 221 |
counter[..., start : start + seg_len] += window[..., :seg_len]
|
| 222 |
|
| 223 |
-
# Output progress
|
| 224 |
processed = min(i, mix.shape[1])
|
| 225 |
total = mix.shape[1]
|
| 226 |
-
sys.stdout.write(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
sys.stdout.flush()
|
| 228 |
|
| 229 |
batch_data.clear()
|
|
@@ -239,6 +233,7 @@ def demix_demucs(config, model, mix, device, pbar=False):
|
|
| 239 |
instruments = config.training.instruments
|
| 240 |
return {k: v for k, v in zip(instruments, estimated_sources)}
|
| 241 |
|
|
|
|
| 242 |
def demix_generic(
|
| 243 |
config: ConfigDict,
|
| 244 |
model: torch.nn.Module,
|
|
@@ -246,9 +241,6 @@ def demix_generic(
|
|
| 246 |
device: torch.device,
|
| 247 |
pbar: bool = False,
|
| 248 |
) -> Dict[str, np.ndarray]:
|
| 249 |
-
"""
|
| 250 |
-
Generic demixing function for models that support chunk-based processing
|
| 251 |
-
"""
|
| 252 |
mix = torch.tensor(mix, dtype=torch.float32)
|
| 253 |
chunk_size = config.audio.chunk_size
|
| 254 |
instruments = prefer_target_instrument(config)
|
|
@@ -260,8 +252,7 @@ def demix_generic(
|
|
| 260 |
border = chunk_size - step
|
| 261 |
length_init = mix.shape[-1]
|
| 262 |
windowing_array = _getWindowingArray(chunk_size, fade_size)
|
| 263 |
-
|
| 264 |
-
# Add padding to handle edge artifacts
|
| 265 |
if length_init > 2 * border and border > 0:
|
| 266 |
mix = nn.functional.pad(mix, (border, border), mode="reflect")
|
| 267 |
|
|
@@ -270,7 +261,6 @@ def demix_generic(
|
|
| 270 |
|
| 271 |
with torch.cuda.amp.autocast(enabled=use_amp):
|
| 272 |
with torch.inference_mode():
|
| 273 |
-
# Initialize result and counter tensors
|
| 274 |
req_shape = (num_instruments,) + mix.shape
|
| 275 |
result = torch.zeros(req_shape, dtype=torch.float32)
|
| 276 |
counter = torch.zeros(req_shape, dtype=torch.float32)
|
|
@@ -280,10 +270,9 @@ def demix_generic(
|
|
| 280 |
batch_locations = []
|
| 281 |
|
| 282 |
while i < mix.shape[1]:
|
| 283 |
-
# Extract chunk and apply padding if necessary
|
| 284 |
part = mix[:, i : i + chunk_size].to(device)
|
| 285 |
chunk_len = part.shape[-1]
|
| 286 |
-
|
| 287 |
pad_mode = "reflect" if chunk_len > chunk_size // 2 else "constant"
|
| 288 |
part = nn.functional.pad(
|
| 289 |
part, (0, chunk_size - chunk_len), mode=pad_mode, value=0
|
|
@@ -293,15 +282,14 @@ def demix_generic(
|
|
| 293 |
batch_locations.append((i, chunk_len))
|
| 294 |
i += step
|
| 295 |
|
| 296 |
-
# Process batch if it's full or the end is reached
|
| 297 |
if len(batch_data) >= batch_size or i >= mix.shape[1]:
|
| 298 |
arr = torch.stack(batch_data, dim=0)
|
| 299 |
x = model(arr)
|
| 300 |
|
| 301 |
window = windowing_array.clone()
|
| 302 |
-
if i - step == 0:
|
| 303 |
window[:fade_size] = 1
|
| 304 |
-
elif i >= mix.shape[1]:
|
| 305 |
window[-fade_size:] = 1
|
| 306 |
|
| 307 |
for j, (start, seg_len) in enumerate(batch_locations):
|
|
@@ -310,27 +298,30 @@ def demix_generic(
|
|
| 310 |
)
|
| 311 |
counter[..., start : start + seg_len] += window[..., :seg_len]
|
| 312 |
|
| 313 |
-
# Output progress
|
| 314 |
processed = min(i, mix.shape[1])
|
| 315 |
total = mix.shape[1]
|
| 316 |
-
sys.stdout.write(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
sys.stdout.flush()
|
| 318 |
|
| 319 |
batch_data.clear()
|
| 320 |
batch_locations.clear()
|
| 321 |
|
| 322 |
-
# Compute final estimated sources
|
| 323 |
estimated_sources = result / counter
|
| 324 |
estimated_sources = estimated_sources.cpu().numpy()
|
| 325 |
np.nan_to_num(estimated_sources, copy=False, nan=0.0)
|
| 326 |
|
| 327 |
-
# Remove padding
|
| 328 |
if length_init > 2 * border and border > 0:
|
| 329 |
estimated_sources = estimated_sources[..., border:-border]
|
| 330 |
|
| 331 |
-
# Return the result as a dictionary
|
| 332 |
return {k: v for k, v in zip(instruments, estimated_sources)}
|
| 333 |
|
|
|
|
| 334 |
def demix(
|
| 335 |
config: ConfigDict,
|
| 336 |
model: torch.nn.Module,
|
|
@@ -339,28 +330,17 @@ def demix(
|
|
| 339 |
model_type: str,
|
| 340 |
pbar: bool = False,
|
| 341 |
) -> Dict[str, np.ndarray]:
|
| 342 |
-
"""
|
| 343 |
-
Unified function for audio source separation with support for multiple processing modes.
|
| 344 |
-
"""
|
| 345 |
-
# Handle different model types
|
| 346 |
if model_type == "vr":
|
| 347 |
return demix_vr(config, model, mix, device, pbar)
|
| 348 |
elif model_type == "mdxnet":
|
| 349 |
return demix_mdxnet(config, model, mix, device, pbar)
|
| 350 |
elif model_type == "htdemucs":
|
| 351 |
-
# HTDemucs uses its own processing
|
| 352 |
return demix_demucs(config, model, mix, device, pbar)
|
| 353 |
else:
|
| 354 |
-
# Generic processing for other models
|
| 355 |
return demix_generic(config, model, mix, device, pbar)
|
| 356 |
|
| 357 |
|
| 358 |
def prefer_target_instrument(config: ConfigDict) -> List[str]:
|
| 359 |
-
"""
|
| 360 |
-
Return the list of target instruments based on the configuration.
|
| 361 |
-
If a specific target instrument is specified in the configuration,
|
| 362 |
-
it returns a list with that instrument. Otherwise, it returns the list of instruments.
|
| 363 |
-
"""
|
| 364 |
if config.training.get("target_instrument"):
|
| 365 |
return [config.training.target_instrument]
|
| 366 |
else:
|
|
@@ -370,21 +350,13 @@ def prefer_target_instrument(config: ConfigDict) -> List[str]:
|
|
| 370 |
def prefer_target_instrument_test(
|
| 371 |
config: ConfigDict, selected_instruments: Optional[List[str]] = None
|
| 372 |
) -> List[str]:
|
| 373 |
-
"""
|
| 374 |
-
Return the list of target instruments based on the configuration and selected instruments.
|
| 375 |
-
If selected_instruments is specified, returns the intersection with available instruments.
|
| 376 |
-
Otherwise, if a target instrument is specified, returns it, else returns all instruments.
|
| 377 |
-
"""
|
| 378 |
available_instruments = config.training.instruments
|
| 379 |
|
| 380 |
if selected_instruments is not None:
|
| 381 |
-
# Return only selected instruments that are available
|
| 382 |
return [
|
| 383 |
instr for instr in selected_instruments if instr in available_instruments
|
| 384 |
]
|
| 385 |
elif config.training.get("target_instrument"):
|
| 386 |
-
# Default behavior if no selection - return target instrument
|
| 387 |
return [config.training.target_instrument]
|
| 388 |
else:
|
| 389 |
-
|
| 390 |
-
return available_instruments
|
|
|
|
|
|
|
| 1 |
__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
|
| 2 |
|
| 3 |
import sys
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
def load_config(model_type: str, config_path: str) -> Any:
|
|
|
|
|
|
|
|
|
|
| 17 |
try:
|
| 18 |
with open(config_path, "r") as f:
|
| 19 |
if model_type == "htdemucs":
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
def get_model_from_config(model_type: str, config_path: str) -> Tuple:
|
|
|
|
|
|
|
|
|
|
| 31 |
config = load_config(model_type, config_path)
|
| 32 |
|
| 33 |
if model_type == "mdx23c":
|
|
|
|
| 40 |
|
| 41 |
model = MDXNet(**dict(config.model))
|
| 42 |
|
|
|
|
|
|
|
| 43 |
elif model_type == "vr":
|
| 44 |
from models.vr_arch import VRNet
|
| 45 |
+
|
| 46 |
model = VRNet(**dict(config.model))
|
| 47 |
|
| 48 |
elif model_type == "htdemucs":
|
|
|
|
| 51 |
model = get_model(config)
|
| 52 |
|
| 53 |
elif model_type == "mel_band_roformer":
|
| 54 |
+
if hasattr(config, "windowed"):
|
| 55 |
from models.windowed_roformer.model import MelBandRoformerWSA
|
| 56 |
|
| 57 |
model = MelBandRoformerWSA(**dict(config.model))
|
|
|
|
| 105 |
|
| 106 |
return model, config
|
| 107 |
|
| 108 |
+
|
| 109 |
def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
| 110 |
fadein = torch.linspace(0, 1, fade_size)
|
| 111 |
fadeout = torch.linspace(1, 0, fade_size)
|
| 112 |
|
|
|
|
| 115 |
window[:fade_size] = fadein
|
| 116 |
return window
|
| 117 |
|
| 118 |
+
|
| 119 |
def demix_mdxnet(
|
| 120 |
config: Any,
|
| 121 |
model: Any,
|
|
|
|
| 123 |
device: torch.device,
|
| 124 |
pbar: bool = False,
|
| 125 |
) -> Dict[str, np.ndarray]:
|
|
|
|
|
|
|
|
|
|
| 126 |
mix_tensor = torch.tensor(mix, dtype=torch.float32)
|
| 127 |
inv_mix_tensor = torch.tensor(-mix, dtype=torch.float32)
|
| 128 |
+
|
| 129 |
num_overlap = config.inference.num_overlap
|
| 130 |
denoise = config.inference.denoise
|
| 131 |
stem_name = model.primary_stem
|
| 132 |
if denoise:
|
| 133 |
processed_wav = model.process_wave(mix_tensor, device, num_overlap, pbar=pbar)
|
| 134 |
+
inv_processed_wav = model.process_wave(
|
| 135 |
+
inv_mix_tensor, device, num_overlap, pbar=pbar
|
| 136 |
+
)
|
| 137 |
result = processed_wav.cpu().numpy()
|
| 138 |
inv_result = inv_processed_wav.cpu().numpy()
|
| 139 |
result_separation = (result + -inv_result) * 0.5
|
|
|
|
| 141 |
processed_wav = model.process_wave(mix_tensor, device, num_overlap, pbar=pbar)
|
| 142 |
result_separation = processed_wav.cpu().numpy()
|
| 143 |
|
| 144 |
+
result_separation = np.nan_to_num(
|
| 145 |
+
result_separation, nan=0.0, posinf=0.0, neginf=0.0
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
return {stem_name: result_separation}
|
| 149 |
|
|
|
|
| 150 |
|
| 151 |
def demix_vr(
|
| 152 |
config: Any,
|
|
|
|
| 155 |
device: torch.device,
|
| 156 |
pbar: bool = False,
|
| 157 |
) -> Dict[str, np.ndarray]:
|
| 158 |
+
return model.demix(
|
| 159 |
+
mix, config.audio.sample_rate, device, config.inference.aggression
|
| 160 |
+
)
|
| 161 |
+
|
|
|
|
|
|
|
| 162 |
|
| 163 |
def demix_demucs(config, model, mix, device, pbar=False):
|
| 164 |
mix = torch.tensor(mix, dtype=torch.float32)
|
|
|
|
| 166 |
num_instruments = len(config.training.instruments)
|
| 167 |
num_overlap = config.inference.num_overlap
|
| 168 |
step = chunk_size // num_overlap
|
| 169 |
+
fade_size = chunk_size // 10
|
| 170 |
+
windowing_array = _getWindowingArray(chunk_size, fade_size)
|
| 171 |
|
| 172 |
batch_size = config.inference.batch_size
|
| 173 |
use_amp = getattr(config.training, "use_amp", True)
|
|
|
|
| 199 |
x = model(arr)
|
| 200 |
|
| 201 |
window = windowing_array.clone()
|
| 202 |
+
if i - step == 0:
|
| 203 |
window[:fade_size] = 1
|
| 204 |
+
elif i >= mix.shape[1]:
|
| 205 |
window[-fade_size:] = 1
|
| 206 |
|
| 207 |
for j, (start, seg_len) in enumerate(batch_locations):
|
|
|
|
| 210 |
)
|
| 211 |
counter[..., start : start + seg_len] += window[..., :seg_len]
|
| 212 |
|
|
|
|
| 213 |
processed = min(i, mix.shape[1])
|
| 214 |
total = mix.shape[1]
|
| 215 |
+
sys.stdout.write(
|
| 216 |
+
json.dumps(
|
| 217 |
+
{"processing": {"processed": processed, "total": total}}
|
| 218 |
+
)
|
| 219 |
+
+ "\n"
|
| 220 |
+
)
|
| 221 |
sys.stdout.flush()
|
| 222 |
|
| 223 |
batch_data.clear()
|
|
|
|
| 233 |
instruments = config.training.instruments
|
| 234 |
return {k: v for k, v in zip(instruments, estimated_sources)}
|
| 235 |
|
| 236 |
+
|
| 237 |
def demix_generic(
|
| 238 |
config: ConfigDict,
|
| 239 |
model: torch.nn.Module,
|
|
|
|
| 241 |
device: torch.device,
|
| 242 |
pbar: bool = False,
|
| 243 |
) -> Dict[str, np.ndarray]:
|
|
|
|
|
|
|
|
|
|
| 244 |
mix = torch.tensor(mix, dtype=torch.float32)
|
| 245 |
chunk_size = config.audio.chunk_size
|
| 246 |
instruments = prefer_target_instrument(config)
|
|
|
|
| 252 |
border = chunk_size - step
|
| 253 |
length_init = mix.shape[-1]
|
| 254 |
windowing_array = _getWindowingArray(chunk_size, fade_size)
|
| 255 |
+
|
|
|
|
| 256 |
if length_init > 2 * border and border > 0:
|
| 257 |
mix = nn.functional.pad(mix, (border, border), mode="reflect")
|
| 258 |
|
|
|
|
| 261 |
|
| 262 |
with torch.cuda.amp.autocast(enabled=use_amp):
|
| 263 |
with torch.inference_mode():
|
|
|
|
| 264 |
req_shape = (num_instruments,) + mix.shape
|
| 265 |
result = torch.zeros(req_shape, dtype=torch.float32)
|
| 266 |
counter = torch.zeros(req_shape, dtype=torch.float32)
|
|
|
|
| 270 |
batch_locations = []
|
| 271 |
|
| 272 |
while i < mix.shape[1]:
|
|
|
|
| 273 |
part = mix[:, i : i + chunk_size].to(device)
|
| 274 |
chunk_len = part.shape[-1]
|
| 275 |
+
|
| 276 |
pad_mode = "reflect" if chunk_len > chunk_size // 2 else "constant"
|
| 277 |
part = nn.functional.pad(
|
| 278 |
part, (0, chunk_size - chunk_len), mode=pad_mode, value=0
|
|
|
|
| 282 |
batch_locations.append((i, chunk_len))
|
| 283 |
i += step
|
| 284 |
|
|
|
|
| 285 |
if len(batch_data) >= batch_size or i >= mix.shape[1]:
|
| 286 |
arr = torch.stack(batch_data, dim=0)
|
| 287 |
x = model(arr)
|
| 288 |
|
| 289 |
window = windowing_array.clone()
|
| 290 |
+
if i - step == 0:
|
| 291 |
window[:fade_size] = 1
|
| 292 |
+
elif i >= mix.shape[1]:
|
| 293 |
window[-fade_size:] = 1
|
| 294 |
|
| 295 |
for j, (start, seg_len) in enumerate(batch_locations):
|
|
|
|
| 298 |
)
|
| 299 |
counter[..., start : start + seg_len] += window[..., :seg_len]
|
| 300 |
|
|
|
|
| 301 |
processed = min(i, mix.shape[1])
|
| 302 |
total = mix.shape[1]
|
| 303 |
+
sys.stdout.write(
|
| 304 |
+
json.dumps(
|
| 305 |
+
{"processing": {"processed": processed, "total": total}},
|
| 306 |
+
ensure_ascii=False,
|
| 307 |
+
)
|
| 308 |
+
+ "\n"
|
| 309 |
+
)
|
| 310 |
sys.stdout.flush()
|
| 311 |
|
| 312 |
batch_data.clear()
|
| 313 |
batch_locations.clear()
|
| 314 |
|
|
|
|
| 315 |
estimated_sources = result / counter
|
| 316 |
estimated_sources = estimated_sources.cpu().numpy()
|
| 317 |
np.nan_to_num(estimated_sources, copy=False, nan=0.0)
|
| 318 |
|
|
|
|
| 319 |
if length_init > 2 * border and border > 0:
|
| 320 |
estimated_sources = estimated_sources[..., border:-border]
|
| 321 |
|
|
|
|
| 322 |
return {k: v for k, v in zip(instruments, estimated_sources)}
|
| 323 |
|
| 324 |
+
|
| 325 |
def demix(
|
| 326 |
config: ConfigDict,
|
| 327 |
model: torch.nn.Module,
|
|
|
|
| 330 |
model_type: str,
|
| 331 |
pbar: bool = False,
|
| 332 |
) -> Dict[str, np.ndarray]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
if model_type == "vr":
|
| 334 |
return demix_vr(config, model, mix, device, pbar)
|
| 335 |
elif model_type == "mdxnet":
|
| 336 |
return demix_mdxnet(config, model, mix, device, pbar)
|
| 337 |
elif model_type == "htdemucs":
|
|
|
|
| 338 |
return demix_demucs(config, model, mix, device, pbar)
|
| 339 |
else:
|
|
|
|
| 340 |
return demix_generic(config, model, mix, device, pbar)
|
| 341 |
|
| 342 |
|
| 343 |
def prefer_target_instrument(config: ConfigDict) -> List[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
if config.training.get("target_instrument"):
|
| 345 |
return [config.training.target_instrument]
|
| 346 |
else:
|
|
|
|
| 350 |
def prefer_target_instrument_test(
|
| 351 |
config: ConfigDict, selected_instruments: Optional[List[str]] = None
|
| 352 |
) -> List[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
available_instruments = config.training.instruments
|
| 354 |
|
| 355 |
if selected_instruments is not None:
|
|
|
|
| 356 |
return [
|
| 357 |
instr for instr in selected_instruments if instr in available_instruments
|
| 358 |
]
|
| 359 |
elif config.training.get("target_instrument"):
|
|
|
|
| 360 |
return [config.training.target_instrument]
|
| 361 |
else:
|
| 362 |
+
return available_instruments
|
|
|
mvsepless/model_manager.py
CHANGED
|
@@ -1,609 +1,682 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import json
|
| 4 |
-
import yaml
|
| 5 |
-
from tabulate import tabulate
|
| 6 |
-
import shutil
|
| 7 |
-
from tqdm import tqdm
|
| 8 |
-
import urllib.request
|
| 9 |
-
import gdown
|
| 10 |
-
import requests
|
| 11 |
-
import zipfile
|
| 12 |
-
import tempfile
|
| 13 |
-
import secrets
|
| 14 |
-
import string
|
| 15 |
-
import argparse
|
| 16 |
-
from typing import Dict, Any
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
self.
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
"
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
"
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
"
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
"
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
"
|
| 302 |
-
"
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
"
|
| 318 |
-
"
|
| 319 |
-
"
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
"
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
def
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import yaml
|
| 5 |
+
from tabulate import tabulate
|
| 6 |
+
import shutil
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import urllib.request
|
| 9 |
+
import gdown
|
| 10 |
+
import requests
|
| 11 |
+
import zipfile
|
| 12 |
+
import tempfile
|
| 13 |
+
import secrets
|
| 14 |
+
import string
|
| 15 |
+
import argparse
|
| 16 |
+
from typing import Dict, Any
|
| 17 |
+
|
| 18 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 19 |
+
if not __package__:
|
| 20 |
+
from downloader import dw_file
|
| 21 |
+
else:
|
| 22 |
+
from .downloader import dw_file
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def generate_secure_random(length=10):
|
| 26 |
+
characters = string.ascii_letters + string.digits
|
| 27 |
+
return "".join(secrets.choice(characters) for _ in range(length))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class MvseplessModelManager:
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
models_info_path=os.path.join(script_dir, "models.json"),
|
| 34 |
+
cache_dir=os.path.join(script_dir, "mvsepless_models_cache"),
|
| 35 |
+
):
|
| 36 |
+
self.models_cache_dir = cache_dir
|
| 37 |
+
self.models_info_path = models_info_path
|
| 38 |
+
with open(self.models_info_path, "r", encoding="utf-8") as f:
|
| 39 |
+
models_info = json.load(f)
|
| 40 |
+
self.models_info = models_info
|
| 41 |
+
|
| 42 |
+
def get_mt(self):
|
| 43 |
+
return [mt for mt in self.models_info]
|
| 44 |
+
|
| 45 |
+
def get_mn(self, model_type):
|
| 46 |
+
return [mn for mn in self.models_info.get(model_type, [])]
|
| 47 |
+
|
| 48 |
+
def get_stems(self, model_type, model_name):
|
| 49 |
+
return [
|
| 50 |
+
stem
|
| 51 |
+
for stem in self.models_info.get(model_type)
|
| 52 |
+
.get(model_name)
|
| 53 |
+
.get("stems", [])
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
def get_id(self, model_type, model_name):
|
| 57 |
+
return self.models_info.get(model_type).get(model_name).get("id", 0)
|
| 58 |
+
|
| 59 |
+
def get_tgt_inst(self, model_type, model_name):
|
| 60 |
+
return (
|
| 61 |
+
self.models_info.get(model_type)
|
| 62 |
+
.get(model_name)
|
| 63 |
+
.get("target_instrument", None)
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def get_category(self, model_type, model_name):
|
| 67 |
+
return self.models_info.get(model_type).get(model_name).get("category", "")
|
| 68 |
+
|
| 69 |
+
def display_models_info(self, filter: str = None):
|
| 70 |
+
table_data = []
|
| 71 |
+
headers = [
|
| 72 |
+
"Тип модели",
|
| 73 |
+
"ID",
|
| 74 |
+
"Имя модели",
|
| 75 |
+
"Стемы",
|
| 76 |
+
"Целевой инструмент",
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
for model_type, models in self.models_info.items():
|
| 80 |
+
for model_name, model_info in models.items():
|
| 81 |
+
try:
|
| 82 |
+
stems_list = model_info.get("stems", [])
|
| 83 |
+
id = model_info.get("id", "н/д")
|
| 84 |
+
if filter:
|
| 85 |
+
filter_lower = filter.lower()
|
| 86 |
+
if not any(filter_lower == s.lower() for s in stems_list):
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
row = [
|
| 90 |
+
model_type,
|
| 91 |
+
id,
|
| 92 |
+
model_name,
|
| 93 |
+
", ".join(stems_list) or "н/д",
|
| 94 |
+
model_info.get("target_instrument", "н/д"),
|
| 95 |
+
]
|
| 96 |
+
table_data.append(row)
|
| 97 |
+
except (KeyError, TypeError, AttributeError) as e:
|
| 98 |
+
print(f"Ошибка при обработке модели {model_type}/{model_name}: {e}")
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
if table_data:
|
| 102 |
+
print(tabulate(table_data, headers=headers, tablefmt="grid"))
|
| 103 |
+
else:
|
| 104 |
+
print("Нет моделей, которые содержат указанный стем")
|
| 105 |
+
|
| 106 |
+
def download_model(self, model_paths, model_name, model_type, ckpt_url, conf_url):
|
| 107 |
+
model_dir = os.path.join(model_paths, model_type)
|
| 108 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 109 |
+
|
| 110 |
+
config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
|
| 111 |
+
checkpoint_path = os.path.join(
|
| 112 |
+
model_dir,
|
| 113 |
+
f"{model_name}.onnx" if model_type == "mdxnet" else f"{model_name}.ckpt",
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if config_path is None or checkpoint_path is None:
|
| 117 |
+
raise RuntimeError()
|
| 118 |
+
|
| 119 |
+
if os.path.exists(checkpoint_path) and os.path.exists(config_path):
|
| 120 |
+
if (
|
| 121 |
+
os.path.getsize(checkpoint_path) == 0
|
| 122 |
+
or os.path.getsize(checkpoint_path) == 0
|
| 123 |
+
):
|
| 124 |
+
for local_path, url_model in [
|
| 125 |
+
(checkpoint_path, ckpt_url),
|
| 126 |
+
(config_path, conf_url),
|
| 127 |
+
]:
|
| 128 |
+
if not os.path.exists(local_path):
|
| 129 |
+
|
| 130 |
+
dw_file(url_model, local_path)
|
| 131 |
+
else:
|
| 132 |
+
pass
|
| 133 |
+
else:
|
| 134 |
+
for local_path, url_model in [
|
| 135 |
+
(checkpoint_path, ckpt_url),
|
| 136 |
+
(config_path, conf_url),
|
| 137 |
+
]:
|
| 138 |
+
if not os.path.exists(local_path):
|
| 139 |
+
|
| 140 |
+
dw_file(url_model, local_path)
|
| 141 |
+
|
| 142 |
+
return config_path, checkpoint_path
|
| 143 |
+
|
| 144 |
+
def conf_editor(self, config_path, mdx_denoise, vr_aggr, model_type):
|
| 145 |
+
|
| 146 |
+
class IndentDumper(yaml.Dumper):
|
| 147 |
+
def increase_indent(self, flow=False, indentless=False):
|
| 148 |
+
return super(IndentDumper, self).increase_indent(flow, False)
|
| 149 |
+
|
| 150 |
+
def tuple_constructor(loader, node):
|
| 151 |
+
values = loader.construct_sequence(node)
|
| 152 |
+
return tuple(values)
|
| 153 |
+
|
| 154 |
+
yaml.SafeLoader.add_constructor(
|
| 155 |
+
"tag:yaml.org,2002:python/tuple", tuple_constructor
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def conf_edit(config_path, mdx_denoise, vr_aggr, model_type):
|
| 159 |
+
with open(config_path, "r") as f:
|
| 160 |
+
data = yaml.load(f, Loader=yaml.SafeLoader)
|
| 161 |
+
|
| 162 |
+
if "use_amp" not in data.keys():
|
| 163 |
+
data["training"]["use_amp"] = True
|
| 164 |
+
|
| 165 |
+
if model_type != "vr":
|
| 166 |
+
if data["inference"]["num_overlap"] != 2:
|
| 167 |
+
data["inference"]["num_overlap"] = 2
|
| 168 |
+
|
| 169 |
+
if data["inference"]["batch_size"] != 1:
|
| 170 |
+
data["inference"]["batch_size"] = 1
|
| 171 |
+
|
| 172 |
+
if model_type == "mdxnet":
|
| 173 |
+
data["inference"]["denoise"] = mdx_denoise
|
| 174 |
+
|
| 175 |
+
elif model_type == "vr":
|
| 176 |
+
data["inference"]["aggression"] = vr_aggr
|
| 177 |
+
|
| 178 |
+
with open(config_path, "w") as f:
|
| 179 |
+
yaml.dump(
|
| 180 |
+
data,
|
| 181 |
+
f,
|
| 182 |
+
default_flow_style=False,
|
| 183 |
+
sort_keys=False,
|
| 184 |
+
Dumper=IndentDumper,
|
| 185 |
+
allow_unicode=True,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
conf_edit(config_path, mdx_denoise, vr_aggr, model_type)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class VbachModelManager:
|
| 192 |
+
def __init__(self):
|
| 193 |
+
self.rmvpe_path = os.path.join(script_dir, "predictors", "rmvpe.pt")
|
| 194 |
+
self.fcpe_path = os.path.join(script_dir, "predictors", "fcpe.pt")
|
| 195 |
+
self.custom_fairseq_huberts_dir = os.path.join(
|
| 196 |
+
script_dir, "custom_fairseq_embedders"
|
| 197 |
+
)
|
| 198 |
+
self.custom_transformers_huberts_dir = os.path.join(
|
| 199 |
+
script_dir, "custom_transformers_embedders"
|
| 200 |
+
)
|
| 201 |
+
self.huberts_fairseq_dict = {
|
| 202 |
+
"hubert_base": {
|
| 203 |
+
"url": "https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/hubert_base.pt",
|
| 204 |
+
"local_path": os.path.join(
|
| 205 |
+
self.custom_fairseq_huberts_dir, "hubert_base.pt"
|
| 206 |
+
),
|
| 207 |
+
},
|
| 208 |
+
"contentvec_base": {
|
| 209 |
+
"url": "https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/contentvec_base.pt",
|
| 210 |
+
"local_path": os.path.join(
|
| 211 |
+
self.custom_fairseq_huberts_dir, "contentvec_base.pt"
|
| 212 |
+
),
|
| 213 |
+
},
|
| 214 |
+
"korean_hubert_base": {
|
| 215 |
+
"url": "https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/korean_hubert_base.pt",
|
| 216 |
+
"local_path": os.path.join(
|
| 217 |
+
self.custom_fairseq_huberts_dir, "korean_hubert_base.pt"
|
| 218 |
+
),
|
| 219 |
+
},
|
| 220 |
+
"chinese_hubert_base": {
|
| 221 |
+
"url": "https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/chinese_hubert_base.pt",
|
| 222 |
+
"local_path": os.path.join(
|
| 223 |
+
self.custom_fairseq_huberts_dir, "chinese_hubert_base.pt"
|
| 224 |
+
),
|
| 225 |
+
},
|
| 226 |
+
"portuguese_hubert_base": {
|
| 227 |
+
"url": "https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/portuguese_hubert_base.pt",
|
| 228 |
+
"local_path": os.path.join(
|
| 229 |
+
self.custom_fairseq_huberts_dir, "portuguese_hubert_base.pt"
|
| 230 |
+
),
|
| 231 |
+
},
|
| 232 |
+
"japanese_hubert_base": {
|
| 233 |
+
"url": "https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/japanese_hubert_base.pt",
|
| 234 |
+
"local_path": os.path.join(
|
| 235 |
+
self.custom_fairseq_huberts_dir, "japanese_hubert_base.pt"
|
| 236 |
+
),
|
| 237 |
+
},
|
| 238 |
+
}
|
| 239 |
+
self.huberts_transformers_dict = {
|
| 240 |
+
"contentvec": {
|
| 241 |
+
"base_dir": os.path.join(
|
| 242 |
+
self.custom_transformers_huberts_dir, "contentvec"
|
| 243 |
+
),
|
| 244 |
+
"url_bin": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/contentvec/pytorch_model.bin",
|
| 245 |
+
"url_json": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/contentvec/config.json",
|
| 246 |
+
"local_bin": os.path.join(
|
| 247 |
+
self.custom_transformers_huberts_dir,
|
| 248 |
+
"contentvec",
|
| 249 |
+
"pytorch_model.bin",
|
| 250 |
+
),
|
| 251 |
+
"local_json": os.path.join(
|
| 252 |
+
self.custom_transformers_huberts_dir, "contentvec", "config.json"
|
| 253 |
+
),
|
| 254 |
+
},
|
| 255 |
+
"spin": {
|
| 256 |
+
"base_dir": os.path.join(self.custom_transformers_huberts_dir, "spin"),
|
| 257 |
+
"url_bin": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/spin/pytorch_model.bin",
|
| 258 |
+
"url_json": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/spin/config.json",
|
| 259 |
+
"local_bin": os.path.join(
|
| 260 |
+
self.custom_transformers_huberts_dir, "spin", "pytorch_model.bin"
|
| 261 |
+
),
|
| 262 |
+
"local_json": os.path.join(
|
| 263 |
+
self.custom_transformers_huberts_dir, "spin", "config.json"
|
| 264 |
+
),
|
| 265 |
+
},
|
| 266 |
+
"spin-v2": {
|
| 267 |
+
"base_dir": os.path.join(
|
| 268 |
+
self.custom_transformers_huberts_dir, "spinv2"
|
| 269 |
+
),
|
| 270 |
+
"url_bin": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/spin-v2/pytorch_model.bin",
|
| 271 |
+
"url_json": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/spin-v2/config.json",
|
| 272 |
+
"local_bin": os.path.join(
|
| 273 |
+
self.custom_transformers_huberts_dir, "spinv2", "pytorch_model.bin"
|
| 274 |
+
),
|
| 275 |
+
"local_json": os.path.join(
|
| 276 |
+
self.custom_transformers_huberts_dir, "spinv2", "config.json"
|
| 277 |
+
),
|
| 278 |
+
},
|
| 279 |
+
"chinese-hubert-base": {
|
| 280 |
+
"base_dir": os.path.join(
|
| 281 |
+
self.custom_transformers_huberts_dir, "chinese_hubert_base"
|
| 282 |
+
),
|
| 283 |
+
"url_bin": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/chinese_hubert_base/pytorch_model.bin",
|
| 284 |
+
"url_json": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/chinese_hubert_base/config.json",
|
| 285 |
+
"local_bin": os.path.join(
|
| 286 |
+
self.custom_transformers_huberts_dir,
|
| 287 |
+
"chinese_hubert_base",
|
| 288 |
+
"pytorch_model.bin",
|
| 289 |
+
),
|
| 290 |
+
"local_json": os.path.join(
|
| 291 |
+
self.custom_transformers_huberts_dir,
|
| 292 |
+
"chinese_hubert_base",
|
| 293 |
+
"config.json",
|
| 294 |
+
),
|
| 295 |
+
},
|
| 296 |
+
"japanese-hubert-base": {
|
| 297 |
+
"base_dir": os.path.join(
|
| 298 |
+
self.custom_transformers_huberts_dir, "japanese_hubert_base"
|
| 299 |
+
),
|
| 300 |
+
"url_bin": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/japanese_hubert_base/pytorch_model.bin",
|
| 301 |
+
"url_json": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/japanese_hubert_base/config.json",
|
| 302 |
+
"local_bin": os.path.join(
|
| 303 |
+
self.custom_transformers_huberts_dir,
|
| 304 |
+
"japanese_hubert_base",
|
| 305 |
+
"pytorch_model.bin",
|
| 306 |
+
),
|
| 307 |
+
"local_json": os.path.join(
|
| 308 |
+
self.custom_transformers_huberts_dir,
|
| 309 |
+
"japanese_hubert_base",
|
| 310 |
+
"config.json",
|
| 311 |
+
),
|
| 312 |
+
},
|
| 313 |
+
"korean-hubert-base": {
|
| 314 |
+
"base_dir": os.path.join(
|
| 315 |
+
self.custom_transformers_huberts_dir, "korean_hubert_base"
|
| 316 |
+
),
|
| 317 |
+
"url_bin": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/korean_hubert_base/pytorch_model.bin",
|
| 318 |
+
"url_json": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/korean_hubert_base/config.json",
|
| 319 |
+
"local_bin": os.path.join(
|
| 320 |
+
self.custom_transformers_huberts_dir,
|
| 321 |
+
"korean_hubert_base",
|
| 322 |
+
"pytorch_model.bin",
|
| 323 |
+
),
|
| 324 |
+
"local_json": os.path.join(
|
| 325 |
+
self.custom_transformers_huberts_dir,
|
| 326 |
+
"korean_hubert_base",
|
| 327 |
+
"config.json",
|
| 328 |
+
),
|
| 329 |
+
},
|
| 330 |
+
}
|
| 331 |
+
self.requirements = [
|
| 332 |
+
[
|
| 333 |
+
"https://huggingface.co/Politrees/RVC_resources/resolve/main/predictors/rmvpe.pt",
|
| 334 |
+
self.rmvpe_path,
|
| 335 |
+
],
|
| 336 |
+
[
|
| 337 |
+
"https://huggingface.co/Politrees/RVC_resources/resolve/main/predictors/fcpe.pt",
|
| 338 |
+
self.fcpe_path,
|
| 339 |
+
],
|
| 340 |
+
]
|
| 341 |
+
self.voicemodels_dir = os.path.join(script_dir, "vbach_models_cache")
|
| 342 |
+
os.makedirs(self.voicemodels_dir, exist_ok=True)
|
| 343 |
+
self.voicemodels_info = os.path.join(self.voicemodels_dir, "vbach_models.json")
|
| 344 |
+
self.voicemodels: Dict[str, Dict[str, str]] = {}
|
| 345 |
+
self.download_requirements()
|
| 346 |
+
self.check_hubert("hubert_base")
|
| 347 |
+
self.check_and_load()
|
| 348 |
+
pass
|
| 349 |
+
|
| 350 |
+
def check_hubert(self, embedder_name):
|
| 351 |
+
if embedder_name in self.huberts_fairseq_dict:
|
| 352 |
+
if not os.path.exists(
|
| 353 |
+
self.huberts_fairseq_dict[embedder_name]["local_path"]
|
| 354 |
+
):
|
| 355 |
+
dw_file(
|
| 356 |
+
self.huberts_fairseq_dict[embedder_name]["url"],
|
| 357 |
+
self.huberts_fairseq_dict[embedder_name]["local_path"],
|
| 358 |
+
)
|
| 359 |
+
return self.huberts_fairseq_dict[embedder_name]["local_path"]
|
| 360 |
+
else:
|
| 361 |
+
return None
|
| 362 |
+
|
| 363 |
+
def check_hubert_transformers(self, embedder_name):
|
| 364 |
+
if embedder_name in self.huberts_transformers_dict:
|
| 365 |
+
os.makedirs(
|
| 366 |
+
self.huberts_transformers_dict[embedder_name]["base_dir"], exist_ok=True
|
| 367 |
+
)
|
| 368 |
+
if not os.path.exists(
|
| 369 |
+
self.huberts_transformers_dict[embedder_name]["local_bin"]
|
| 370 |
+
) and not os.path.exists(
|
| 371 |
+
self.huberts_transformers_dict[embedder_name]["local_json"]
|
| 372 |
+
):
|
| 373 |
+
dw_file(
|
| 374 |
+
self.huberts_transformers_dict[embedder_name]["url_bin"],
|
| 375 |
+
self.huberts_transformers_dict[embedder_name]["local_bin"],
|
| 376 |
+
)
|
| 377 |
+
dw_file(
|
| 378 |
+
self.huberts_transformers_dict[embedder_name]["url_json"],
|
| 379 |
+
self.huberts_transformers_dict[embedder_name]["local_json"],
|
| 380 |
+
)
|
| 381 |
+
return self.huberts_transformers_dict[embedder_name]["base_dir"]
|
| 382 |
+
else:
|
| 383 |
+
return None
|
| 384 |
+
|
| 385 |
+
def write_voicemodels_info(self):
|
| 386 |
+
with open(self.voicemodels_info, "w") as f:
|
| 387 |
+
json.dump(self.voicemodels, f, indent=4)
|
| 388 |
+
|
| 389 |
+
def load_voicemodels_info(self):
|
| 390 |
+
with open(self.voicemodels_info, "r") as f:
|
| 391 |
+
return json.load(f)
|
| 392 |
+
|
| 393 |
+
def add_voice_model(
|
| 394 |
+
self,
|
| 395 |
+
name,
|
| 396 |
+
pth_path,
|
| 397 |
+
index_path,
|
| 398 |
+
):
|
| 399 |
+
self.voicemodels[name] = {"pth": pth_path, "index": index_path}
|
| 400 |
+
self.write_voicemodels_info()
|
| 401 |
+
|
| 402 |
+
def del_voice_model(self, name):
|
| 403 |
+
if name in self.parse_voice_models():
|
| 404 |
+
pth = self.voicemodels[name].get("pth", None)
|
| 405 |
+
index = self.voicemodels[name].get("index", None)
|
| 406 |
+
if index:
|
| 407 |
+
os.remove(index)
|
| 408 |
+
if pth:
|
| 409 |
+
os.remove(pth)
|
| 410 |
+
del self.voicemodels[name]
|
| 411 |
+
self.write_voicemodels_info()
|
| 412 |
+
return f"Модель {name} удалена"
|
| 413 |
+
else:
|
| 414 |
+
return f"Модель не была удалена, как так её не существует"
|
| 415 |
+
|
| 416 |
+
def parse_voice_models(self):
|
| 417 |
+
list_models = list(self.voicemodels.keys())
|
| 418 |
+
return list_models
|
| 419 |
+
|
| 420 |
+
def parse_pth_and_index(self, name):
|
| 421 |
+
pth = self.voicemodels[name].get("pth", None)
|
| 422 |
+
index = self.voicemodels[name].get("index", None)
|
| 423 |
+
return pth, index
|
| 424 |
+
|
| 425 |
+
def check_and_load(self):
|
| 426 |
+
if os.path.exists(self.voicemodels_info):
|
| 427 |
+
self.voicemodels = self.load_voicemodels_info()
|
| 428 |
+
else:
|
| 429 |
+
self.write_voicemodels_info()
|
| 430 |
+
|
| 431 |
+
def clear_voicemodels_info(self):
|
| 432 |
+
self.voicemodels: Dict[str, Dict[str, str]] = {}
|
| 433 |
+
self.write_voicemodels_info()
|
| 434 |
+
|
| 435 |
+
def download_requirements(self):
|
| 436 |
+
for url, file in self.requirements:
|
| 437 |
+
if not os.path.exists(file):
|
| 438 |
+
dw_file(url, file)
|
| 439 |
+
|
| 440 |
+
def download_voice_model_file(self, url, zip_name):
|
| 441 |
+
try:
|
| 442 |
+
if "drive.google.com" in url:
|
| 443 |
+
self.download_from_google_drive(url, zip_name)
|
| 444 |
+
elif "pixeldrain.com" in url:
|
| 445 |
+
self.download_from_pixeldrain(url, zip_name)
|
| 446 |
+
elif "disk.yandex.ru" in url or "yadi.sk" in url:
|
| 447 |
+
self.download_from_yandex(url, zip_name)
|
| 448 |
+
else:
|
| 449 |
+
dw_file(url, zip_name)
|
| 450 |
+
except Exception as e:
|
| 451 |
+
print(e)
|
| 452 |
+
|
| 453 |
+
def download_from_google_drive(self, url, zip_name):
|
| 454 |
+
file_id = (
|
| 455 |
+
url.split("file/d/")[1].split("/")[0]
|
| 456 |
+
if "file/d/" in url
|
| 457 |
+
else url.split("id=")[1].split("&")[0]
|
| 458 |
+
)
|
| 459 |
+
gdown.download(id=file_id, output=str(zip_name), quiet=False)
|
| 460 |
+
|
| 461 |
+
def download_from_pixeldrain(self, url, zip_name):
|
| 462 |
+
file_id = url.split("pixeldrain.com/u/")[1]
|
| 463 |
+
response = requests.get(f"https://pixeldrain.com/api/file/{file_id}")
|
| 464 |
+
with open(zip_name, "wb") as f:
|
| 465 |
+
f.write(response.content)
|
| 466 |
+
|
| 467 |
+
def download_from_yandex(self, url, zip_name):
|
| 468 |
+
yandex_public_key = f"download?public_key={url}"
|
| 469 |
+
yandex_api_url = (
|
| 470 |
+
f"https://cloud-api.yandex.net/v1/disk/public/resources/{yandex_public_key}"
|
| 471 |
+
)
|
| 472 |
+
response = requests.get(yandex_api_url)
|
| 473 |
+
if response.status_code == 200:
|
| 474 |
+
download_link = response.json().get("href")
|
| 475 |
+
urllib.request.urlretrieve(download_link, zip_name)
|
| 476 |
+
else:
|
| 477 |
+
print(response.status_code)
|
| 478 |
+
|
| 479 |
+
def extract_zip(self, zip_name, model_name):
|
| 480 |
+
model_dir = os.path.join(
|
| 481 |
+
self.voicemodels_dir, f"{model_name}_{generate_secure_random(17)}"
|
| 482 |
+
)
|
| 483 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 484 |
+
try:
|
| 485 |
+
with zipfile.ZipFile(zip_name, "r") as zip_ref:
|
| 486 |
+
zip_ref.extractall(model_dir)
|
| 487 |
+
os.remove(zip_name)
|
| 488 |
+
|
| 489 |
+
added_voice_models = []
|
| 490 |
+
|
| 491 |
+
index_filepath, model_filepaths = None, []
|
| 492 |
+
for root, _, files in os.walk(model_dir):
|
| 493 |
+
for name in files:
|
| 494 |
+
file_path = os.path.join(root, name)
|
| 495 |
+
if (
|
| 496 |
+
name.endswith(".index")
|
| 497 |
+
and os.stat(file_path).st_size > 1024 * 100
|
| 498 |
+
):
|
| 499 |
+
index_filepath = file_path
|
| 500 |
+
if (
|
| 501 |
+
name.endswith(".pth")
|
| 502 |
+
and os.stat(file_path).st_size > 1024 * 1024 * 20
|
| 503 |
+
):
|
| 504 |
+
model_filepaths.append(file_path)
|
| 505 |
+
|
| 506 |
+
if len(model_filepaths) == 1:
|
| 507 |
+
self.add_voice_model(model_name, model_filepaths[0], index_filepath)
|
| 508 |
+
added_voice_models.append(model_name)
|
| 509 |
+
else:
|
| 510 |
+
for i, pth in enumerate(model_filepaths):
|
| 511 |
+
self.add_voice_model(f"{model_name}_{i + 1}", pth, index_filepath)
|
| 512 |
+
added_voice_models.append(f"{model_name}_{i + 1}")
|
| 513 |
+
list_models_str = "\n".join(added_voice_models)
|
| 514 |
+
return f"Добавленные модели:\n{list_models_str}"
|
| 515 |
+
except Exception as e:
|
| 516 |
+
return f"Произошла ошибка при загрузке модели: {e}"
|
| 517 |
+
|
| 518 |
+
def install_model_zip(self, zip, model_name, mode="url"):
|
| 519 |
+
if model_name in self.parse_voice_models():
|
| 520 |
+
print(
|
| 521 |
+
"Эта модель уже есть в списке установленных моделей. Она будут перезаписана"
|
| 522 |
+
)
|
| 523 |
+
if mode == "url":
|
| 524 |
+
with tempfile.TemporaryDirectory(
|
| 525 |
+
prefix="vbach_temp_model", ignore_cleanup_errors=True
|
| 526 |
+
) as tmp:
|
| 527 |
+
zip_path = os.path.join(tmp, "model.zip")
|
| 528 |
+
self.download_voice_model_file(zip, zip_path)
|
| 529 |
+
status = self.extract_zip(zip_path, model_name)
|
| 530 |
+
if mode == "local":
|
| 531 |
+
status = self.extract_zip(zip, model_name)
|
| 532 |
+
return status
|
| 533 |
+
|
| 534 |
+
def install_model_files(self, index, pth, model_name, mode="url"):
|
| 535 |
+
if model_name in self.parse_voice_models():
|
| 536 |
+
print(
|
| 537 |
+
"Эта модель уже есть в списке установленных моделей. Она будут перезаписана"
|
| 538 |
+
)
|
| 539 |
+
model_dir = os.path.join(
|
| 540 |
+
self.voicemodels_dir, f"{model_name}_{generate_secure_random(17)}"
|
| 541 |
+
)
|
| 542 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 543 |
+
local_index_path = None
|
| 544 |
+
local_pth_path = None
|
| 545 |
+
try:
|
| 546 |
+
if mode == "url":
|
| 547 |
+
if index:
|
| 548 |
+
local_index_path = os.path.join(model_dir, "model.index")
|
| 549 |
+
self.download_voice_model_file(index, local_index_path)
|
| 550 |
+
if pth:
|
| 551 |
+
local_pth_path = os.path.join(model_dir, "model.pth")
|
| 552 |
+
self.download_voice_model_file(pth, local_pth_path)
|
| 553 |
+
|
| 554 |
+
if mode == "local":
|
| 555 |
+
if index:
|
| 556 |
+
if os.path.exists(index):
|
| 557 |
+
local_index_path = os.path.join(
|
| 558 |
+
model_dir, os.path.basename(index)
|
| 559 |
+
)
|
| 560 |
+
shutil.copy(index, local_index_path)
|
| 561 |
+
if pth:
|
| 562 |
+
if os.path.exists(pth):
|
| 563 |
+
local_pth_path = os.path.join(model_dir, os.path.basename(pth))
|
| 564 |
+
shutil.copy(pth, local_pth_path)
|
| 565 |
+
|
| 566 |
+
self.add_voice_model(model_name, local_pth_path, local_index_path)
|
| 567 |
+
return f"Модель {model_name} добавлена"
|
| 568 |
+
except Exception as e:
|
| 569 |
+
return f"Произошла ошибка при загрузке модели: {e}"
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
if __name__ == "__main__":
|
| 573 |
+
parser = argparse.ArgumentParser(description="Менеджер моделей")
|
| 574 |
+
subparsers = parser.add_subparsers(
|
| 575 |
+
title="subcommands", dest="command", required=True
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
mvsepless_parser = subparsers.add_parser(
|
| 579 |
+
"mvsepless", help="Скачивание моделей в MVSepLess"
|
| 580 |
+
)
|
| 581 |
+
mvsepless_parser.add_argument("--model_type", required=True, help="Тип модели")
|
| 582 |
+
mvsepless_parser.add_argument("--model_name", required=True, help="Имя модели")
|
| 583 |
+
|
| 584 |
+
vbach_parser = subparsers.add_parser(
|
| 585 |
+
"vbach", help="Установка голосовых моделей в Vbach"
|
| 586 |
+
)
|
| 587 |
+
vbach_subparsers = vbach_parser.add_subparsers(
|
| 588 |
+
title="vbach_commands", dest="vbach_command", required=True
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
install_local_parser = vbach_subparsers.add_parser(
|
| 592 |
+
"install_local", help="Установка голосовой модели по локальным файлам"
|
| 593 |
+
)
|
| 594 |
+
install_local_parser.add_argument(
|
| 595 |
+
"--model_name", required=True, help="Имя голосовой модели"
|
| 596 |
+
)
|
| 597 |
+
install_local_parser.add_argument("--pth", required=True, help="Путь к *.pth файлу")
|
| 598 |
+
install_local_parser.add_argument(
|
| 599 |
+
"--index", required=False, help="Путь к *.index файлу"
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
install_url_zip_parser = vbach_subparsers.add_parser(
|
| 603 |
+
"install_url_zip", help="Установка голосовой модели по URL (архив с файлами)"
|
| 604 |
+
)
|
| 605 |
+
install_url_zip_parser.add_argument(
|
| 606 |
+
"--model_name", required=True, help="Имя голосовой модели"
|
| 607 |
+
)
|
| 608 |
+
install_url_zip_parser.add_argument("--url", required=True, help="URL *.zip файла")
|
| 609 |
+
|
| 610 |
+
install_url_files_parser = vbach_subparsers.add_parser(
|
| 611 |
+
"install_url_files", help="Установка голосовой модели по URL (отдельные файлы)"
|
| 612 |
+
)
|
| 613 |
+
install_url_files_parser.add_argument(
|
| 614 |
+
"--model_name", required=True, help="Имя голосовой модели"
|
| 615 |
+
)
|
| 616 |
+
install_url_files_parser.add_argument(
|
| 617 |
+
"--pth_url", required=True, help="URL *.pth файла"
|
| 618 |
+
)
|
| 619 |
+
install_url_files_parser.add_argument(
|
| 620 |
+
"--index_url", required=False, help="URL *.index файла"
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
list_parser = vbach_subparsers.add_parser(
|
| 624 |
+
"list", help="Список установленных моделей"
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
remove_voice_model = vbach_subparsers.add_parser("remove", help="Удаление модели")
|
| 628 |
+
remove_voice_model.add_argument(
|
| 629 |
+
"--model_name", required=True, help="Имя голосовой модели"
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
args = parser.parse_args()
|
| 633 |
+
|
| 634 |
+
if args.command == "mvsepless":
|
| 635 |
+
|
| 636 |
+
_model_manager = MvseplessModelManager()
|
| 637 |
+
info = _model_manager.models_info.get(args.model_type).get(args.model_name, None)
|
| 638 |
+
if not info:
|
| 639 |
+
raise ValueError(
|
| 640 |
+
f"Модель {args.model_name} не найдена для типа {args.model_type}"
|
| 641 |
+
)
|
| 642 |
+
conf, ckpt = _model_manager.download_model(
|
| 643 |
+
_model_manager.models_cache_dir,
|
| 644 |
+
args.model_name,
|
| 645 |
+
args.model_type,
|
| 646 |
+
info["checkpoint_url"],
|
| 647 |
+
info["config_url"],
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
elif args.command == "vbach":
|
| 651 |
+
model_manager = VbachModelManager()
|
| 652 |
+
|
| 653 |
+
if args.vbach_command == "install_local":
|
| 654 |
+
status = model_manager.install_model_files(
|
| 655 |
+
args.index, args.pth, args.model_name, mode="local"
|
| 656 |
+
)
|
| 657 |
+
print(status)
|
| 658 |
+
|
| 659 |
+
elif args.vbach_command == "install_url_zip":
|
| 660 |
+
status = model_manager.install_model_zip(
|
| 661 |
+
args.url, args.model_name, mode="url"
|
| 662 |
+
)
|
| 663 |
+
print(status)
|
| 664 |
+
|
| 665 |
+
elif args.vbach_command == "install_url_files":
|
| 666 |
+
status = model_manager.install_model_files(
|
| 667 |
+
args.index_url, args.pth_url, args.model_name, mode="url"
|
| 668 |
+
)
|
| 669 |
+
print(status)
|
| 670 |
+
|
| 671 |
+
elif args.vbach_command == "list":
|
| 672 |
+
models = model_manager.parse_voice_models()
|
| 673 |
+
if models:
|
| 674 |
+
print("Установленные модели:")
|
| 675 |
+
for model in models:
|
| 676 |
+
print(f" - {model}")
|
| 677 |
+
else:
|
| 678 |
+
print("Нет установленных моделей")
|
| 679 |
+
|
| 680 |
+
elif args.vbach_command == "remove":
|
| 681 |
+
status = model_manager.del_voice_model(args.model_name)
|
| 682 |
+
print(status)
|
mvsepless/models.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
mvsepless/models/bandit/core/__init__.py
CHANGED
|
@@ -1,691 +1,669 @@
|
|
| 1 |
-
import os.path
|
| 2 |
-
from collections import defaultdict
|
| 3 |
-
from itertools import chain, combinations
|
| 4 |
-
from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Type, TypedDict
|
| 5 |
-
|
| 6 |
-
import pytorch_lightning as pl
|
| 7 |
-
import torch
|
| 8 |
-
import torchaudio as ta
|
| 9 |
-
import torchmetrics as tm
|
| 10 |
-
from asteroid import losses as asteroid_losses
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
from
|
| 15 |
-
from torch import
|
| 16 |
-
|
| 17 |
-
from
|
| 18 |
-
|
| 19 |
-
from . import
|
| 20 |
-
from .
|
| 21 |
-
from .
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
self
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
self.
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
self.
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
self.
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
def
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
"audio"
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
[
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
self.
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
if
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
[
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
self.attach_fader()
|
| 672 |
-
|
| 673 |
-
def attach_fader(self, force_reattach=False) -> None:
|
| 674 |
-
if self.fader is None or force_reattach:
|
| 675 |
-
self.fader = parse_fader_config(self.fader_config)
|
| 676 |
-
self.fader.to(self.device)
|
| 677 |
-
|
| 678 |
-
def log_dict_with_prefix(
|
| 679 |
-
self,
|
| 680 |
-
dict_: Dict[str, torch.Tensor],
|
| 681 |
-
prefix: str,
|
| 682 |
-
batch_size: Optional[int] = None,
|
| 683 |
-
**kwargs: Any,
|
| 684 |
-
) -> None:
|
| 685 |
-
self.log_dict(
|
| 686 |
-
{f"{prefix}/{k}": v for k, v in dict_.items()},
|
| 687 |
-
batch_size=batch_size,
|
| 688 |
-
logger=True,
|
| 689 |
-
sync_dist=True,
|
| 690 |
-
**kwargs,
|
| 691 |
-
)
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from itertools import chain, combinations
|
| 4 |
+
from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Type, TypedDict
|
| 5 |
+
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio as ta
|
| 9 |
+
import torchmetrics as tm
|
| 10 |
+
from asteroid import losses as asteroid_losses
|
| 11 |
+
|
| 12 |
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
| 13 |
+
from torch import nn, optim
|
| 14 |
+
from torch.optim import lr_scheduler
|
| 15 |
+
from torch.optim.lr_scheduler import LRScheduler
|
| 16 |
+
|
| 17 |
+
from . import loss, metrics as metrics_, model
|
| 18 |
+
from .data._types import BatchedDataDict
|
| 19 |
+
from .data.augmentation import BaseAugmentor, StemAugmentor
|
| 20 |
+
from .utils import audio as audio_
|
| 21 |
+
from .utils.audio import BaseFader
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
ConfigDict = TypedDict("ConfigDict", {"name": str, "kwargs": Dict[str, Any]})
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SchedulerConfigDict(ConfigDict):
|
| 28 |
+
monitor: str
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
OptimizerSchedulerConfigDict = TypedDict(
|
| 32 |
+
"OptimizerSchedulerConfigDict",
|
| 33 |
+
{"optimizer": ConfigDict, "scheduler": SchedulerConfigDict},
|
| 34 |
+
total=False,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class LRSchedulerReturnDict(TypedDict, total=False):
|
| 39 |
+
scheduler: LRScheduler
|
| 40 |
+
monitor: str
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ConfigureOptimizerReturnDict(TypedDict, total=False):
|
| 44 |
+
optimizer: torch.optim.Optimizer
|
| 45 |
+
lr_scheduler: LRSchedulerReturnDict
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
OutputType = Dict[str, Any]
|
| 49 |
+
MetricsType = Dict[str, torch.Tensor]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_optimizer_class(name: str) -> Type[optim.Optimizer]:
|
| 53 |
+
|
| 54 |
+
if name == "DeepSpeedCPUAdam":
|
| 55 |
+
return DeepSpeedCPUAdam
|
| 56 |
+
|
| 57 |
+
for module in [optim, gooptim]:
|
| 58 |
+
if name in module.__dict__:
|
| 59 |
+
return module.__dict__[name]
|
| 60 |
+
|
| 61 |
+
raise NameError
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def parse_optimizer_config(
|
| 65 |
+
config: OptimizerSchedulerConfigDict, parameters: Iterator[nn.Parameter]
|
| 66 |
+
) -> ConfigureOptimizerReturnDict:
|
| 67 |
+
optim_class = get_optimizer_class(config["optimizer"]["name"])
|
| 68 |
+
optimizer = optim_class(parameters, **config["optimizer"]["kwargs"])
|
| 69 |
+
|
| 70 |
+
optim_dict: ConfigureOptimizerReturnDict = {
|
| 71 |
+
"optimizer": optimizer,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
if "scheduler" in config:
|
| 75 |
+
|
| 76 |
+
lr_scheduler_class_ = config["scheduler"]["name"]
|
| 77 |
+
lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_]
|
| 78 |
+
lr_scheduler_dict: LRSchedulerReturnDict = {
|
| 79 |
+
"scheduler": lr_scheduler_class(optimizer, **config["scheduler"]["kwargs"])
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
if lr_scheduler_class_ == "ReduceLROnPlateau":
|
| 83 |
+
lr_scheduler_dict["monitor"] = config["scheduler"]["monitor"]
|
| 84 |
+
|
| 85 |
+
optim_dict["lr_scheduler"] = lr_scheduler_dict
|
| 86 |
+
|
| 87 |
+
return optim_dict
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def parse_model_config(config: ConfigDict) -> Any:
|
| 91 |
+
name = config["name"]
|
| 92 |
+
|
| 93 |
+
for module in [model]:
|
| 94 |
+
if name in module.__dict__:
|
| 95 |
+
return module.__dict__[name](**config["kwargs"])
|
| 96 |
+
|
| 97 |
+
raise NameError
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
_LEGACY_LOSS_NAMES = ["HybridL1Loss"]
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _parse_legacy_loss_config(config: ConfigDict) -> nn.Module:
|
| 104 |
+
name = config["name"]
|
| 105 |
+
|
| 106 |
+
if name == "HybridL1Loss":
|
| 107 |
+
return loss.TimeFreqL1Loss(**config["kwargs"])
|
| 108 |
+
|
| 109 |
+
raise NameError
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def parse_loss_config(config: ConfigDict) -> nn.Module:
|
| 113 |
+
name = config["name"]
|
| 114 |
+
|
| 115 |
+
if name in _LEGACY_LOSS_NAMES:
|
| 116 |
+
return _parse_legacy_loss_config(config)
|
| 117 |
+
|
| 118 |
+
for module in [loss, nn.modules.loss, asteroid_losses]:
|
| 119 |
+
if name in module.__dict__:
|
| 120 |
+
return module.__dict__[name](**config["kwargs"])
|
| 121 |
+
|
| 122 |
+
raise NameError
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def get_metric(config: ConfigDict) -> tm.Metric:
|
| 126 |
+
name = config["name"]
|
| 127 |
+
|
| 128 |
+
for module in [tm, metrics_]:
|
| 129 |
+
if name in module.__dict__:
|
| 130 |
+
return module.__dict__[name](**config["kwargs"])
|
| 131 |
+
raise NameError
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def parse_metric_config(config: Dict[str, ConfigDict]) -> tm.MetricCollection:
|
| 135 |
+
metrics = {}
|
| 136 |
+
|
| 137 |
+
for metric in config:
|
| 138 |
+
metrics[metric] = get_metric(config[metric])
|
| 139 |
+
|
| 140 |
+
return tm.MetricCollection(metrics)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def parse_fader_config(config: ConfigDict) -> BaseFader:
|
| 144 |
+
name = config["name"]
|
| 145 |
+
|
| 146 |
+
for module in [audio_]:
|
| 147 |
+
if name in module.__dict__:
|
| 148 |
+
return module.__dict__[name](**config["kwargs"])
|
| 149 |
+
|
| 150 |
+
raise NameError
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class LightningSystem(pl.LightningModule):
|
| 154 |
+
_VOX_STEMS = ["speech", "vocals"]
|
| 155 |
+
_BG_STEMS = ["background", "effects", "mne"]
|
| 156 |
+
|
| 157 |
+
def __init__(
|
| 158 |
+
self, config: Dict, loss_adjustment: float = 1.0, attach_fader: bool = False
|
| 159 |
+
) -> None:
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.optimizer_config = config["optimizer"]
|
| 162 |
+
self.model = parse_model_config(config["model"])
|
| 163 |
+
self.loss = parse_loss_config(config["loss"])
|
| 164 |
+
self.metrics = nn.ModuleDict(
|
| 165 |
+
{
|
| 166 |
+
stem: parse_metric_config(config["metrics"]["dev"])
|
| 167 |
+
for stem in self.model.stems
|
| 168 |
+
}
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
self.metrics.disallow_fsdp = True
|
| 172 |
+
|
| 173 |
+
self.test_metrics = nn.ModuleDict(
|
| 174 |
+
{
|
| 175 |
+
stem: parse_metric_config(config["metrics"]["test"])
|
| 176 |
+
for stem in self.model.stems
|
| 177 |
+
}
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
self.test_metrics.disallow_fsdp = True
|
| 181 |
+
|
| 182 |
+
self.fs = config["model"]["kwargs"]["fs"]
|
| 183 |
+
|
| 184 |
+
self.fader_config = config["inference"]["fader"]
|
| 185 |
+
if attach_fader:
|
| 186 |
+
self.fader = parse_fader_config(config["inference"]["fader"])
|
| 187 |
+
else:
|
| 188 |
+
self.fader = None
|
| 189 |
+
|
| 190 |
+
self.augmentation: Optional[BaseAugmentor]
|
| 191 |
+
if config.get("augmentation", None) is not None:
|
| 192 |
+
self.augmentation = StemAugmentor(**config["augmentation"])
|
| 193 |
+
else:
|
| 194 |
+
self.augmentation = None
|
| 195 |
+
|
| 196 |
+
self.predict_output_path: Optional[str] = None
|
| 197 |
+
self.loss_adjustment = loss_adjustment
|
| 198 |
+
|
| 199 |
+
self.val_prefix = None
|
| 200 |
+
self.test_prefix = None
|
| 201 |
+
|
| 202 |
+
def configure_optimizers(self) -> Any:
|
| 203 |
+
return parse_optimizer_config(
|
| 204 |
+
self.optimizer_config, self.trainer.model.parameters()
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def compute_loss(
|
| 208 |
+
self, batch: BatchedDataDict, output: OutputType
|
| 209 |
+
) -> Dict[str, torch.Tensor]:
|
| 210 |
+
return {"loss": self.loss(output, batch)}
|
| 211 |
+
|
| 212 |
+
def update_metrics(
|
| 213 |
+
self, batch: BatchedDataDict, output: OutputType, mode: str
|
| 214 |
+
) -> None:
|
| 215 |
+
|
| 216 |
+
if mode == "test":
|
| 217 |
+
metrics = self.test_metrics
|
| 218 |
+
else:
|
| 219 |
+
metrics = self.metrics
|
| 220 |
+
|
| 221 |
+
for stem, metric in metrics.items():
|
| 222 |
+
|
| 223 |
+
if stem == "mne:+":
|
| 224 |
+
stem = "mne"
|
| 225 |
+
|
| 226 |
+
if mode == "train":
|
| 227 |
+
metric.update(
|
| 228 |
+
output["audio"][stem],
|
| 229 |
+
batch["audio"][stem],
|
| 230 |
+
)
|
| 231 |
+
else:
|
| 232 |
+
if stem not in batch["audio"]:
|
| 233 |
+
matched = False
|
| 234 |
+
if stem in self._VOX_STEMS:
|
| 235 |
+
for bstem in self._VOX_STEMS:
|
| 236 |
+
if bstem in batch["audio"]:
|
| 237 |
+
batch["audio"][stem] = batch["audio"][bstem]
|
| 238 |
+
matched = True
|
| 239 |
+
break
|
| 240 |
+
elif stem in self._BG_STEMS:
|
| 241 |
+
for bstem in self._BG_STEMS:
|
| 242 |
+
if bstem in batch["audio"]:
|
| 243 |
+
batch["audio"][stem] = batch["audio"][bstem]
|
| 244 |
+
matched = True
|
| 245 |
+
break
|
| 246 |
+
else:
|
| 247 |
+
matched = True
|
| 248 |
+
|
| 249 |
+
if matched:
|
| 250 |
+
if stem == "mne" and "mne" not in output["audio"]:
|
| 251 |
+
output["audio"]["mne"] = (
|
| 252 |
+
output["audio"]["music"] + output["audio"]["effects"]
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
metric.update(
|
| 256 |
+
output["audio"][stem],
|
| 257 |
+
batch["audio"][stem],
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
def compute_metrics(self, mode: str = "dev") -> Dict[str, torch.Tensor]:
|
| 261 |
+
|
| 262 |
+
if mode == "test":
|
| 263 |
+
metrics = self.test_metrics
|
| 264 |
+
else:
|
| 265 |
+
metrics = self.metrics
|
| 266 |
+
|
| 267 |
+
metric_dict = {}
|
| 268 |
+
|
| 269 |
+
for stem, metric in metrics.items():
|
| 270 |
+
md = metric.compute()
|
| 271 |
+
metric_dict.update({f"{stem}/{k}": v for k, v in md.items()})
|
| 272 |
+
|
| 273 |
+
self.log_dict(metric_dict, prog_bar=True, logger=False)
|
| 274 |
+
|
| 275 |
+
return metric_dict
|
| 276 |
+
|
| 277 |
+
def reset_metrics(self, test_mode: bool = False) -> None:
|
| 278 |
+
|
| 279 |
+
if test_mode:
|
| 280 |
+
metrics = self.test_metrics
|
| 281 |
+
else:
|
| 282 |
+
metrics = self.metrics
|
| 283 |
+
|
| 284 |
+
for _, metric in metrics.items():
|
| 285 |
+
metric.reset()
|
| 286 |
+
|
| 287 |
+
def forward(self, batch: BatchedDataDict) -> Any:
|
| 288 |
+
batch, output = self.model(batch)
|
| 289 |
+
|
| 290 |
+
return batch, output
|
| 291 |
+
|
| 292 |
+
def common_step(self, batch: BatchedDataDict, mode: str) -> Any:
|
| 293 |
+
batch, output = self.forward(batch)
|
| 294 |
+
loss_dict = self.compute_loss(batch, output)
|
| 295 |
+
|
| 296 |
+
with torch.no_grad():
|
| 297 |
+
self.update_metrics(batch, output, mode=mode)
|
| 298 |
+
|
| 299 |
+
if mode == "train":
|
| 300 |
+
self.log("loss", loss_dict["loss"], prog_bar=True)
|
| 301 |
+
|
| 302 |
+
return output, loss_dict
|
| 303 |
+
|
| 304 |
+
def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]:
|
| 305 |
+
|
| 306 |
+
if self.augmentation is not None:
|
| 307 |
+
with torch.no_grad():
|
| 308 |
+
batch = self.augmentation(batch)
|
| 309 |
+
|
| 310 |
+
_, loss_dict = self.common_step(batch, mode="train")
|
| 311 |
+
|
| 312 |
+
with torch.inference_mode():
|
| 313 |
+
self.log_dict_with_prefix(
|
| 314 |
+
loss_dict, "train", batch_size=batch["audio"]["mixture"].shape[0]
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
loss_dict["loss"] *= self.loss_adjustment
|
| 318 |
+
|
| 319 |
+
return loss_dict
|
| 320 |
+
|
| 321 |
+
def on_train_batch_end(
|
| 322 |
+
self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int
|
| 323 |
+
) -> None:
|
| 324 |
+
|
| 325 |
+
metric_dict = self.compute_metrics()
|
| 326 |
+
self.log_dict_with_prefix(metric_dict, "train")
|
| 327 |
+
self.reset_metrics()
|
| 328 |
+
|
| 329 |
+
def validation_step(
|
| 330 |
+
self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
|
| 331 |
+
) -> Dict[str, Any]:
|
| 332 |
+
|
| 333 |
+
with torch.inference_mode():
|
| 334 |
+
curr_val_prefix = f"val{dataloader_idx}" if dataloader_idx > 0 else "val"
|
| 335 |
+
|
| 336 |
+
if curr_val_prefix != self.val_prefix:
|
| 337 |
+
if self.val_prefix is not None:
|
| 338 |
+
self._on_validation_epoch_end()
|
| 339 |
+
self.val_prefix = curr_val_prefix
|
| 340 |
+
_, loss_dict = self.common_step(batch, mode="val")
|
| 341 |
+
|
| 342 |
+
self.log_dict_with_prefix(
|
| 343 |
+
loss_dict,
|
| 344 |
+
self.val_prefix,
|
| 345 |
+
batch_size=batch["audio"]["mixture"].shape[0],
|
| 346 |
+
prog_bar=True,
|
| 347 |
+
add_dataloader_idx=False,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
return loss_dict
|
| 351 |
+
|
| 352 |
+
def on_validation_epoch_end(self) -> None:
|
| 353 |
+
self._on_validation_epoch_end()
|
| 354 |
+
|
| 355 |
+
def _on_validation_epoch_end(self) -> None:
|
| 356 |
+
metric_dict = self.compute_metrics()
|
| 357 |
+
self.log_dict_with_prefix(
|
| 358 |
+
metric_dict, self.val_prefix, prog_bar=True, add_dataloader_idx=False
|
| 359 |
+
)
|
| 360 |
+
self.reset_metrics()
|
| 361 |
+
|
| 362 |
+
def old_predtest_step(
|
| 363 |
+
self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
|
| 364 |
+
) -> Tuple[BatchedDataDict, OutputType]:
|
| 365 |
+
|
| 366 |
+
audio_batch = batch["audio"]["mixture"]
|
| 367 |
+
track_batch = batch.get("track", ["" for _ in range(len(audio_batch))])
|
| 368 |
+
|
| 369 |
+
output_list_of_dicts = [
|
| 370 |
+
self.fader(audio[None, ...], lambda a: self.test_forward(a, track))
|
| 371 |
+
for audio, track in zip(audio_batch, track_batch)
|
| 372 |
+
]
|
| 373 |
+
|
| 374 |
+
output_dict_of_lists = defaultdict(list)
|
| 375 |
+
|
| 376 |
+
for output_dict in output_list_of_dicts:
|
| 377 |
+
for stem, audio in output_dict.items():
|
| 378 |
+
output_dict_of_lists[stem].append(audio)
|
| 379 |
+
|
| 380 |
+
output = {
|
| 381 |
+
"audio": {
|
| 382 |
+
stem: torch.concat(output_list, dim=0)
|
| 383 |
+
for stem, output_list in output_dict_of_lists.items()
|
| 384 |
+
}
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
return batch, output
|
| 388 |
+
|
| 389 |
+
def predtest_step(
|
| 390 |
+
self, batch: BatchedDataDict, batch_idx: int = -1, dataloader_idx: int = 0
|
| 391 |
+
) -> Tuple[BatchedDataDict, OutputType]:
|
| 392 |
+
|
| 393 |
+
if getattr(self.model, "bypass_fader", False):
|
| 394 |
+
batch, output = self.model(batch)
|
| 395 |
+
else:
|
| 396 |
+
audio_batch = batch["audio"]["mixture"]
|
| 397 |
+
output = self.fader(
|
| 398 |
+
audio_batch, lambda a: self.test_forward(a, "", batch=batch)
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
return batch, output
|
| 402 |
+
|
| 403 |
+
def test_forward(
|
| 404 |
+
self, audio: torch.Tensor, track: str = "", batch: BatchedDataDict = None
|
| 405 |
+
) -> torch.Tensor:
|
| 406 |
+
|
| 407 |
+
if self.fader is None:
|
| 408 |
+
self.attach_fader()
|
| 409 |
+
|
| 410 |
+
cond = batch.get("condition", None)
|
| 411 |
+
|
| 412 |
+
if cond is not None and cond.shape[0] == 1:
|
| 413 |
+
cond = cond.repeat(audio.shape[0], 1)
|
| 414 |
+
|
| 415 |
+
_, output = self.forward(
|
| 416 |
+
{
|
| 417 |
+
"audio": {"mixture": audio},
|
| 418 |
+
"track": track,
|
| 419 |
+
"condition": cond,
|
| 420 |
+
}
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
return output["audio"]
|
| 424 |
+
|
| 425 |
+
def on_test_epoch_start(self) -> None:
|
| 426 |
+
self.attach_fader(force_reattach=True)
|
| 427 |
+
|
| 428 |
+
def test_step(
|
| 429 |
+
self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
|
| 430 |
+
) -> Any:
|
| 431 |
+
curr_test_prefix = f"test{dataloader_idx}"
|
| 432 |
+
|
| 433 |
+
if curr_test_prefix != self.test_prefix:
|
| 434 |
+
if self.test_prefix is not None:
|
| 435 |
+
self._on_test_epoch_end()
|
| 436 |
+
self.test_prefix = curr_test_prefix
|
| 437 |
+
|
| 438 |
+
with torch.inference_mode():
|
| 439 |
+
_, output = self.predtest_step(batch, batch_idx, dataloader_idx)
|
| 440 |
+
self.update_metrics(batch, output, mode="test")
|
| 441 |
+
|
| 442 |
+
return output
|
| 443 |
+
|
| 444 |
+
def on_test_epoch_end(self) -> None:
|
| 445 |
+
self._on_test_epoch_end()
|
| 446 |
+
|
| 447 |
+
def _on_test_epoch_end(self) -> None:
|
| 448 |
+
metric_dict = self.compute_metrics(mode="test")
|
| 449 |
+
self.log_dict_with_prefix(
|
| 450 |
+
metric_dict, self.test_prefix, prog_bar=True, add_dataloader_idx=False
|
| 451 |
+
)
|
| 452 |
+
self.reset_metrics()
|
| 453 |
+
|
| 454 |
+
def predict_step(
|
| 455 |
+
self,
|
| 456 |
+
batch: BatchedDataDict,
|
| 457 |
+
batch_idx: int = 0,
|
| 458 |
+
dataloader_idx: int = 0,
|
| 459 |
+
include_track_name: Optional[bool] = None,
|
| 460 |
+
get_no_vox_combinations: bool = True,
|
| 461 |
+
get_residual: bool = False,
|
| 462 |
+
treat_batch_as_channels: bool = False,
|
| 463 |
+
fs: Optional[int] = None,
|
| 464 |
+
) -> Any:
|
| 465 |
+
assert self.predict_output_path is not None
|
| 466 |
+
|
| 467 |
+
batch_size = batch["audio"]["mixture"].shape[0]
|
| 468 |
+
|
| 469 |
+
if include_track_name is None:
|
| 470 |
+
include_track_name = batch_size > 1
|
| 471 |
+
|
| 472 |
+
with torch.inference_mode():
|
| 473 |
+
batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
|
| 474 |
+
print("Pred test finished...")
|
| 475 |
+
torch.cuda.empty_cache()
|
| 476 |
+
metric_dict = {}
|
| 477 |
+
|
| 478 |
+
if get_residual:
|
| 479 |
+
mixture = batch["audio"]["mixture"]
|
| 480 |
+
extracted = sum([output["audio"][stem] for stem in output["audio"]])
|
| 481 |
+
residual = mixture - extracted
|
| 482 |
+
print(extracted.shape, mixture.shape, residual.shape)
|
| 483 |
+
|
| 484 |
+
output["audio"]["residual"] = residual
|
| 485 |
+
|
| 486 |
+
if get_no_vox_combinations:
|
| 487 |
+
no_vox_stems = [
|
| 488 |
+
stem for stem in output["audio"] if stem not in self._VOX_STEMS
|
| 489 |
+
]
|
| 490 |
+
no_vox_combinations = chain.from_iterable(
|
| 491 |
+
combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
for combination in no_vox_combinations:
|
| 495 |
+
combination_ = list(combination)
|
| 496 |
+
output["audio"]["+".join(combination_)] = sum(
|
| 497 |
+
[output["audio"][stem] for stem in combination_]
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
if treat_batch_as_channels:
|
| 501 |
+
for stem in output["audio"]:
|
| 502 |
+
output["audio"][stem] = output["audio"][stem].reshape(
|
| 503 |
+
1, -1, output["audio"][stem].shape[-1]
|
| 504 |
+
)
|
| 505 |
+
batch_size = 1
|
| 506 |
+
|
| 507 |
+
for b in range(batch_size):
|
| 508 |
+
print("!!", b)
|
| 509 |
+
for stem in output["audio"]:
|
| 510 |
+
print(f"Saving audio for {stem} to {self.predict_output_path}")
|
| 511 |
+
track_name = batch["track"][b].split("/")[-1]
|
| 512 |
+
|
| 513 |
+
if batch.get("audio", {}).get(stem, None) is not None:
|
| 514 |
+
self.test_metrics[stem].reset()
|
| 515 |
+
metrics = self.test_metrics[stem](
|
| 516 |
+
batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
|
| 517 |
+
)
|
| 518 |
+
snr = metrics["snr"]
|
| 519 |
+
sisnr = metrics["sisnr"]
|
| 520 |
+
sdr = metrics["sdr"]
|
| 521 |
+
metric_dict[stem] = metrics
|
| 522 |
+
print(
|
| 523 |
+
track_name,
|
| 524 |
+
f"snr={snr:2.2f} dB",
|
| 525 |
+
f"sisnr={sisnr:2.2f}",
|
| 526 |
+
f"sdr={sdr:2.2f} dB",
|
| 527 |
+
)
|
| 528 |
+
filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
|
| 529 |
+
else:
|
| 530 |
+
filename = f"{stem}.wav"
|
| 531 |
+
|
| 532 |
+
if include_track_name:
|
| 533 |
+
output_dir = os.path.join(self.predict_output_path, track_name)
|
| 534 |
+
else:
|
| 535 |
+
output_dir = self.predict_output_path
|
| 536 |
+
|
| 537 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 538 |
+
|
| 539 |
+
if fs is None:
|
| 540 |
+
fs = self.fs
|
| 541 |
+
|
| 542 |
+
ta.save(
|
| 543 |
+
os.path.join(output_dir, filename),
|
| 544 |
+
output["audio"][stem][b, ...].cpu(),
|
| 545 |
+
fs,
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
return metric_dict
|
| 549 |
+
|
| 550 |
+
def get_stems(
|
| 551 |
+
self,
|
| 552 |
+
batch: BatchedDataDict,
|
| 553 |
+
batch_idx: int = 0,
|
| 554 |
+
dataloader_idx: int = 0,
|
| 555 |
+
include_track_name: Optional[bool] = None,
|
| 556 |
+
get_no_vox_combinations: bool = True,
|
| 557 |
+
get_residual: bool = False,
|
| 558 |
+
treat_batch_as_channels: bool = False,
|
| 559 |
+
fs: Optional[int] = None,
|
| 560 |
+
) -> Any:
|
| 561 |
+
assert self.predict_output_path is not None
|
| 562 |
+
|
| 563 |
+
batch_size = batch["audio"]["mixture"].shape[0]
|
| 564 |
+
|
| 565 |
+
if include_track_name is None:
|
| 566 |
+
include_track_name = batch_size > 1
|
| 567 |
+
|
| 568 |
+
with torch.inference_mode():
|
| 569 |
+
batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
|
| 570 |
+
torch.cuda.empty_cache()
|
| 571 |
+
metric_dict = {}
|
| 572 |
+
|
| 573 |
+
if get_residual:
|
| 574 |
+
mixture = batch["audio"]["mixture"]
|
| 575 |
+
extracted = sum([output["audio"][stem] for stem in output["audio"]])
|
| 576 |
+
residual = mixture - extracted
|
| 577 |
+
|
| 578 |
+
output["audio"]["residual"] = residual
|
| 579 |
+
|
| 580 |
+
if get_no_vox_combinations:
|
| 581 |
+
no_vox_stems = [
|
| 582 |
+
stem for stem in output["audio"] if stem not in self._VOX_STEMS
|
| 583 |
+
]
|
| 584 |
+
no_vox_combinations = chain.from_iterable(
|
| 585 |
+
combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
for combination in no_vox_combinations:
|
| 589 |
+
combination_ = list(combination)
|
| 590 |
+
output["audio"]["+".join(combination_)] = sum(
|
| 591 |
+
[output["audio"][stem] for stem in combination_]
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
if treat_batch_as_channels:
|
| 595 |
+
for stem in output["audio"]:
|
| 596 |
+
output["audio"][stem] = output["audio"][stem].reshape(
|
| 597 |
+
1, -1, output["audio"][stem].shape[-1]
|
| 598 |
+
)
|
| 599 |
+
batch_size = 1
|
| 600 |
+
|
| 601 |
+
result = {}
|
| 602 |
+
for b in range(batch_size):
|
| 603 |
+
for stem in output["audio"]:
|
| 604 |
+
track_name = batch["track"][b].split("/")[-1]
|
| 605 |
+
|
| 606 |
+
if batch.get("audio", {}).get(stem, None) is not None:
|
| 607 |
+
self.test_metrics[stem].reset()
|
| 608 |
+
metrics = self.test_metrics[stem](
|
| 609 |
+
batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
|
| 610 |
+
)
|
| 611 |
+
snr = metrics["snr"]
|
| 612 |
+
sisnr = metrics["sisnr"]
|
| 613 |
+
sdr = metrics["sdr"]
|
| 614 |
+
metric_dict[stem] = metrics
|
| 615 |
+
print(
|
| 616 |
+
track_name,
|
| 617 |
+
f"snr={snr:2.2f} dB",
|
| 618 |
+
f"sisnr={sisnr:2.2f}",
|
| 619 |
+
f"sdr={sdr:2.2f} dB",
|
| 620 |
+
)
|
| 621 |
+
filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
|
| 622 |
+
else:
|
| 623 |
+
filename = f"{stem}.wav"
|
| 624 |
+
|
| 625 |
+
if include_track_name:
|
| 626 |
+
output_dir = os.path.join(self.predict_output_path, track_name)
|
| 627 |
+
else:
|
| 628 |
+
output_dir = self.predict_output_path
|
| 629 |
+
|
| 630 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 631 |
+
|
| 632 |
+
if fs is None:
|
| 633 |
+
fs = self.fs
|
| 634 |
+
|
| 635 |
+
result[stem] = output["audio"][stem][b, ...].cpu().numpy()
|
| 636 |
+
|
| 637 |
+
return result
|
| 638 |
+
|
| 639 |
+
def load_state_dict(
|
| 640 |
+
self, state_dict: Mapping[str, Any], strict: bool = False
|
| 641 |
+
) -> Any:
|
| 642 |
+
|
| 643 |
+
return super().load_state_dict(state_dict, strict=False)
|
| 644 |
+
|
| 645 |
+
def set_predict_output_path(self, path: str) -> None:
|
| 646 |
+
self.predict_output_path = path
|
| 647 |
+
os.makedirs(self.predict_output_path, exist_ok=True)
|
| 648 |
+
|
| 649 |
+
self.attach_fader()
|
| 650 |
+
|
| 651 |
+
def attach_fader(self, force_reattach=False) -> None:
|
| 652 |
+
if self.fader is None or force_reattach:
|
| 653 |
+
self.fader = parse_fader_config(self.fader_config)
|
| 654 |
+
self.fader.to(self.device)
|
| 655 |
+
|
| 656 |
+
def log_dict_with_prefix(
|
| 657 |
+
self,
|
| 658 |
+
dict_: Dict[str, torch.Tensor],
|
| 659 |
+
prefix: str,
|
| 660 |
+
batch_size: Optional[int] = None,
|
| 661 |
+
**kwargs: Any,
|
| 662 |
+
) -> None:
|
| 663 |
+
self.log_dict(
|
| 664 |
+
{f"{prefix}/{k}": v for k, v in dict_.items()},
|
| 665 |
+
batch_size=batch_size,
|
| 666 |
+
logger=True,
|
| 667 |
+
sync_dist=True,
|
| 668 |
+
**kwargs,
|
| 669 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mvsepless/models/bandit/core/data/__init__.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
| 1 |
-
from .dnr.datamodule import DivideAndRemasterDataModule
|
| 2 |
-
from .musdb.datamodule import MUSDB18DataModule
|
|
|
|
| 1 |
+
from .dnr.datamodule import DivideAndRemasterDataModule
|
| 2 |
+
from .musdb.datamodule import MUSDB18DataModule
|
mvsepless/models/bandit/core/data/_types.py
CHANGED
|
@@ -1,17 +1,17 @@
|
|
| 1 |
-
from typing import Dict, Sequence, TypedDict
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
AudioDict = Dict[str, torch.Tensor]
|
| 6 |
-
|
| 7 |
-
DataDict = TypedDict("DataDict", {"audio": AudioDict, "track": str})
|
| 8 |
-
|
| 9 |
-
BatchedDataDict = TypedDict(
|
| 10 |
-
"BatchedDataDict", {"audio": AudioDict, "track": Sequence[str]}
|
| 11 |
-
)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class DataDictWithLanguage(TypedDict):
|
| 15 |
-
audio: AudioDict
|
| 16 |
-
track: str
|
| 17 |
-
language: str
|
|
|
|
| 1 |
+
from typing import Dict, Sequence, TypedDict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
AudioDict = Dict[str, torch.Tensor]
|
| 6 |
+
|
| 7 |
+
DataDict = TypedDict("DataDict", {"audio": AudioDict, "track": str})
|
| 8 |
+
|
| 9 |
+
BatchedDataDict = TypedDict(
|
| 10 |
+
"BatchedDataDict", {"audio": AudioDict, "track": Sequence[str]}
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DataDictWithLanguage(TypedDict):
|
| 15 |
+
audio: AudioDict
|
| 16 |
+
track: str
|
| 17 |
+
language: str
|
mvsepless/models/bandit/core/data/augmentation.py
CHANGED
|
@@ -1,102 +1,102 @@
|
|
| 1 |
-
from abc import ABC
|
| 2 |
-
from typing import Any, Dict, Union
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
import torch_audiomentations as tam
|
| 6 |
-
from torch import nn
|
| 7 |
-
|
| 8 |
-
from ._types import BatchedDataDict, DataDict
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class BaseAugmentor(nn.Module, ABC):
|
| 12 |
-
def forward(
|
| 13 |
-
self, item: Union[DataDict, BatchedDataDict]
|
| 14 |
-
) -> Union[DataDict, BatchedDataDict]:
|
| 15 |
-
raise NotImplementedError
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class StemAugmentor(BaseAugmentor):
|
| 19 |
-
def __init__(
|
| 20 |
-
self,
|
| 21 |
-
audiomentations: Dict[str, Dict[str, Any]],
|
| 22 |
-
fix_clipping: bool = True,
|
| 23 |
-
scaler_margin: float = 0.5,
|
| 24 |
-
apply_both_default_and_common: bool = False,
|
| 25 |
-
) -> None:
|
| 26 |
-
super().__init__()
|
| 27 |
-
|
| 28 |
-
augmentations = {}
|
| 29 |
-
|
| 30 |
-
self.has_default = "[default]" in audiomentations
|
| 31 |
-
self.has_common = "[common]" in audiomentations
|
| 32 |
-
self.apply_both_default_and_common = apply_both_default_and_common
|
| 33 |
-
|
| 34 |
-
for stem in audiomentations:
|
| 35 |
-
if audiomentations[stem]["name"] == "Compose":
|
| 36 |
-
augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
|
| 37 |
-
[
|
| 38 |
-
getattr(tam, aug["name"])(**aug["kwargs"])
|
| 39 |
-
for aug in audiomentations[stem]["kwargs"]["transforms"]
|
| 40 |
-
],
|
| 41 |
-
**audiomentations[stem]["kwargs"]["kwargs"],
|
| 42 |
-
)
|
| 43 |
-
else:
|
| 44 |
-
augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
|
| 45 |
-
**audiomentations[stem]["kwargs"]
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
self.augmentations = nn.ModuleDict(augmentations)
|
| 49 |
-
self.fix_clipping = fix_clipping
|
| 50 |
-
self.scaler_margin = scaler_margin
|
| 51 |
-
|
| 52 |
-
def check_and_fix_clipping(
|
| 53 |
-
self, item: Union[DataDict, BatchedDataDict]
|
| 54 |
-
) -> Union[DataDict, BatchedDataDict]:
|
| 55 |
-
max_abs = []
|
| 56 |
-
|
| 57 |
-
for stem in item["audio"]:
|
| 58 |
-
max_abs.append(item["audio"][stem].abs().max().item())
|
| 59 |
-
|
| 60 |
-
if max(max_abs) > 1.0:
|
| 61 |
-
scaler = 1.0 / (
|
| 62 |
-
max(max_abs)
|
| 63 |
-
+ torch.rand((1,), device=item["audio"]["mixture"].device)
|
| 64 |
-
* self.scaler_margin
|
| 65 |
-
)
|
| 66 |
-
|
| 67 |
-
for stem in item["audio"]:
|
| 68 |
-
item["audio"][stem] *= scaler
|
| 69 |
-
|
| 70 |
-
return item
|
| 71 |
-
|
| 72 |
-
def forward(
|
| 73 |
-
self, item: Union[DataDict, BatchedDataDict]
|
| 74 |
-
) -> Union[DataDict, BatchedDataDict]:
|
| 75 |
-
|
| 76 |
-
for stem in item["audio"]:
|
| 77 |
-
if stem == "mixture":
|
| 78 |
-
continue
|
| 79 |
-
|
| 80 |
-
if self.has_common:
|
| 81 |
-
item["audio"][stem] = self.augmentations["[common]"](
|
| 82 |
-
item["audio"][stem]
|
| 83 |
-
).samples
|
| 84 |
-
|
| 85 |
-
if stem in self.augmentations:
|
| 86 |
-
item["audio"][stem] = self.augmentations[stem](
|
| 87 |
-
item["audio"][stem]
|
| 88 |
-
).samples
|
| 89 |
-
elif self.has_default:
|
| 90 |
-
if not self.has_common or self.apply_both_default_and_common:
|
| 91 |
-
item["audio"][stem] = self.augmentations["[default]"](
|
| 92 |
-
item["audio"][stem]
|
| 93 |
-
).samples
|
| 94 |
-
|
| 95 |
-
item["audio"]["mixture"] = sum(
|
| 96 |
-
[item["audio"][stem] for stem in item["audio"] if stem != "mixture"]
|
| 97 |
-
) # type: ignore[call-overload, assignment]
|
| 98 |
-
|
| 99 |
-
if self.fix_clipping:
|
| 100 |
-
item = self.check_and_fix_clipping(item)
|
| 101 |
-
|
| 102 |
-
return item
|
|
|
|
| 1 |
+
from abc import ABC
|
| 2 |
+
from typing import Any, Dict, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch_audiomentations as tam
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from ._types import BatchedDataDict, DataDict
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BaseAugmentor(nn.Module, ABC):
|
| 12 |
+
def forward(
|
| 13 |
+
self, item: Union[DataDict, BatchedDataDict]
|
| 14 |
+
) -> Union[DataDict, BatchedDataDict]:
|
| 15 |
+
raise NotImplementedError
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class StemAugmentor(BaseAugmentor):
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
audiomentations: Dict[str, Dict[str, Any]],
|
| 22 |
+
fix_clipping: bool = True,
|
| 23 |
+
scaler_margin: float = 0.5,
|
| 24 |
+
apply_both_default_and_common: bool = False,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
augmentations = {}
|
| 29 |
+
|
| 30 |
+
self.has_default = "[default]" in audiomentations
|
| 31 |
+
self.has_common = "[common]" in audiomentations
|
| 32 |
+
self.apply_both_default_and_common = apply_both_default_and_common
|
| 33 |
+
|
| 34 |
+
for stem in audiomentations:
|
| 35 |
+
if audiomentations[stem]["name"] == "Compose":
|
| 36 |
+
augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
|
| 37 |
+
[
|
| 38 |
+
getattr(tam, aug["name"])(**aug["kwargs"])
|
| 39 |
+
for aug in audiomentations[stem]["kwargs"]["transforms"]
|
| 40 |
+
],
|
| 41 |
+
**audiomentations[stem]["kwargs"]["kwargs"],
|
| 42 |
+
)
|
| 43 |
+
else:
|
| 44 |
+
augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
|
| 45 |
+
**audiomentations[stem]["kwargs"]
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
self.augmentations = nn.ModuleDict(augmentations)
|
| 49 |
+
self.fix_clipping = fix_clipping
|
| 50 |
+
self.scaler_margin = scaler_margin
|
| 51 |
+
|
| 52 |
+
def check_and_fix_clipping(
|
| 53 |
+
self, item: Union[DataDict, BatchedDataDict]
|
| 54 |
+
) -> Union[DataDict, BatchedDataDict]:
|
| 55 |
+
max_abs = []
|
| 56 |
+
|
| 57 |
+
for stem in item["audio"]:
|
| 58 |
+
max_abs.append(item["audio"][stem].abs().max().item())
|
| 59 |
+
|
| 60 |
+
if max(max_abs) > 1.0:
|
| 61 |
+
scaler = 1.0 / (
|
| 62 |
+
max(max_abs)
|
| 63 |
+
+ torch.rand((1,), device=item["audio"]["mixture"].device)
|
| 64 |
+
* self.scaler_margin
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
for stem in item["audio"]:
|
| 68 |
+
item["audio"][stem] *= scaler
|
| 69 |
+
|
| 70 |
+
return item
|
| 71 |
+
|
| 72 |
+
def forward(
|
| 73 |
+
self, item: Union[DataDict, BatchedDataDict]
|
| 74 |
+
) -> Union[DataDict, BatchedDataDict]:
|
| 75 |
+
|
| 76 |
+
for stem in item["audio"]:
|
| 77 |
+
if stem == "mixture":
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
if self.has_common:
|
| 81 |
+
item["audio"][stem] = self.augmentations["[common]"](
|
| 82 |
+
item["audio"][stem]
|
| 83 |
+
).samples
|
| 84 |
+
|
| 85 |
+
if stem in self.augmentations:
|
| 86 |
+
item["audio"][stem] = self.augmentations[stem](
|
| 87 |
+
item["audio"][stem]
|
| 88 |
+
).samples
|
| 89 |
+
elif self.has_default:
|
| 90 |
+
if not self.has_common or self.apply_both_default_and_common:
|
| 91 |
+
item["audio"][stem] = self.augmentations["[default]"](
|
| 92 |
+
item["audio"][stem]
|
| 93 |
+
).samples
|
| 94 |
+
|
| 95 |
+
item["audio"]["mixture"] = sum(
|
| 96 |
+
[item["audio"][stem] for stem in item["audio"] if stem != "mixture"]
|
| 97 |
+
) # type: ignore[call-overload, assignment]
|
| 98 |
+
|
| 99 |
+
if self.fix_clipping:
|
| 100 |
+
item = self.check_and_fix_clipping(item)
|
| 101 |
+
|
| 102 |
+
return item
|
mvsepless/models/bandit/core/data/augmented.py
CHANGED
|
@@ -1,34 +1,34 @@
|
|
| 1 |
-
import warnings
|
| 2 |
-
from typing import Dict, Optional, Union
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
from torch import nn
|
| 6 |
-
from torch.utils import data
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class AugmentedDataset(data.Dataset):
|
| 10 |
-
def __init__(
|
| 11 |
-
self,
|
| 12 |
-
dataset: data.Dataset,
|
| 13 |
-
augmentation: nn.Module = nn.Identity(),
|
| 14 |
-
target_length: Optional[int] = None,
|
| 15 |
-
) -> None:
|
| 16 |
-
warnings.warn(
|
| 17 |
-
"This class is no longer used. Attach augmentation to "
|
| 18 |
-
"the LightningSystem instead.",
|
| 19 |
-
DeprecationWarning,
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
self.dataset = dataset
|
| 23 |
-
self.augmentation = augmentation
|
| 24 |
-
|
| 25 |
-
self.ds_length: int = len(dataset) # type: ignore[arg-type]
|
| 26 |
-
self.length = target_length if target_length is not None else self.ds_length
|
| 27 |
-
|
| 28 |
-
def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str, torch.Tensor]]]:
|
| 29 |
-
item = self.dataset[index % self.ds_length]
|
| 30 |
-
item = self.augmentation(item)
|
| 31 |
-
return item
|
| 32 |
-
|
| 33 |
-
def __len__(self) -> int:
|
| 34 |
-
return self.length
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Dict, Optional, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.utils import data
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AugmentedDataset(data.Dataset):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
dataset: data.Dataset,
|
| 13 |
+
augmentation: nn.Module = nn.Identity(),
|
| 14 |
+
target_length: Optional[int] = None,
|
| 15 |
+
) -> None:
|
| 16 |
+
warnings.warn(
|
| 17 |
+
"This class is no longer used. Attach augmentation to "
|
| 18 |
+
"the LightningSystem instead.",
|
| 19 |
+
DeprecationWarning,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
self.dataset = dataset
|
| 23 |
+
self.augmentation = augmentation
|
| 24 |
+
|
| 25 |
+
self.ds_length: int = len(dataset) # type: ignore[arg-type]
|
| 26 |
+
self.length = target_length if target_length is not None else self.ds_length
|
| 27 |
+
|
| 28 |
+
def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str, torch.Tensor]]]:
|
| 29 |
+
item = self.dataset[index % self.ds_length]
|
| 30 |
+
item = self.augmentation(item)
|
| 31 |
+
return item
|
| 32 |
+
|
| 33 |
+
def __len__(self) -> int:
|
| 34 |
+
return self.length
|
mvsepless/models/bandit/core/data/base.py
CHANGED
|
@@ -1,60 +1,60 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from abc import ABC, abstractmethod
|
| 3 |
-
from typing import Any, Dict, List, Optional
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import pedalboard as pb
|
| 7 |
-
import torch
|
| 8 |
-
import torchaudio as ta
|
| 9 |
-
from torch.utils import data
|
| 10 |
-
|
| 11 |
-
from ._types import AudioDict, DataDict
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class BaseSourceSeparationDataset(data.Dataset, ABC):
|
| 15 |
-
def __init__(
|
| 16 |
-
self,
|
| 17 |
-
split: str,
|
| 18 |
-
stems: List[str],
|
| 19 |
-
files: List[str],
|
| 20 |
-
data_path: str,
|
| 21 |
-
fs: int,
|
| 22 |
-
npy_memmap: bool,
|
| 23 |
-
recompute_mixture: bool,
|
| 24 |
-
):
|
| 25 |
-
self.split = split
|
| 26 |
-
self.stems = stems
|
| 27 |
-
self.stems_no_mixture = [s for s in stems if s != "mixture"]
|
| 28 |
-
self.files = files
|
| 29 |
-
self.data_path = data_path
|
| 30 |
-
self.fs = fs
|
| 31 |
-
self.npy_memmap = npy_memmap
|
| 32 |
-
self.recompute_mixture = recompute_mixture
|
| 33 |
-
|
| 34 |
-
@abstractmethod
|
| 35 |
-
def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
|
| 36 |
-
raise NotImplementedError
|
| 37 |
-
|
| 38 |
-
def _get_audio(self, stems, identifier: Dict[str, Any]):
|
| 39 |
-
audio = {}
|
| 40 |
-
for stem in stems:
|
| 41 |
-
audio[stem] = self.get_stem(stem=stem, identifier=identifier)
|
| 42 |
-
|
| 43 |
-
return audio
|
| 44 |
-
|
| 45 |
-
def get_audio(self, identifier: Dict[str, Any]) -> AudioDict:
|
| 46 |
-
|
| 47 |
-
if self.recompute_mixture:
|
| 48 |
-
audio = self._get_audio(self.stems_no_mixture, identifier=identifier)
|
| 49 |
-
audio["mixture"] = self.compute_mixture(audio)
|
| 50 |
-
return audio
|
| 51 |
-
else:
|
| 52 |
-
return self._get_audio(self.stems, identifier=identifier)
|
| 53 |
-
|
| 54 |
-
@abstractmethod
|
| 55 |
-
def get_identifier(self, index: int) -> Dict[str, Any]:
|
| 56 |
-
pass
|
| 57 |
-
|
| 58 |
-
def compute_mixture(self, audio: AudioDict) -> torch.Tensor:
|
| 59 |
-
|
| 60 |
-
return sum(audio[stem] for stem in audio if stem != "mixture")
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pedalboard as pb
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio as ta
|
| 9 |
+
from torch.utils import data
|
| 10 |
+
|
| 11 |
+
from ._types import AudioDict, DataDict
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BaseSourceSeparationDataset(data.Dataset, ABC):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
split: str,
|
| 18 |
+
stems: List[str],
|
| 19 |
+
files: List[str],
|
| 20 |
+
data_path: str,
|
| 21 |
+
fs: int,
|
| 22 |
+
npy_memmap: bool,
|
| 23 |
+
recompute_mixture: bool,
|
| 24 |
+
):
|
| 25 |
+
self.split = split
|
| 26 |
+
self.stems = stems
|
| 27 |
+
self.stems_no_mixture = [s for s in stems if s != "mixture"]
|
| 28 |
+
self.files = files
|
| 29 |
+
self.data_path = data_path
|
| 30 |
+
self.fs = fs
|
| 31 |
+
self.npy_memmap = npy_memmap
|
| 32 |
+
self.recompute_mixture = recompute_mixture
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
|
| 36 |
+
raise NotImplementedError
|
| 37 |
+
|
| 38 |
+
def _get_audio(self, stems, identifier: Dict[str, Any]):
|
| 39 |
+
audio = {}
|
| 40 |
+
for stem in stems:
|
| 41 |
+
audio[stem] = self.get_stem(stem=stem, identifier=identifier)
|
| 42 |
+
|
| 43 |
+
return audio
|
| 44 |
+
|
| 45 |
+
def get_audio(self, identifier: Dict[str, Any]) -> AudioDict:
|
| 46 |
+
|
| 47 |
+
if self.recompute_mixture:
|
| 48 |
+
audio = self._get_audio(self.stems_no_mixture, identifier=identifier)
|
| 49 |
+
audio["mixture"] = self.compute_mixture(audio)
|
| 50 |
+
return audio
|
| 51 |
+
else:
|
| 52 |
+
return self._get_audio(self.stems, identifier=identifier)
|
| 53 |
+
|
| 54 |
+
@abstractmethod
|
| 55 |
+
def get_identifier(self, index: int) -> Dict[str, Any]:
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
def compute_mixture(self, audio: AudioDict) -> torch.Tensor:
|
| 59 |
+
|
| 60 |
+
return sum(audio[stem] for stem in audio if stem != "mixture")
|
mvsepless/models/bandit/core/data/dnr/datamodule.py
CHANGED
|
@@ -1,68 +1,64 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from typing import Mapping, Optional
|
| 3 |
-
|
| 4 |
-
import pytorch_lightning as pl
|
| 5 |
-
|
| 6 |
-
from .dataset import (
|
| 7 |
-
DivideAndRemasterDataset,
|
| 8 |
-
DivideAndRemasterDeterministicChunkDataset,
|
| 9 |
-
DivideAndRemasterRandomChunkDataset,
|
| 10 |
-
DivideAndRemasterRandomChunkDatasetWithSpeechReverb,
|
| 11 |
-
)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def DivideAndRemasterDataModule(
|
| 15 |
-
data_root: str = "$DATA_ROOT/DnR/v2",
|
| 16 |
-
batch_size: int = 2,
|
| 17 |
-
num_workers: int = 8,
|
| 18 |
-
train_kwargs: Optional[Mapping] = None,
|
| 19 |
-
val_kwargs: Optional[Mapping] = None,
|
| 20 |
-
test_kwargs: Optional[Mapping] = None,
|
| 21 |
-
datamodule_kwargs: Optional[Mapping] = None,
|
| 22 |
-
use_speech_reverb: bool = False,
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign]
|
| 67 |
-
|
| 68 |
-
return datamodule
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Mapping, Optional
|
| 3 |
+
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
|
| 6 |
+
from .dataset import (
|
| 7 |
+
DivideAndRemasterDataset,
|
| 8 |
+
DivideAndRemasterDeterministicChunkDataset,
|
| 9 |
+
DivideAndRemasterRandomChunkDataset,
|
| 10 |
+
DivideAndRemasterRandomChunkDatasetWithSpeechReverb,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def DivideAndRemasterDataModule(
|
| 15 |
+
data_root: str = "$DATA_ROOT/DnR/v2",
|
| 16 |
+
batch_size: int = 2,
|
| 17 |
+
num_workers: int = 8,
|
| 18 |
+
train_kwargs: Optional[Mapping] = None,
|
| 19 |
+
val_kwargs: Optional[Mapping] = None,
|
| 20 |
+
test_kwargs: Optional[Mapping] = None,
|
| 21 |
+
datamodule_kwargs: Optional[Mapping] = None,
|
| 22 |
+
use_speech_reverb: bool = False,
|
| 23 |
+
) -> pl.LightningDataModule:
|
| 24 |
+
if train_kwargs is None:
|
| 25 |
+
train_kwargs = {}
|
| 26 |
+
|
| 27 |
+
if val_kwargs is None:
|
| 28 |
+
val_kwargs = {}
|
| 29 |
+
|
| 30 |
+
if test_kwargs is None:
|
| 31 |
+
test_kwargs = {}
|
| 32 |
+
|
| 33 |
+
if datamodule_kwargs is None:
|
| 34 |
+
datamodule_kwargs = {}
|
| 35 |
+
|
| 36 |
+
if num_workers is None:
|
| 37 |
+
num_workers = os.cpu_count()
|
| 38 |
+
|
| 39 |
+
if num_workers is None:
|
| 40 |
+
num_workers = 32
|
| 41 |
+
|
| 42 |
+
num_workers = min(num_workers, 64)
|
| 43 |
+
|
| 44 |
+
if use_speech_reverb:
|
| 45 |
+
train_cls = DivideAndRemasterRandomChunkDatasetWithSpeechReverb
|
| 46 |
+
else:
|
| 47 |
+
train_cls = DivideAndRemasterRandomChunkDataset
|
| 48 |
+
|
| 49 |
+
train_dataset = train_cls(data_root, "train", **train_kwargs)
|
| 50 |
+
|
| 51 |
+
datamodule = pl.LightningDataModule.from_datasets(
|
| 52 |
+
train_dataset=train_dataset,
|
| 53 |
+
val_dataset=DivideAndRemasterDeterministicChunkDataset(
|
| 54 |
+
data_root, "val", **val_kwargs
|
| 55 |
+
),
|
| 56 |
+
test_dataset=DivideAndRemasterDataset(data_root, "test", **test_kwargs),
|
| 57 |
+
batch_size=batch_size,
|
| 58 |
+
num_workers=num_workers,
|
| 59 |
+
**datamodule_kwargs,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign]
|
| 63 |
+
|
| 64 |
+
return datamodule
|
|
|
|
|
|
|
|
|
|
|
|
mvsepless/models/bandit/core/data/dnr/dataset.py
CHANGED
|
@@ -1,366 +1,360 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from abc import ABC
|
| 3 |
-
from typing import Any, Dict, List, Optional
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import pedalboard as pb
|
| 7 |
-
import torch
|
| 8 |
-
import torchaudio as ta
|
| 9 |
-
from torch.utils import data
|
| 10 |
-
|
| 11 |
-
from .._types import AudioDict, DataDict
|
| 12 |
-
from ..base import BaseSourceSeparationDataset
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC):
|
| 16 |
-
ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"]
|
| 17 |
-
STEM_NAME_MAP = {
|
| 18 |
-
"mixture": "mix",
|
| 19 |
-
"speech": "speech",
|
| 20 |
-
"music": "music",
|
| 21 |
-
"effects": "sfx",
|
| 22 |
-
}
|
| 23 |
-
SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"}
|
| 24 |
-
|
| 25 |
-
FULL_TRACK_LENGTH_SECOND = 60
|
| 26 |
-
FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100
|
| 27 |
-
|
| 28 |
-
def __init__(
|
| 29 |
-
self,
|
| 30 |
-
split: str,
|
| 31 |
-
stems: List[str],
|
| 32 |
-
files: List[str],
|
| 33 |
-
data_path: str,
|
| 34 |
-
fs: int = 44100,
|
| 35 |
-
npy_memmap: bool = True,
|
| 36 |
-
recompute_mixture: bool = False,
|
| 37 |
-
) -> None:
|
| 38 |
-
super().__init__(
|
| 39 |
-
split=split,
|
| 40 |
-
stems=stems,
|
| 41 |
-
files=files,
|
| 42 |
-
data_path=data_path,
|
| 43 |
-
fs=fs,
|
| 44 |
-
npy_memmap=npy_memmap,
|
| 45 |
-
recompute_mixture=recompute_mixture,
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
|
| 49 |
-
|
| 50 |
-
if stem == "mne":
|
| 51 |
-
return self.get_stem(stem="music", identifier=identifier) + self.get_stem(
|
| 52 |
-
stem="effects", identifier=identifier
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
track = identifier["track"]
|
| 56 |
-
path = os.path.join(self.data_path, track)
|
| 57 |
-
|
| 58 |
-
if self.npy_memmap:
|
| 59 |
-
audio = np.load(
|
| 60 |
-
os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"), mmap_mode="r"
|
| 61 |
-
)
|
| 62 |
-
else:
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
files =
|
| 95 |
-
|
| 96 |
-
f
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
assert len(files) ==
|
| 103 |
-
elif split == "
|
| 104 |
-
assert len(files) ==
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
f
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
assert len(files) ==
|
| 150 |
-
elif split == "
|
| 151 |
-
assert len(files) ==
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
self.
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
self.
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
stems
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
)
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
pprint(track_)
|
| 362 |
-
track_["audio"] = {k: v.shape for k, v in track_["audio"].items()}
|
| 363 |
-
pprint(track_)
|
| 364 |
-
# break
|
| 365 |
-
|
| 366 |
-
break
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import ABC
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pedalboard as pb
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio as ta
|
| 9 |
+
from torch.utils import data
|
| 10 |
+
|
| 11 |
+
from .._types import AudioDict, DataDict
|
| 12 |
+
from ..base import BaseSourceSeparationDataset
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC):
|
| 16 |
+
ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"]
|
| 17 |
+
STEM_NAME_MAP = {
|
| 18 |
+
"mixture": "mix",
|
| 19 |
+
"speech": "speech",
|
| 20 |
+
"music": "music",
|
| 21 |
+
"effects": "sfx",
|
| 22 |
+
}
|
| 23 |
+
SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"}
|
| 24 |
+
|
| 25 |
+
FULL_TRACK_LENGTH_SECOND = 60
|
| 26 |
+
FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
split: str,
|
| 31 |
+
stems: List[str],
|
| 32 |
+
files: List[str],
|
| 33 |
+
data_path: str,
|
| 34 |
+
fs: int = 44100,
|
| 35 |
+
npy_memmap: bool = True,
|
| 36 |
+
recompute_mixture: bool = False,
|
| 37 |
+
) -> None:
|
| 38 |
+
super().__init__(
|
| 39 |
+
split=split,
|
| 40 |
+
stems=stems,
|
| 41 |
+
files=files,
|
| 42 |
+
data_path=data_path,
|
| 43 |
+
fs=fs,
|
| 44 |
+
npy_memmap=npy_memmap,
|
| 45 |
+
recompute_mixture=recompute_mixture,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
|
| 49 |
+
|
| 50 |
+
if stem == "mne":
|
| 51 |
+
return self.get_stem(stem="music", identifier=identifier) + self.get_stem(
|
| 52 |
+
stem="effects", identifier=identifier
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
track = identifier["track"]
|
| 56 |
+
path = os.path.join(self.data_path, track)
|
| 57 |
+
|
| 58 |
+
if self.npy_memmap:
|
| 59 |
+
audio = np.load(
|
| 60 |
+
os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"), mmap_mode="r"
|
| 61 |
+
)
|
| 62 |
+
else:
|
| 63 |
+
audio, _ = ta.load(os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav"))
|
| 64 |
+
|
| 65 |
+
return audio
|
| 66 |
+
|
| 67 |
+
def get_identifier(self, index):
|
| 68 |
+
return dict(track=self.files[index])
|
| 69 |
+
|
| 70 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 71 |
+
identifier = self.get_identifier(index)
|
| 72 |
+
audio = self.get_audio(identifier)
|
| 73 |
+
|
| 74 |
+
return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class DivideAndRemasterDataset(DivideAndRemasterBaseDataset):
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
data_root: str,
|
| 81 |
+
split: str,
|
| 82 |
+
stems: Optional[List[str]] = None,
|
| 83 |
+
fs: int = 44100,
|
| 84 |
+
npy_memmap: bool = True,
|
| 85 |
+
) -> None:
|
| 86 |
+
|
| 87 |
+
if stems is None:
|
| 88 |
+
stems = self.ALLOWED_STEMS
|
| 89 |
+
self.stems = stems
|
| 90 |
+
|
| 91 |
+
data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
|
| 92 |
+
|
| 93 |
+
files = sorted(os.listdir(data_path))
|
| 94 |
+
files = [
|
| 95 |
+
f
|
| 96 |
+
for f in files
|
| 97 |
+
if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
|
| 98 |
+
]
|
| 99 |
+
if split == "train":
|
| 100 |
+
assert len(files) == 3406, len(files)
|
| 101 |
+
elif split == "val":
|
| 102 |
+
assert len(files) == 487, len(files)
|
| 103 |
+
elif split == "test":
|
| 104 |
+
assert len(files) == 973, len(files)
|
| 105 |
+
|
| 106 |
+
self.n_tracks = len(files)
|
| 107 |
+
|
| 108 |
+
super().__init__(
|
| 109 |
+
data_path=data_path,
|
| 110 |
+
split=split,
|
| 111 |
+
stems=stems,
|
| 112 |
+
files=files,
|
| 113 |
+
fs=fs,
|
| 114 |
+
npy_memmap=npy_memmap,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def __len__(self) -> int:
|
| 118 |
+
return self.n_tracks
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset):
|
| 122 |
+
def __init__(
|
| 123 |
+
self,
|
| 124 |
+
data_root: str,
|
| 125 |
+
split: str,
|
| 126 |
+
target_length: int,
|
| 127 |
+
chunk_size_second: float,
|
| 128 |
+
stems: Optional[List[str]] = None,
|
| 129 |
+
fs: int = 44100,
|
| 130 |
+
npy_memmap: bool = True,
|
| 131 |
+
) -> None:
|
| 132 |
+
|
| 133 |
+
if stems is None:
|
| 134 |
+
stems = self.ALLOWED_STEMS
|
| 135 |
+
self.stems = stems
|
| 136 |
+
|
| 137 |
+
data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
|
| 138 |
+
|
| 139 |
+
files = sorted(os.listdir(data_path))
|
| 140 |
+
files = [
|
| 141 |
+
f
|
| 142 |
+
for f in files
|
| 143 |
+
if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
if split == "train":
|
| 147 |
+
assert len(files) == 3406, len(files)
|
| 148 |
+
elif split == "val":
|
| 149 |
+
assert len(files) == 487, len(files)
|
| 150 |
+
elif split == "test":
|
| 151 |
+
assert len(files) == 973, len(files)
|
| 152 |
+
|
| 153 |
+
self.n_tracks = len(files)
|
| 154 |
+
|
| 155 |
+
self.target_length = target_length
|
| 156 |
+
self.chunk_size = int(chunk_size_second * fs)
|
| 157 |
+
|
| 158 |
+
super().__init__(
|
| 159 |
+
data_path=data_path,
|
| 160 |
+
split=split,
|
| 161 |
+
stems=stems,
|
| 162 |
+
files=files,
|
| 163 |
+
fs=fs,
|
| 164 |
+
npy_memmap=npy_memmap,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def __len__(self) -> int:
|
| 168 |
+
return self.target_length
|
| 169 |
+
|
| 170 |
+
def get_identifier(self, index):
|
| 171 |
+
return super().get_identifier(index % self.n_tracks)
|
| 172 |
+
|
| 173 |
+
def get_stem(
|
| 174 |
+
self,
|
| 175 |
+
*,
|
| 176 |
+
stem: str,
|
| 177 |
+
identifier: Dict[str, Any],
|
| 178 |
+
chunk_here: bool = False,
|
| 179 |
+
) -> torch.Tensor:
|
| 180 |
+
|
| 181 |
+
stem = super().get_stem(stem=stem, identifier=identifier)
|
| 182 |
+
|
| 183 |
+
if chunk_here:
|
| 184 |
+
start = np.random.randint(
|
| 185 |
+
0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
|
| 186 |
+
)
|
| 187 |
+
end = start + self.chunk_size
|
| 188 |
+
|
| 189 |
+
stem = stem[:, start:end]
|
| 190 |
+
|
| 191 |
+
return stem
|
| 192 |
+
|
| 193 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 194 |
+
identifier = self.get_identifier(index)
|
| 195 |
+
audio = self.get_audio(identifier)
|
| 196 |
+
|
| 197 |
+
start = np.random.randint(0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size)
|
| 198 |
+
end = start + self.chunk_size
|
| 199 |
+
|
| 200 |
+
audio = {k: v[:, start:end] for k, v in audio.items()}
|
| 201 |
+
|
| 202 |
+
return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset):
|
| 206 |
+
def __init__(
|
| 207 |
+
self,
|
| 208 |
+
data_root: str,
|
| 209 |
+
split: str,
|
| 210 |
+
chunk_size_second: float,
|
| 211 |
+
hop_size_second: float,
|
| 212 |
+
stems: Optional[List[str]] = None,
|
| 213 |
+
fs: int = 44100,
|
| 214 |
+
npy_memmap: bool = True,
|
| 215 |
+
) -> None:
|
| 216 |
+
|
| 217 |
+
if stems is None:
|
| 218 |
+
stems = self.ALLOWED_STEMS
|
| 219 |
+
self.stems = stems
|
| 220 |
+
|
| 221 |
+
data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
|
| 222 |
+
|
| 223 |
+
files = sorted(os.listdir(data_path))
|
| 224 |
+
files = [
|
| 225 |
+
f
|
| 226 |
+
for f in files
|
| 227 |
+
if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
|
| 228 |
+
]
|
| 229 |
+
if split == "train":
|
| 230 |
+
assert len(files) == 3406, len(files)
|
| 231 |
+
elif split == "val":
|
| 232 |
+
assert len(files) == 487, len(files)
|
| 233 |
+
elif split == "test":
|
| 234 |
+
assert len(files) == 973, len(files)
|
| 235 |
+
|
| 236 |
+
self.n_tracks = len(files)
|
| 237 |
+
|
| 238 |
+
self.chunk_size = int(chunk_size_second * fs)
|
| 239 |
+
self.hop_size = int(hop_size_second * fs)
|
| 240 |
+
self.n_chunks_per_track = int(
|
| 241 |
+
(self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
self.length = self.n_tracks * self.n_chunks_per_track
|
| 245 |
+
|
| 246 |
+
super().__init__(
|
| 247 |
+
data_path=data_path,
|
| 248 |
+
split=split,
|
| 249 |
+
stems=stems,
|
| 250 |
+
files=files,
|
| 251 |
+
fs=fs,
|
| 252 |
+
npy_memmap=npy_memmap,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
def get_identifier(self, index):
|
| 256 |
+
return super().get_identifier(index % self.n_tracks)
|
| 257 |
+
|
| 258 |
+
def __len__(self) -> int:
|
| 259 |
+
return self.length
|
| 260 |
+
|
| 261 |
+
def __getitem__(self, item: int) -> DataDict:
|
| 262 |
+
|
| 263 |
+
index = item % self.n_tracks
|
| 264 |
+
chunk = item // self.n_tracks
|
| 265 |
+
|
| 266 |
+
data_ = super().__getitem__(index)
|
| 267 |
+
|
| 268 |
+
audio = data_["audio"]
|
| 269 |
+
|
| 270 |
+
start = chunk * self.hop_size
|
| 271 |
+
end = start + self.chunk_size
|
| 272 |
+
|
| 273 |
+
for stem in self.stems:
|
| 274 |
+
data_["audio"][stem] = audio[stem][:, start:end]
|
| 275 |
+
|
| 276 |
+
return data_
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
|
| 280 |
+
DivideAndRemasterRandomChunkDataset
|
| 281 |
+
):
|
| 282 |
+
def __init__(
|
| 283 |
+
self,
|
| 284 |
+
data_root: str,
|
| 285 |
+
split: str,
|
| 286 |
+
target_length: int,
|
| 287 |
+
chunk_size_second: float,
|
| 288 |
+
stems: Optional[List[str]] = None,
|
| 289 |
+
fs: int = 44100,
|
| 290 |
+
npy_memmap: bool = True,
|
| 291 |
+
) -> None:
|
| 292 |
+
|
| 293 |
+
if stems is None:
|
| 294 |
+
stems = self.ALLOWED_STEMS
|
| 295 |
+
|
| 296 |
+
stems_no_mixture = [s for s in stems if s != "mixture"]
|
| 297 |
+
|
| 298 |
+
super().__init__(
|
| 299 |
+
data_root=data_root,
|
| 300 |
+
split=split,
|
| 301 |
+
target_length=target_length,
|
| 302 |
+
chunk_size_second=chunk_size_second,
|
| 303 |
+
stems=stems_no_mixture,
|
| 304 |
+
fs=fs,
|
| 305 |
+
npy_memmap=npy_memmap,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
self.stems = stems
|
| 309 |
+
self.stems_no_mixture = stems_no_mixture
|
| 310 |
+
|
| 311 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 312 |
+
|
| 313 |
+
data_ = super().__getitem__(index)
|
| 314 |
+
|
| 315 |
+
dry = data_["audio"]["speech"][:]
|
| 316 |
+
n_samples = dry.shape[-1]
|
| 317 |
+
|
| 318 |
+
wet_level = np.random.rand()
|
| 319 |
+
|
| 320 |
+
speech = pb.Reverb(
|
| 321 |
+
room_size=np.random.rand(),
|
| 322 |
+
damping=np.random.rand(),
|
| 323 |
+
wet_level=wet_level,
|
| 324 |
+
dry_level=(1 - wet_level),
|
| 325 |
+
width=np.random.rand(),
|
| 326 |
+
).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples]
|
| 327 |
+
|
| 328 |
+
data_["audio"]["speech"] = speech
|
| 329 |
+
|
| 330 |
+
data_["audio"]["mixture"] = sum(
|
| 331 |
+
[data_["audio"][s] for s in self.stems_no_mixture]
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
return data_
|
| 335 |
+
|
| 336 |
+
def __len__(self) -> int:
|
| 337 |
+
return super().__len__()
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
if __name__ == "__main__":
|
| 341 |
+
|
| 342 |
+
from pprint import pprint
|
| 343 |
+
from tqdm.auto import tqdm
|
| 344 |
+
|
| 345 |
+
for split_ in ["train", "val", "test"]:
|
| 346 |
+
ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
|
| 347 |
+
data_root="$DATA_ROOT/DnR/v2np",
|
| 348 |
+
split=split_,
|
| 349 |
+
target_length=100,
|
| 350 |
+
chunk_size_second=6.0,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
print(split_, len(ds))
|
| 354 |
+
|
| 355 |
+
for track_ in tqdm(ds): # type: ignore
|
| 356 |
+
pprint(track_)
|
| 357 |
+
track_["audio"] = {k: v.shape for k, v in track_["audio"].items()}
|
| 358 |
+
pprint(track_)
|
| 359 |
+
|
| 360 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mvsepless/models/bandit/core/data/dnr/preprocess.py
CHANGED
|
@@ -1,51 +1,51 @@
|
|
| 1 |
-
import glob
|
| 2 |
-
import os
|
| 3 |
-
from typing import Tuple
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import torchaudio as ta
|
| 7 |
-
from tqdm.contrib.concurrent import process_map
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def process_one(inputs: Tuple[str, str, int]) -> None:
|
| 11 |
-
infile, outfile, target_fs = inputs
|
| 12 |
-
|
| 13 |
-
dir = os.path.dirname(outfile)
|
| 14 |
-
os.makedirs(dir, exist_ok=True)
|
| 15 |
-
|
| 16 |
-
data, fs = ta.load(infile)
|
| 17 |
-
|
| 18 |
-
if fs != target_fs:
|
| 19 |
-
data = ta.functional.resample(
|
| 20 |
-
data, fs, target_fs, resampling_method="sinc_interp_kaiser"
|
| 21 |
-
)
|
| 22 |
-
fs = target_fs
|
| 23 |
-
|
| 24 |
-
data = data.numpy()
|
| 25 |
-
data = data.astype(np.float32)
|
| 26 |
-
|
| 27 |
-
if os.path.exists(outfile):
|
| 28 |
-
data_ = np.load(outfile)
|
| 29 |
-
if np.allclose(data, data_):
|
| 30 |
-
return
|
| 31 |
-
|
| 32 |
-
np.save(outfile, data)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def preprocess(data_path: str, output_path: str, fs: int) -> None:
|
| 36 |
-
files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
|
| 37 |
-
print(files)
|
| 38 |
-
outfiles = [
|
| 39 |
-
f.replace(data_path, output_path).replace(".wav", ".npy") for f in files
|
| 40 |
-
]
|
| 41 |
-
|
| 42 |
-
os.makedirs(output_path, exist_ok=True)
|
| 43 |
-
inputs = list(zip(files, outfiles, [fs] * len(files)))
|
| 44 |
-
|
| 45 |
-
process_map(process_one, inputs, chunksize=32)
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
if __name__ == "__main__":
|
| 49 |
-
import fire
|
| 50 |
-
|
| 51 |
-
fire.Fire()
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torchaudio as ta
|
| 7 |
+
from tqdm.contrib.concurrent import process_map
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def process_one(inputs: Tuple[str, str, int]) -> None:
|
| 11 |
+
infile, outfile, target_fs = inputs
|
| 12 |
+
|
| 13 |
+
dir = os.path.dirname(outfile)
|
| 14 |
+
os.makedirs(dir, exist_ok=True)
|
| 15 |
+
|
| 16 |
+
data, fs = ta.load(infile)
|
| 17 |
+
|
| 18 |
+
if fs != target_fs:
|
| 19 |
+
data = ta.functional.resample(
|
| 20 |
+
data, fs, target_fs, resampling_method="sinc_interp_kaiser"
|
| 21 |
+
)
|
| 22 |
+
fs = target_fs
|
| 23 |
+
|
| 24 |
+
data = data.numpy()
|
| 25 |
+
data = data.astype(np.float32)
|
| 26 |
+
|
| 27 |
+
if os.path.exists(outfile):
|
| 28 |
+
data_ = np.load(outfile)
|
| 29 |
+
if np.allclose(data, data_):
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
np.save(outfile, data)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def preprocess(data_path: str, output_path: str, fs: int) -> None:
|
| 36 |
+
files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
|
| 37 |
+
print(files)
|
| 38 |
+
outfiles = [
|
| 39 |
+
f.replace(data_path, output_path).replace(".wav", ".npy") for f in files
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
os.makedirs(output_path, exist_ok=True)
|
| 43 |
+
inputs = list(zip(files, outfiles, [fs] * len(files)))
|
| 44 |
+
|
| 45 |
+
process_map(process_one, inputs, chunksize=32)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
import fire
|
| 50 |
+
|
| 51 |
+
fire.Fire()
|
mvsepless/models/bandit/core/data/musdb/datamodule.py
CHANGED
|
@@ -1,75 +1,75 @@
|
|
| 1 |
-
import os.path
|
| 2 |
-
from typing import Mapping, Optional
|
| 3 |
-
|
| 4 |
-
import pytorch_lightning as pl
|
| 5 |
-
|
| 6 |
-
from .dataset import (
|
| 7 |
-
MUSDB18BaseDataset,
|
| 8 |
-
MUSDB18FullTrackDataset,
|
| 9 |
-
MUSDB18SadDataset,
|
| 10 |
-
MUSDB18SadOnTheFlyAugmentedDataset,
|
| 11 |
-
)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def MUSDB18DataModule(
|
| 15 |
-
data_root: str = "$DATA_ROOT/MUSDB18/HQ",
|
| 16 |
-
target_stem: str = "vocals",
|
| 17 |
-
batch_size: int = 2,
|
| 18 |
-
num_workers: int = 8,
|
| 19 |
-
train_kwargs: Optional[Mapping] = None,
|
| 20 |
-
val_kwargs: Optional[Mapping] = None,
|
| 21 |
-
test_kwargs: Optional[Mapping] = None,
|
| 22 |
-
datamodule_kwargs: Optional[Mapping] = None,
|
| 23 |
-
use_on_the_fly: bool = True,
|
| 24 |
-
npy_memmap: bool = True,
|
| 25 |
-
) -> pl.LightningDataModule:
|
| 26 |
-
if train_kwargs is None:
|
| 27 |
-
train_kwargs = {}
|
| 28 |
-
|
| 29 |
-
if val_kwargs is None:
|
| 30 |
-
val_kwargs = {}
|
| 31 |
-
|
| 32 |
-
if test_kwargs is None:
|
| 33 |
-
test_kwargs = {}
|
| 34 |
-
|
| 35 |
-
if datamodule_kwargs is None:
|
| 36 |
-
datamodule_kwargs = {}
|
| 37 |
-
|
| 38 |
-
train_dataset: MUSDB18BaseDataset
|
| 39 |
-
|
| 40 |
-
if use_on_the_fly:
|
| 41 |
-
train_dataset = MUSDB18SadOnTheFlyAugmentedDataset(
|
| 42 |
-
data_root=os.path.join(data_root, "saded-np"),
|
| 43 |
-
split="train",
|
| 44 |
-
target_stem=target_stem,
|
| 45 |
-
**train_kwargs
|
| 46 |
-
)
|
| 47 |
-
else:
|
| 48 |
-
train_dataset = MUSDB18SadDataset(
|
| 49 |
-
data_root=os.path.join(data_root, "saded-np"),
|
| 50 |
-
split="train",
|
| 51 |
-
target_stem=target_stem,
|
| 52 |
-
**train_kwargs
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
datamodule = pl.LightningDataModule.from_datasets(
|
| 56 |
-
train_dataset=train_dataset,
|
| 57 |
-
val_dataset=MUSDB18SadDataset(
|
| 58 |
-
data_root=os.path.join(data_root, "saded-np"),
|
| 59 |
-
split="val",
|
| 60 |
-
target_stem=target_stem,
|
| 61 |
-
**val_kwargs
|
| 62 |
-
),
|
| 63 |
-
test_dataset=MUSDB18FullTrackDataset(
|
| 64 |
-
data_root=os.path.join(data_root, "canonical"), split="test", **test_kwargs
|
| 65 |
-
),
|
| 66 |
-
batch_size=batch_size,
|
| 67 |
-
num_workers=num_workers,
|
| 68 |
-
**datamodule_kwargs
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
datamodule.predict_dataloader = ( # type: ignore[method-assign]
|
| 72 |
-
datamodule.test_dataloader
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
return datamodule
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
from typing import Mapping, Optional
|
| 3 |
+
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
|
| 6 |
+
from .dataset import (
|
| 7 |
+
MUSDB18BaseDataset,
|
| 8 |
+
MUSDB18FullTrackDataset,
|
| 9 |
+
MUSDB18SadDataset,
|
| 10 |
+
MUSDB18SadOnTheFlyAugmentedDataset,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def MUSDB18DataModule(
|
| 15 |
+
data_root: str = "$DATA_ROOT/MUSDB18/HQ",
|
| 16 |
+
target_stem: str = "vocals",
|
| 17 |
+
batch_size: int = 2,
|
| 18 |
+
num_workers: int = 8,
|
| 19 |
+
train_kwargs: Optional[Mapping] = None,
|
| 20 |
+
val_kwargs: Optional[Mapping] = None,
|
| 21 |
+
test_kwargs: Optional[Mapping] = None,
|
| 22 |
+
datamodule_kwargs: Optional[Mapping] = None,
|
| 23 |
+
use_on_the_fly: bool = True,
|
| 24 |
+
npy_memmap: bool = True,
|
| 25 |
+
) -> pl.LightningDataModule:
|
| 26 |
+
if train_kwargs is None:
|
| 27 |
+
train_kwargs = {}
|
| 28 |
+
|
| 29 |
+
if val_kwargs is None:
|
| 30 |
+
val_kwargs = {}
|
| 31 |
+
|
| 32 |
+
if test_kwargs is None:
|
| 33 |
+
test_kwargs = {}
|
| 34 |
+
|
| 35 |
+
if datamodule_kwargs is None:
|
| 36 |
+
datamodule_kwargs = {}
|
| 37 |
+
|
| 38 |
+
train_dataset: MUSDB18BaseDataset
|
| 39 |
+
|
| 40 |
+
if use_on_the_fly:
|
| 41 |
+
train_dataset = MUSDB18SadOnTheFlyAugmentedDataset(
|
| 42 |
+
data_root=os.path.join(data_root, "saded-np"),
|
| 43 |
+
split="train",
|
| 44 |
+
target_stem=target_stem,
|
| 45 |
+
**train_kwargs,
|
| 46 |
+
)
|
| 47 |
+
else:
|
| 48 |
+
train_dataset = MUSDB18SadDataset(
|
| 49 |
+
data_root=os.path.join(data_root, "saded-np"),
|
| 50 |
+
split="train",
|
| 51 |
+
target_stem=target_stem,
|
| 52 |
+
**train_kwargs,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
datamodule = pl.LightningDataModule.from_datasets(
|
| 56 |
+
train_dataset=train_dataset,
|
| 57 |
+
val_dataset=MUSDB18SadDataset(
|
| 58 |
+
data_root=os.path.join(data_root, "saded-np"),
|
| 59 |
+
split="val",
|
| 60 |
+
target_stem=target_stem,
|
| 61 |
+
**val_kwargs,
|
| 62 |
+
),
|
| 63 |
+
test_dataset=MUSDB18FullTrackDataset(
|
| 64 |
+
data_root=os.path.join(data_root, "canonical"), split="test", **test_kwargs
|
| 65 |
+
),
|
| 66 |
+
batch_size=batch_size,
|
| 67 |
+
num_workers=num_workers,
|
| 68 |
+
**datamodule_kwargs,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
datamodule.predict_dataloader = ( # type: ignore[method-assign]
|
| 72 |
+
datamodule.test_dataloader
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
return datamodule
|
mvsepless/models/bandit/core/data/musdb/dataset.py
CHANGED
|
@@ -1,273 +1,241 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from abc import ABC
|
| 3 |
-
from typing import List, Optional, Tuple
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import torch
|
| 7 |
-
import torchaudio as ta
|
| 8 |
-
from torch.utils import data
|
| 9 |
-
|
| 10 |
-
from .._types import AudioDict, DataDict
|
| 11 |
-
from ..base import BaseSourceSeparationDataset
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC):
|
| 15 |
-
|
| 16 |
-
ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"]
|
| 17 |
-
|
| 18 |
-
def __init__(
|
| 19 |
-
self,
|
| 20 |
-
split: str,
|
| 21 |
-
stems: List[str],
|
| 22 |
-
files: List[str],
|
| 23 |
-
data_path: str,
|
| 24 |
-
fs: int = 44100,
|
| 25 |
-
npy_memmap=False,
|
| 26 |
-
) -> None:
|
| 27 |
-
super().__init__(
|
| 28 |
-
split=split,
|
| 29 |
-
stems=stems,
|
| 30 |
-
files=files,
|
| 31 |
-
data_path=data_path,
|
| 32 |
-
fs=fs,
|
| 33 |
-
npy_memmap=npy_memmap,
|
| 34 |
-
recompute_mixture=False,
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
|
| 38 |
-
track = identifier["track"]
|
| 39 |
-
path = os.path.join(self.data_path, track)
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
"
|
| 65 |
-
"
|
| 66 |
-
"
|
| 67 |
-
"
|
| 68 |
-
"
|
| 69 |
-
"
|
| 70 |
-
"
|
| 71 |
-
"
|
| 72 |
-
"
|
| 73 |
-
"
|
| 74 |
-
"
|
| 75 |
-
"
|
| 76 |
-
"
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
files =
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
self.
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
self.
|
| 185 |
-
self.
|
| 186 |
-
self.
|
| 187 |
-
|
| 188 |
-
self.
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
)
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
#
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
audio
|
| 242 |
-
|
| 243 |
-
if self.rescale:
|
| 244 |
-
max_abs_val = max(
|
| 245 |
-
[torch.max(torch.abs(audio[stem])) for stem in self.stems]
|
| 246 |
-
) # type: ignore[type-var]
|
| 247 |
-
if max_abs_val > 1:
|
| 248 |
-
audio = {k: v / max_abs_val for k, v in audio.items()}
|
| 249 |
-
|
| 250 |
-
track = identifier["track"]
|
| 251 |
-
|
| 252 |
-
return {"audio": audio, "track": f"{self.split}/{track}"}
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
# if __name__ == "__main__":
|
| 256 |
-
#
|
| 257 |
-
# from pprint import pprint
|
| 258 |
-
# from tqdm.auto import tqdm
|
| 259 |
-
#
|
| 260 |
-
# for split_ in ["train", "val", "test"]:
|
| 261 |
-
# ds = MUSDB18SadOnTheFlyAugmentedDataset(
|
| 262 |
-
# data_root="$DATA_ROOT/MUSDB18/HQ/saded",
|
| 263 |
-
# split=split_,
|
| 264 |
-
# target_stem="vocals"
|
| 265 |
-
# )
|
| 266 |
-
#
|
| 267 |
-
# print(split_, len(ds))
|
| 268 |
-
#
|
| 269 |
-
# for track_ in tqdm(ds):
|
| 270 |
-
# track_["audio"] = {
|
| 271 |
-
# k: v.shape for k, v in track_["audio"].items()
|
| 272 |
-
# }
|
| 273 |
-
# pprint(track_)
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import ABC
|
| 3 |
+
from typing import List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torchaudio as ta
|
| 8 |
+
from torch.utils import data
|
| 9 |
+
|
| 10 |
+
from .._types import AudioDict, DataDict
|
| 11 |
+
from ..base import BaseSourceSeparationDataset
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC):
|
| 15 |
+
|
| 16 |
+
ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"]
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
split: str,
|
| 21 |
+
stems: List[str],
|
| 22 |
+
files: List[str],
|
| 23 |
+
data_path: str,
|
| 24 |
+
fs: int = 44100,
|
| 25 |
+
npy_memmap=False,
|
| 26 |
+
) -> None:
|
| 27 |
+
super().__init__(
|
| 28 |
+
split=split,
|
| 29 |
+
stems=stems,
|
| 30 |
+
files=files,
|
| 31 |
+
data_path=data_path,
|
| 32 |
+
fs=fs,
|
| 33 |
+
npy_memmap=npy_memmap,
|
| 34 |
+
recompute_mixture=False,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
|
| 38 |
+
track = identifier["track"]
|
| 39 |
+
path = os.path.join(self.data_path, track)
|
| 40 |
+
|
| 41 |
+
if self.npy_memmap:
|
| 42 |
+
audio = np.load(os.path.join(path, f"{stem}.wav.npy"), mmap_mode="r")
|
| 43 |
+
else:
|
| 44 |
+
audio, _ = ta.load(os.path.join(path, f"{stem}.wav"))
|
| 45 |
+
|
| 46 |
+
return audio
|
| 47 |
+
|
| 48 |
+
def get_identifier(self, index):
|
| 49 |
+
return dict(track=self.files[index])
|
| 50 |
+
|
| 51 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 52 |
+
identifier = self.get_identifier(index)
|
| 53 |
+
audio = self.get_audio(identifier)
|
| 54 |
+
|
| 55 |
+
return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class MUSDB18FullTrackDataset(MUSDB18BaseDataset):
|
| 59 |
+
|
| 60 |
+
N_TRAIN_TRACKS = 100
|
| 61 |
+
N_TEST_TRACKS = 50
|
| 62 |
+
VALIDATION_FILES = [
|
| 63 |
+
"Actions - One Minute Smile",
|
| 64 |
+
"Clara Berry And Wooldog - Waltz For My Victims",
|
| 65 |
+
"Johnny Lokke - Promises & Lies",
|
| 66 |
+
"Patrick Talbot - A Reason To Leave",
|
| 67 |
+
"Triviul - Angelsaint",
|
| 68 |
+
"Alexander Ross - Goodbye Bolero",
|
| 69 |
+
"Fergessen - Nos Palpitants",
|
| 70 |
+
"Leaf - Summerghost",
|
| 71 |
+
"Skelpolu - Human Mistakes",
|
| 72 |
+
"Young Griffo - Pennies",
|
| 73 |
+
"ANiMAL - Rockshow",
|
| 74 |
+
"James May - On The Line",
|
| 75 |
+
"Meaxic - Take A Step",
|
| 76 |
+
"Traffic Experiment - Sirens",
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self, data_root: str, split: str, stems: Optional[List[str]] = None
|
| 81 |
+
) -> None:
|
| 82 |
+
|
| 83 |
+
if stems is None:
|
| 84 |
+
stems = self.ALLOWED_STEMS
|
| 85 |
+
self.stems = stems
|
| 86 |
+
|
| 87 |
+
if split == "test":
|
| 88 |
+
subset = "test"
|
| 89 |
+
elif split in ["train", "val"]:
|
| 90 |
+
subset = "train"
|
| 91 |
+
else:
|
| 92 |
+
raise NameError
|
| 93 |
+
|
| 94 |
+
data_path = os.path.join(data_root, subset)
|
| 95 |
+
|
| 96 |
+
files = sorted(os.listdir(data_path))
|
| 97 |
+
files = [f for f in files if not f.startswith(".")]
|
| 98 |
+
if subset == "train":
|
| 99 |
+
assert len(files) == 100, len(files)
|
| 100 |
+
if split == "train":
|
| 101 |
+
files = [f for f in files if f not in self.VALIDATION_FILES]
|
| 102 |
+
assert len(files) == 100 - len(self.VALIDATION_FILES)
|
| 103 |
+
else:
|
| 104 |
+
files = [f for f in files if f in self.VALIDATION_FILES]
|
| 105 |
+
assert len(files) == len(self.VALIDATION_FILES)
|
| 106 |
+
else:
|
| 107 |
+
split = "test"
|
| 108 |
+
assert len(files) == 50
|
| 109 |
+
|
| 110 |
+
self.n_tracks = len(files)
|
| 111 |
+
|
| 112 |
+
super().__init__(data_path=data_path, split=split, stems=stems, files=files)
|
| 113 |
+
|
| 114 |
+
def __len__(self) -> int:
|
| 115 |
+
return self.n_tracks
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class MUSDB18SadDataset(MUSDB18BaseDataset):
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
data_root: str,
|
| 122 |
+
split: str,
|
| 123 |
+
target_stem: str,
|
| 124 |
+
stems: Optional[List[str]] = None,
|
| 125 |
+
target_length: Optional[int] = None,
|
| 126 |
+
npy_memmap=False,
|
| 127 |
+
) -> None:
|
| 128 |
+
|
| 129 |
+
if stems is None:
|
| 130 |
+
stems = self.ALLOWED_STEMS
|
| 131 |
+
|
| 132 |
+
data_path = os.path.join(data_root, target_stem, split)
|
| 133 |
+
|
| 134 |
+
files = sorted(os.listdir(data_path))
|
| 135 |
+
files = [f for f in files if not f.startswith(".")]
|
| 136 |
+
|
| 137 |
+
super().__init__(
|
| 138 |
+
data_path=data_path,
|
| 139 |
+
split=split,
|
| 140 |
+
stems=stems,
|
| 141 |
+
files=files,
|
| 142 |
+
npy_memmap=npy_memmap,
|
| 143 |
+
)
|
| 144 |
+
self.n_segments = len(files)
|
| 145 |
+
self.target_stem = target_stem
|
| 146 |
+
self.target_length = (
|
| 147 |
+
target_length if target_length is not None else self.n_segments
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def __len__(self) -> int:
|
| 151 |
+
return self.target_length
|
| 152 |
+
|
| 153 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 154 |
+
|
| 155 |
+
index = index % self.n_segments
|
| 156 |
+
|
| 157 |
+
return super().__getitem__(index)
|
| 158 |
+
|
| 159 |
+
def get_identifier(self, index):
|
| 160 |
+
return super().get_identifier(index % self.n_segments)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset):
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
data_root: str,
|
| 167 |
+
split: str,
|
| 168 |
+
target_stem: str,
|
| 169 |
+
stems: Optional[List[str]] = None,
|
| 170 |
+
target_length: int = 20000,
|
| 171 |
+
apply_probability: Optional[float] = None,
|
| 172 |
+
chunk_size_second: float = 3.0,
|
| 173 |
+
random_scale_range_db: Tuple[float, float] = (-10, 10),
|
| 174 |
+
drop_probability: float = 0.1,
|
| 175 |
+
rescale: bool = True,
|
| 176 |
+
) -> None:
|
| 177 |
+
super().__init__(data_root, split, target_stem, stems)
|
| 178 |
+
|
| 179 |
+
if apply_probability is None:
|
| 180 |
+
apply_probability = (target_length - self.n_segments) / target_length
|
| 181 |
+
|
| 182 |
+
self.apply_probability = apply_probability
|
| 183 |
+
self.drop_probability = drop_probability
|
| 184 |
+
self.chunk_size_second = chunk_size_second
|
| 185 |
+
self.random_scale_range_db = random_scale_range_db
|
| 186 |
+
self.rescale = rescale
|
| 187 |
+
|
| 188 |
+
self.chunk_size_sample = int(self.chunk_size_second * self.fs)
|
| 189 |
+
self.target_length = target_length
|
| 190 |
+
|
| 191 |
+
def __len__(self) -> int:
|
| 192 |
+
return self.target_length
|
| 193 |
+
|
| 194 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 195 |
+
|
| 196 |
+
index = index % self.n_segments
|
| 197 |
+
|
| 198 |
+
audio = {}
|
| 199 |
+
identifier = self.get_identifier(index)
|
| 200 |
+
|
| 201 |
+
for stem in self.stems_no_mixture:
|
| 202 |
+
if stem == self.target_stem:
|
| 203 |
+
identifier_ = identifier
|
| 204 |
+
else:
|
| 205 |
+
if np.random.rand() < self.apply_probability:
|
| 206 |
+
index_ = np.random.randint(self.n_segments)
|
| 207 |
+
identifier_ = self.get_identifier(index_)
|
| 208 |
+
else:
|
| 209 |
+
identifier_ = identifier
|
| 210 |
+
|
| 211 |
+
audio[stem] = self.get_stem(stem=stem, identifier=identifier_)
|
| 212 |
+
|
| 213 |
+
if self.chunk_size_sample < audio[stem].shape[-1]:
|
| 214 |
+
chunk_start = np.random.randint(
|
| 215 |
+
audio[stem].shape[-1] - self.chunk_size_sample
|
| 216 |
+
)
|
| 217 |
+
else:
|
| 218 |
+
chunk_start = 0
|
| 219 |
+
|
| 220 |
+
if np.random.rand() < self.drop_probability:
|
| 221 |
+
linear_scale = 0.0
|
| 222 |
+
else:
|
| 223 |
+
db_scale = np.random.uniform(*self.random_scale_range_db)
|
| 224 |
+
linear_scale = np.power(10, db_scale / 20)
|
| 225 |
+
audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample] = (
|
| 226 |
+
linear_scale
|
| 227 |
+
* audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample]
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
audio["mixture"] = self.compute_mixture(audio)
|
| 231 |
+
|
| 232 |
+
if self.rescale:
|
| 233 |
+
max_abs_val = max(
|
| 234 |
+
[torch.max(torch.abs(audio[stem])) for stem in self.stems]
|
| 235 |
+
) # type: ignore[type-var]
|
| 236 |
+
if max_abs_val > 1:
|
| 237 |
+
audio = {k: v / max_abs_val for k, v in audio.items()}
|
| 238 |
+
|
| 239 |
+
track = identifier["track"]
|
| 240 |
+
|
| 241 |
+
return {"audio": audio, "track": f"{self.split}/{track}"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mvsepless/models/bandit/core/data/musdb/preprocess.py
CHANGED
|
@@ -1,226 +1,223 @@
|
|
| 1 |
-
import glob
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
import numpy as np
|
| 5 |
-
import torch
|
| 6 |
-
import torchaudio as ta
|
| 7 |
-
from torch import nn
|
| 8 |
-
from torch.nn import functional as F
|
| 9 |
-
from tqdm.contrib.concurrent import process_map
|
| 10 |
-
|
| 11 |
-
from .._types import DataDict
|
| 12 |
-
from .dataset import MUSDB18FullTrackDataset
|
| 13 |
-
import pyloudnorm as pyln
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class SourceActivityDetector(nn.Module):
|
| 17 |
-
def __init__(
|
| 18 |
-
self,
|
| 19 |
-
analysis_stem: str,
|
| 20 |
-
output_path: str,
|
| 21 |
-
fs: int = 44100,
|
| 22 |
-
segment_length_second: float = 6.0,
|
| 23 |
-
hop_length_second: float = 3.0,
|
| 24 |
-
n_chunks: int = 10,
|
| 25 |
-
chunk_epsilon: float = 1e-5,
|
| 26 |
-
energy_threshold_quantile: float = 0.15,
|
| 27 |
-
segment_epsilon: float = 1e-3,
|
| 28 |
-
salient_proportion_threshold: float = 0.5,
|
| 29 |
-
target_lufs: float = -24,
|
| 30 |
-
) -> None:
|
| 31 |
-
super().__init__()
|
| 32 |
-
|
| 33 |
-
self.fs = fs
|
| 34 |
-
self.segment_length = int(segment_length_second * self.fs)
|
| 35 |
-
self.hop_length = int(hop_length_second * self.fs)
|
| 36 |
-
self.n_chunks = n_chunks
|
| 37 |
-
assert self.segment_length % self.n_chunks == 0
|
| 38 |
-
self.chunk_size = self.segment_length // self.n_chunks
|
| 39 |
-
self.chunk_epsilon = chunk_epsilon
|
| 40 |
-
self.energy_threshold_quantile = energy_threshold_quantile
|
| 41 |
-
self.segment_epsilon = segment_epsilon
|
| 42 |
-
self.salient_proportion_threshold = salient_proportion_threshold
|
| 43 |
-
self.analysis_stem = analysis_stem
|
| 44 |
-
|
| 45 |
-
self.meter = pyln.Meter(self.fs)
|
| 46 |
-
self.target_lufs = target_lufs
|
| 47 |
-
|
| 48 |
-
self.output_path = output_path
|
| 49 |
-
|
| 50 |
-
def forward(self, data: DataDict) -> None:
|
| 51 |
-
|
| 52 |
-
stem_ = self.analysis_stem if (self.analysis_stem != "none") else "mixture"
|
| 53 |
-
|
| 54 |
-
x = data["audio"][stem_]
|
| 55 |
-
|
| 56 |
-
xnp = x.numpy()
|
| 57 |
-
loudness = self.meter.integrated_loudness(xnp.T)
|
| 58 |
-
|
| 59 |
-
for stem in data["audio"]:
|
| 60 |
-
s = data["audio"][stem]
|
| 61 |
-
s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T
|
| 62 |
-
s = torch.as_tensor(s)
|
| 63 |
-
data["audio"][stem] = s
|
| 64 |
-
|
| 65 |
-
if x.ndim == 3:
|
| 66 |
-
assert x.shape[0] == 1
|
| 67 |
-
x = x[0]
|
| 68 |
-
|
| 69 |
-
n_chan, n_samples = x.shape
|
| 70 |
-
|
| 71 |
-
n_segments = (
|
| 72 |
-
int(np.ceil((n_samples - self.segment_length) / self.hop_length)) + 1
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
segments = torch.zeros((n_segments, n_chan, self.segment_length))
|
| 76 |
-
for i in range(n_segments):
|
| 77 |
-
start = i * self.hop_length
|
| 78 |
-
end = start + self.segment_length
|
| 79 |
-
end = min(end, n_samples)
|
| 80 |
-
|
| 81 |
-
xseg = x[:, start:end]
|
| 82 |
-
|
| 83 |
-
if end - start < self.segment_length:
|
| 84 |
-
xseg = F.pad(
|
| 85 |
-
xseg, pad=(0, self.segment_length - (end - start)), value=torch.nan
|
| 86 |
-
)
|
| 87 |
-
|
| 88 |
-
segments[i, :, :] = xseg
|
| 89 |
-
|
| 90 |
-
chunks = segments.reshape((n_segments, n_chan, self.n_chunks, self.chunk_size))
|
| 91 |
-
|
| 92 |
-
if self.analysis_stem != "none":
|
| 93 |
-
chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3))
|
| 94 |
-
chunk_energies = torch.nan_to_num(chunk_energies, nan=0)
|
| 95 |
-
chunk_energies[chunk_energies == 0] = self.chunk_epsilon
|
| 96 |
-
|
| 97 |
-
energy_threshold = torch.nanquantile(
|
| 98 |
-
chunk_energies, q=self.energy_threshold_quantile
|
| 99 |
-
)
|
| 100 |
-
|
| 101 |
-
if energy_threshold < self.segment_epsilon:
|
| 102 |
-
energy_threshold = self.segment_epsilon # type: ignore[assignment]
|
| 103 |
-
|
| 104 |
-
chunks_above_threshold = chunk_energies > energy_threshold
|
| 105 |
-
n_chunks_above_threshold = torch.mean(
|
| 106 |
-
chunks_above_threshold.to(torch.float), dim=-1
|
| 107 |
-
)
|
| 108 |
-
|
| 109 |
-
segment_above_threshold = (
|
| 110 |
-
n_chunks_above_threshold > self.salient_proportion_threshold
|
| 111 |
-
)
|
| 112 |
-
|
| 113 |
-
if torch.sum(segment_above_threshold) == 0:
|
| 114 |
-
return
|
| 115 |
-
|
| 116 |
-
else:
|
| 117 |
-
segment_above_threshold = torch.ones((n_segments,))
|
| 118 |
-
|
| 119 |
-
for i in range(n_segments):
|
| 120 |
-
if not segment_above_threshold[i]:
|
| 121 |
-
continue
|
| 122 |
-
|
| 123 |
-
outpath = os.path.join(
|
| 124 |
-
self.output_path,
|
| 125 |
-
self.analysis_stem,
|
| 126 |
-
f"{data['track']} - {self.analysis_stem}{i:03d}",
|
| 127 |
-
)
|
| 128 |
-
os.makedirs(outpath, exist_ok=True)
|
| 129 |
-
|
| 130 |
-
for stem in data["audio"]:
|
| 131 |
-
if stem == self.analysis_stem:
|
| 132 |
-
segment = torch.nan_to_num(segments[i, :, :], nan=0)
|
| 133 |
-
else:
|
| 134 |
-
start = i * self.hop_length
|
| 135 |
-
end = start + self.segment_length
|
| 136 |
-
end = min(n_samples, end)
|
| 137 |
-
|
| 138 |
-
segment = data["audio"][stem][:, start:end]
|
| 139 |
-
|
| 140 |
-
if end - start < self.segment_length:
|
| 141 |
-
segment = F.pad(
|
| 142 |
-
segment, (0, self.segment_length - (end - start))
|
| 143 |
-
)
|
| 144 |
-
|
| 145 |
-
assert segment.shape[-1] == self.segment_length, segment.shape
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
import fire
|
| 225 |
-
|
| 226 |
-
fire.Fire()
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torchaudio as ta
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
from tqdm.contrib.concurrent import process_map
|
| 10 |
+
|
| 11 |
+
from .._types import DataDict
|
| 12 |
+
from .dataset import MUSDB18FullTrackDataset
|
| 13 |
+
import pyloudnorm as pyln
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SourceActivityDetector(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
analysis_stem: str,
|
| 20 |
+
output_path: str,
|
| 21 |
+
fs: int = 44100,
|
| 22 |
+
segment_length_second: float = 6.0,
|
| 23 |
+
hop_length_second: float = 3.0,
|
| 24 |
+
n_chunks: int = 10,
|
| 25 |
+
chunk_epsilon: float = 1e-5,
|
| 26 |
+
energy_threshold_quantile: float = 0.15,
|
| 27 |
+
segment_epsilon: float = 1e-3,
|
| 28 |
+
salient_proportion_threshold: float = 0.5,
|
| 29 |
+
target_lufs: float = -24,
|
| 30 |
+
) -> None:
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
self.fs = fs
|
| 34 |
+
self.segment_length = int(segment_length_second * self.fs)
|
| 35 |
+
self.hop_length = int(hop_length_second * self.fs)
|
| 36 |
+
self.n_chunks = n_chunks
|
| 37 |
+
assert self.segment_length % self.n_chunks == 0
|
| 38 |
+
self.chunk_size = self.segment_length // self.n_chunks
|
| 39 |
+
self.chunk_epsilon = chunk_epsilon
|
| 40 |
+
self.energy_threshold_quantile = energy_threshold_quantile
|
| 41 |
+
self.segment_epsilon = segment_epsilon
|
| 42 |
+
self.salient_proportion_threshold = salient_proportion_threshold
|
| 43 |
+
self.analysis_stem = analysis_stem
|
| 44 |
+
|
| 45 |
+
self.meter = pyln.Meter(self.fs)
|
| 46 |
+
self.target_lufs = target_lufs
|
| 47 |
+
|
| 48 |
+
self.output_path = output_path
|
| 49 |
+
|
| 50 |
+
def forward(self, data: DataDict) -> None:
|
| 51 |
+
|
| 52 |
+
stem_ = self.analysis_stem if (self.analysis_stem != "none") else "mixture"
|
| 53 |
+
|
| 54 |
+
x = data["audio"][stem_]
|
| 55 |
+
|
| 56 |
+
xnp = x.numpy()
|
| 57 |
+
loudness = self.meter.integrated_loudness(xnp.T)
|
| 58 |
+
|
| 59 |
+
for stem in data["audio"]:
|
| 60 |
+
s = data["audio"][stem]
|
| 61 |
+
s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T
|
| 62 |
+
s = torch.as_tensor(s)
|
| 63 |
+
data["audio"][stem] = s
|
| 64 |
+
|
| 65 |
+
if x.ndim == 3:
|
| 66 |
+
assert x.shape[0] == 1
|
| 67 |
+
x = x[0]
|
| 68 |
+
|
| 69 |
+
n_chan, n_samples = x.shape
|
| 70 |
+
|
| 71 |
+
n_segments = (
|
| 72 |
+
int(np.ceil((n_samples - self.segment_length) / self.hop_length)) + 1
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
segments = torch.zeros((n_segments, n_chan, self.segment_length))
|
| 76 |
+
for i in range(n_segments):
|
| 77 |
+
start = i * self.hop_length
|
| 78 |
+
end = start + self.segment_length
|
| 79 |
+
end = min(end, n_samples)
|
| 80 |
+
|
| 81 |
+
xseg = x[:, start:end]
|
| 82 |
+
|
| 83 |
+
if end - start < self.segment_length:
|
| 84 |
+
xseg = F.pad(
|
| 85 |
+
xseg, pad=(0, self.segment_length - (end - start)), value=torch.nan
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
segments[i, :, :] = xseg
|
| 89 |
+
|
| 90 |
+
chunks = segments.reshape((n_segments, n_chan, self.n_chunks, self.chunk_size))
|
| 91 |
+
|
| 92 |
+
if self.analysis_stem != "none":
|
| 93 |
+
chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3))
|
| 94 |
+
chunk_energies = torch.nan_to_num(chunk_energies, nan=0)
|
| 95 |
+
chunk_energies[chunk_energies == 0] = self.chunk_epsilon
|
| 96 |
+
|
| 97 |
+
energy_threshold = torch.nanquantile(
|
| 98 |
+
chunk_energies, q=self.energy_threshold_quantile
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
if energy_threshold < self.segment_epsilon:
|
| 102 |
+
energy_threshold = self.segment_epsilon # type: ignore[assignment]
|
| 103 |
+
|
| 104 |
+
chunks_above_threshold = chunk_energies > energy_threshold
|
| 105 |
+
n_chunks_above_threshold = torch.mean(
|
| 106 |
+
chunks_above_threshold.to(torch.float), dim=-1
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
segment_above_threshold = (
|
| 110 |
+
n_chunks_above_threshold > self.salient_proportion_threshold
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
if torch.sum(segment_above_threshold) == 0:
|
| 114 |
+
return
|
| 115 |
+
|
| 116 |
+
else:
|
| 117 |
+
segment_above_threshold = torch.ones((n_segments,))
|
| 118 |
+
|
| 119 |
+
for i in range(n_segments):
|
| 120 |
+
if not segment_above_threshold[i]:
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
outpath = os.path.join(
|
| 124 |
+
self.output_path,
|
| 125 |
+
self.analysis_stem,
|
| 126 |
+
f"{data['track']} - {self.analysis_stem}{i:03d}",
|
| 127 |
+
)
|
| 128 |
+
os.makedirs(outpath, exist_ok=True)
|
| 129 |
+
|
| 130 |
+
for stem in data["audio"]:
|
| 131 |
+
if stem == self.analysis_stem:
|
| 132 |
+
segment = torch.nan_to_num(segments[i, :, :], nan=0)
|
| 133 |
+
else:
|
| 134 |
+
start = i * self.hop_length
|
| 135 |
+
end = start + self.segment_length
|
| 136 |
+
end = min(n_samples, end)
|
| 137 |
+
|
| 138 |
+
segment = data["audio"][stem][:, start:end]
|
| 139 |
+
|
| 140 |
+
if end - start < self.segment_length:
|
| 141 |
+
segment = F.pad(
|
| 142 |
+
segment, (0, self.segment_length - (end - start))
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
assert segment.shape[-1] == self.segment_length, segment.shape
|
| 146 |
+
|
| 147 |
+
np.save(os.path.join(outpath, f"{stem}.wav"), segment)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def preprocess(
|
| 151 |
+
analysis_stem: str,
|
| 152 |
+
output_path: str = "/data/MUSDB18/HQ/saded-np",
|
| 153 |
+
fs: int = 44100,
|
| 154 |
+
segment_length_second: float = 6.0,
|
| 155 |
+
hop_length_second: float = 3.0,
|
| 156 |
+
n_chunks: int = 10,
|
| 157 |
+
chunk_epsilon: float = 1e-5,
|
| 158 |
+
energy_threshold_quantile: float = 0.15,
|
| 159 |
+
segment_epsilon: float = 1e-3,
|
| 160 |
+
salient_proportion_threshold: float = 0.5,
|
| 161 |
+
) -> None:
|
| 162 |
+
|
| 163 |
+
sad = SourceActivityDetector(
|
| 164 |
+
analysis_stem=analysis_stem,
|
| 165 |
+
output_path=output_path,
|
| 166 |
+
fs=fs,
|
| 167 |
+
segment_length_second=segment_length_second,
|
| 168 |
+
hop_length_second=hop_length_second,
|
| 169 |
+
n_chunks=n_chunks,
|
| 170 |
+
chunk_epsilon=chunk_epsilon,
|
| 171 |
+
energy_threshold_quantile=energy_threshold_quantile,
|
| 172 |
+
segment_epsilon=segment_epsilon,
|
| 173 |
+
salient_proportion_threshold=salient_proportion_threshold,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
for split in ["train", "val", "test"]:
|
| 177 |
+
ds = MUSDB18FullTrackDataset(
|
| 178 |
+
data_root="/data/MUSDB18/HQ/canonical",
|
| 179 |
+
split=split,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
tracks = []
|
| 183 |
+
for i, track in enumerate(tqdm(ds, total=len(ds))):
|
| 184 |
+
if i % 32 == 0 and tracks:
|
| 185 |
+
process_map(sad, tracks, max_workers=8)
|
| 186 |
+
tracks = []
|
| 187 |
+
tracks.append(track)
|
| 188 |
+
process_map(sad, tracks, max_workers=8)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def loudness_norm_one(inputs):
|
| 192 |
+
infile, outfile, target_lufs = inputs
|
| 193 |
+
|
| 194 |
+
audio, fs = ta.load(infile)
|
| 195 |
+
audio = audio.mean(dim=0, keepdim=True).numpy().T
|
| 196 |
+
|
| 197 |
+
meter = pyln.Meter(fs)
|
| 198 |
+
loudness = meter.integrated_loudness(audio)
|
| 199 |
+
audio = pyln.normalize.loudness(audio, loudness, target_lufs)
|
| 200 |
+
|
| 201 |
+
os.makedirs(os.path.dirname(outfile), exist_ok=True)
|
| 202 |
+
np.save(outfile, audio.T)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def loudness_norm(
|
| 206 |
+
data_path: str,
|
| 207 |
+
target_lufs=-17.0,
|
| 208 |
+
):
|
| 209 |
+
files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
|
| 210 |
+
|
| 211 |
+
outfiles = [f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files]
|
| 212 |
+
|
| 213 |
+
files = [(f, o, target_lufs) for f, o in zip(files, outfiles)]
|
| 214 |
+
|
| 215 |
+
process_map(loudness_norm_one, files, chunksize=2)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
if __name__ == "__main__":
|
| 219 |
+
|
| 220 |
+
from tqdm.auto import tqdm
|
| 221 |
+
import fire
|
| 222 |
+
|
| 223 |
+
fire.Fire()
|
|
|
|
|
|
|
|
|
mvsepless/models/bandit/core/data/musdb/validation.yaml
CHANGED
|
@@ -1,15 +1,15 @@
|
|
| 1 |
-
validation:
|
| 2 |
-
- 'Actions - One Minute Smile'
|
| 3 |
-
- 'Clara Berry And Wooldog - Waltz For My Victims'
|
| 4 |
-
- 'Johnny Lokke - Promises & Lies'
|
| 5 |
-
- 'Patrick Talbot - A Reason To Leave'
|
| 6 |
-
- 'Triviul - Angelsaint'
|
| 7 |
-
- 'Alexander Ross - Goodbye Bolero'
|
| 8 |
-
- 'Fergessen - Nos Palpitants'
|
| 9 |
-
- 'Leaf - Summerghost'
|
| 10 |
-
- 'Skelpolu - Human Mistakes'
|
| 11 |
-
- 'Young Griffo - Pennies'
|
| 12 |
-
- 'ANiMAL - Rockshow'
|
| 13 |
-
- 'James May - On The Line'
|
| 14 |
-
- 'Meaxic - Take A Step'
|
| 15 |
- 'Traffic Experiment - Sirens'
|
|
|
|
| 1 |
+
validation:
|
| 2 |
+
- 'Actions - One Minute Smile'
|
| 3 |
+
- 'Clara Berry And Wooldog - Waltz For My Victims'
|
| 4 |
+
- 'Johnny Lokke - Promises & Lies'
|
| 5 |
+
- 'Patrick Talbot - A Reason To Leave'
|
| 6 |
+
- 'Triviul - Angelsaint'
|
| 7 |
+
- 'Alexander Ross - Goodbye Bolero'
|
| 8 |
+
- 'Fergessen - Nos Palpitants'
|
| 9 |
+
- 'Leaf - Summerghost'
|
| 10 |
+
- 'Skelpolu - Human Mistakes'
|
| 11 |
+
- 'Young Griffo - Pennies'
|
| 12 |
+
- 'ANiMAL - Rockshow'
|
| 13 |
+
- 'James May - On The Line'
|
| 14 |
+
- 'Meaxic - Take A Step'
|
| 15 |
- 'Traffic Experiment - Sirens'
|
mvsepless/models/bandit/core/loss/__init__.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
-
from ._multistem import MultiStemWrapperFromConfig
|
| 2 |
-
from ._timefreq import (
|
| 3 |
-
ReImL1Loss,
|
| 4 |
-
ReImL2Loss,
|
| 5 |
-
TimeFreqL1Loss,
|
| 6 |
-
TimeFreqL2Loss,
|
| 7 |
-
TimeFreqSignalNoisePNormRatioLoss,
|
| 8 |
-
)
|
|
|
|
| 1 |
+
from ._multistem import MultiStemWrapperFromConfig
|
| 2 |
+
from ._timefreq import (
|
| 3 |
+
ReImL1Loss,
|
| 4 |
+
ReImL2Loss,
|
| 5 |
+
TimeFreqL1Loss,
|
| 6 |
+
TimeFreqL2Loss,
|
| 7 |
+
TimeFreqSignalNoisePNormRatioLoss,
|
| 8 |
+
)
|
mvsepless/models/bandit/core/loss/_complex.py
CHANGED
|
@@ -1,27 +1,27 @@
|
|
| 1 |
-
from typing import Any
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn
|
| 5 |
-
from torch.nn.modules import loss as _loss
|
| 6 |
-
from torch.nn.modules.loss import _Loss
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class ReImLossWrapper(_Loss):
|
| 10 |
-
def __init__(self, module: _Loss) -> None:
|
| 11 |
-
super().__init__()
|
| 12 |
-
self.module = module
|
| 13 |
-
|
| 14 |
-
def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 15 |
-
return self.module(torch.view_as_real(preds), torch.view_as_real(target))
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class ReImL1Loss(ReImLossWrapper):
|
| 19 |
-
def __init__(self, **kwargs: Any) -> None:
|
| 20 |
-
l1_loss = _loss.L1Loss(**kwargs)
|
| 21 |
-
super().__init__(module=(l1_loss))
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class ReImL2Loss(ReImLossWrapper):
|
| 25 |
-
def __init__(self, **kwargs: Any) -> None:
|
| 26 |
-
l2_loss = _loss.MSELoss(**kwargs)
|
| 27 |
-
super().__init__(module=(l2_loss))
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.modules import loss as _loss
|
| 6 |
+
from torch.nn.modules.loss import _Loss
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ReImLossWrapper(_Loss):
|
| 10 |
+
def __init__(self, module: _Loss) -> None:
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.module = module
|
| 13 |
+
|
| 14 |
+
def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
return self.module(torch.view_as_real(preds), torch.view_as_real(target))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ReImL1Loss(ReImLossWrapper):
|
| 19 |
+
def __init__(self, **kwargs: Any) -> None:
|
| 20 |
+
l1_loss = _loss.L1Loss(**kwargs)
|
| 21 |
+
super().__init__(module=(l1_loss))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ReImL2Loss(ReImLossWrapper):
|
| 25 |
+
def __init__(self, **kwargs: Any) -> None:
|
| 26 |
+
l2_loss = _loss.MSELoss(**kwargs)
|
| 27 |
+
super().__init__(module=(l2_loss))
|
mvsepless/models/bandit/core/loss/_multistem.py
CHANGED
|
@@ -1,43 +1,43 @@
|
|
| 1 |
-
from typing import Any, Dict
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from asteroid import losses as asteroid_losses
|
| 5 |
-
from torch import nn
|
| 6 |
-
from torch.nn.modules.loss import _Loss
|
| 7 |
-
|
| 8 |
-
from . import snr
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss:
|
| 12 |
-
|
| 13 |
-
for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]:
|
| 14 |
-
if name in module.__dict__:
|
| 15 |
-
return module.__dict__[name](**kwargs)
|
| 16 |
-
|
| 17 |
-
raise NameError
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class MultiStemWrapper(_Loss):
|
| 21 |
-
def __init__(self, module: _Loss, modality: str = "audio") -> None:
|
| 22 |
-
super().__init__()
|
| 23 |
-
self.loss = module
|
| 24 |
-
self.modality = modality
|
| 25 |
-
|
| 26 |
-
def forward(
|
| 27 |
-
self,
|
| 28 |
-
preds: Dict[str, Dict[str, torch.Tensor]],
|
| 29 |
-
target: Dict[str, Dict[str, torch.Tensor]],
|
| 30 |
-
) -> torch.Tensor:
|
| 31 |
-
loss = {
|
| 32 |
-
stem: self.loss(preds[self.modality][stem], target[self.modality][stem])
|
| 33 |
-
for stem in preds[self.modality]
|
| 34 |
-
if stem in target[self.modality]
|
| 35 |
-
}
|
| 36 |
-
|
| 37 |
-
return sum(list(loss.values()))
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
class MultiStemWrapperFromConfig(MultiStemWrapper):
|
| 41 |
-
def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None:
|
| 42 |
-
loss = parse_loss(name, kwargs)
|
| 43 |
-
super().__init__(module=loss, modality=modality)
|
|
|
|
| 1 |
+
from typing import Any, Dict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from asteroid import losses as asteroid_losses
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn.modules.loss import _Loss
|
| 7 |
+
|
| 8 |
+
from . import snr
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss:
|
| 12 |
+
|
| 13 |
+
for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]:
|
| 14 |
+
if name in module.__dict__:
|
| 15 |
+
return module.__dict__[name](**kwargs)
|
| 16 |
+
|
| 17 |
+
raise NameError
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MultiStemWrapper(_Loss):
|
| 21 |
+
def __init__(self, module: _Loss, modality: str = "audio") -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.loss = module
|
| 24 |
+
self.modality = modality
|
| 25 |
+
|
| 26 |
+
def forward(
|
| 27 |
+
self,
|
| 28 |
+
preds: Dict[str, Dict[str, torch.Tensor]],
|
| 29 |
+
target: Dict[str, Dict[str, torch.Tensor]],
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
loss = {
|
| 32 |
+
stem: self.loss(preds[self.modality][stem], target[self.modality][stem])
|
| 33 |
+
for stem in preds[self.modality]
|
| 34 |
+
if stem in target[self.modality]
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
return sum(list(loss.values()))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MultiStemWrapperFromConfig(MultiStemWrapper):
|
| 41 |
+
def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None:
|
| 42 |
+
loss = parse_loss(name, kwargs)
|
| 43 |
+
super().__init__(module=loss, modality=modality)
|
mvsepless/models/bandit/core/loss/_timefreq.py
CHANGED
|
@@ -1,95 +1,94 @@
|
|
| 1 |
-
from typing import Any, Dict, Optional
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn
|
| 5 |
-
from torch.nn.modules.loss import _Loss
|
| 6 |
-
|
| 7 |
-
from ._multistem import MultiStemWrapper
|
| 8 |
-
from ._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper
|
| 9 |
-
from .snr import SignalNoisePNormRatio
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class TimeFreqWrapper(_Loss):
|
| 13 |
-
def __init__(
|
| 14 |
-
self,
|
| 15 |
-
time_module: _Loss,
|
| 16 |
-
freq_module: Optional[_Loss] = None,
|
| 17 |
-
time_weight: float = 1.0,
|
| 18 |
-
freq_weight: float = 1.0,
|
| 19 |
-
multistem: bool = True,
|
| 20 |
-
) -> None:
|
| 21 |
-
super().__init__()
|
| 22 |
-
|
| 23 |
-
if freq_module is None:
|
| 24 |
-
freq_module = time_module
|
| 25 |
-
|
| 26 |
-
if multistem:
|
| 27 |
-
time_module = MultiStemWrapper(time_module, modality="audio")
|
| 28 |
-
freq_module = MultiStemWrapper(freq_module, modality="spectrogram")
|
| 29 |
-
|
| 30 |
-
self.time_module = time_module
|
| 31 |
-
self.freq_module = freq_module
|
| 32 |
-
|
| 33 |
-
self.time_weight = time_weight
|
| 34 |
-
self.freq_weight = freq_weight
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
freq_module
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
freq_module
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
freq_module
|
| 95 |
-
super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
|
|
|
|
| 1 |
+
from typing import Any, Dict, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.modules.loss import _Loss
|
| 6 |
+
|
| 7 |
+
from ._multistem import MultiStemWrapper
|
| 8 |
+
from ._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper
|
| 9 |
+
from .snr import SignalNoisePNormRatio
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TimeFreqWrapper(_Loss):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
time_module: _Loss,
|
| 16 |
+
freq_module: Optional[_Loss] = None,
|
| 17 |
+
time_weight: float = 1.0,
|
| 18 |
+
freq_weight: float = 1.0,
|
| 19 |
+
multistem: bool = True,
|
| 20 |
+
) -> None:
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
if freq_module is None:
|
| 24 |
+
freq_module = time_module
|
| 25 |
+
|
| 26 |
+
if multistem:
|
| 27 |
+
time_module = MultiStemWrapper(time_module, modality="audio")
|
| 28 |
+
freq_module = MultiStemWrapper(freq_module, modality="spectrogram")
|
| 29 |
+
|
| 30 |
+
self.time_module = time_module
|
| 31 |
+
self.freq_module = freq_module
|
| 32 |
+
|
| 33 |
+
self.time_weight = time_weight
|
| 34 |
+
self.freq_weight = freq_weight
|
| 35 |
+
|
| 36 |
+
def forward(self, preds: Any, target: Any) -> torch.Tensor:
|
| 37 |
+
|
| 38 |
+
return self.time_weight * self.time_module(
|
| 39 |
+
preds, target
|
| 40 |
+
) + self.freq_weight * self.freq_module(preds, target)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class TimeFreqL1Loss(TimeFreqWrapper):
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
time_weight: float = 1.0,
|
| 47 |
+
freq_weight: float = 1.0,
|
| 48 |
+
tkwargs: Optional[Dict[str, Any]] = None,
|
| 49 |
+
fkwargs: Optional[Dict[str, Any]] = None,
|
| 50 |
+
multistem: bool = True,
|
| 51 |
+
) -> None:
|
| 52 |
+
if tkwargs is None:
|
| 53 |
+
tkwargs = {}
|
| 54 |
+
if fkwargs is None:
|
| 55 |
+
fkwargs = {}
|
| 56 |
+
time_module = nn.L1Loss(**tkwargs)
|
| 57 |
+
freq_module = ReImL1Loss(**fkwargs)
|
| 58 |
+
super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class TimeFreqL2Loss(TimeFreqWrapper):
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
time_weight: float = 1.0,
|
| 65 |
+
freq_weight: float = 1.0,
|
| 66 |
+
tkwargs: Optional[Dict[str, Any]] = None,
|
| 67 |
+
fkwargs: Optional[Dict[str, Any]] = None,
|
| 68 |
+
multistem: bool = True,
|
| 69 |
+
) -> None:
|
| 70 |
+
if tkwargs is None:
|
| 71 |
+
tkwargs = {}
|
| 72 |
+
if fkwargs is None:
|
| 73 |
+
fkwargs = {}
|
| 74 |
+
time_module = nn.MSELoss(**tkwargs)
|
| 75 |
+
freq_module = ReImL2Loss(**fkwargs)
|
| 76 |
+
super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper):
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
time_weight: float = 1.0,
|
| 83 |
+
freq_weight: float = 1.0,
|
| 84 |
+
tkwargs: Optional[Dict[str, Any]] = None,
|
| 85 |
+
fkwargs: Optional[Dict[str, Any]] = None,
|
| 86 |
+
multistem: bool = True,
|
| 87 |
+
) -> None:
|
| 88 |
+
if tkwargs is None:
|
| 89 |
+
tkwargs = {}
|
| 90 |
+
if fkwargs is None:
|
| 91 |
+
fkwargs = {}
|
| 92 |
+
time_module = SignalNoisePNormRatio(**tkwargs)
|
| 93 |
+
freq_module = SignalNoisePNormRatio(**fkwargs)
|
| 94 |
+
super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
|
|
|
mvsepless/models/bandit/core/loss/snr.py
CHANGED
|
@@ -1,139 +1,131 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from torch.nn.modules.loss import _Loss
|
| 3 |
-
from torch.nn import functional as F
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class SignalNoisePNormRatio(_Loss):
|
| 7 |
-
def __init__(
|
| 8 |
-
self,
|
| 9 |
-
p: float = 1.0,
|
| 10 |
-
scale_invariant: bool = False,
|
| 11 |
-
zero_mean: bool = False,
|
| 12 |
-
take_log: bool = True,
|
| 13 |
-
reduction: str = "mean",
|
| 14 |
-
EPS: float = 1e-3,
|
| 15 |
-
) -> None:
|
| 16 |
-
assert reduction != "sum", NotImplementedError
|
| 17 |
-
super().__init__(reduction=reduction)
|
| 18 |
-
assert not zero_mean
|
| 19 |
-
|
| 20 |
-
self.p = p
|
| 21 |
-
|
| 22 |
-
self.EPS = EPS
|
| 23 |
-
self.take_log = take_log
|
| 24 |
-
|
| 25 |
-
self.scale_invariant = scale_invariant
|
| 26 |
-
|
| 27 |
-
def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 28 |
-
|
| 29 |
-
target_ = target
|
| 30 |
-
if self.scale_invariant:
|
| 31 |
-
ndim = target.ndim
|
| 32 |
-
dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True)
|
| 33 |
-
s_target_energy = torch.sum(
|
| 34 |
-
target * torch.conj(target), dim=-1, keepdim=True
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
if ndim > 2:
|
| 38 |
-
dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True)
|
| 39 |
-
s_target_energy = torch.sum(
|
| 40 |
-
s_target_energy, dim=list(range(1, ndim)), keepdim=True
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8)
|
| 44 |
-
target = target_ * target_scaler
|
| 45 |
-
|
| 46 |
-
if torch.is_complex(est_target):
|
| 47 |
-
est_target = torch.view_as_real(est_target)
|
| 48 |
-
target = torch.view_as_real(target)
|
| 49 |
-
|
| 50 |
-
batch_size = est_target.shape[0]
|
| 51 |
-
est_target = est_target.reshape(batch_size, -1)
|
| 52 |
-
target = target.reshape(batch_size, -1)
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
self.
|
| 94 |
-
self.
|
| 95 |
-
self.
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
scaled_target =
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
else:
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
if self.
|
| 129 |
-
losses = torch.
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
else:
|
| 133 |
-
losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / (
|
| 134 |
-
torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
|
| 135 |
-
)
|
| 136 |
-
if self.take_log:
|
| 137 |
-
losses = 10 * torch.log10(losses + self.EPS)
|
| 138 |
-
losses = losses.mean() if self.reduction == "mean" else losses
|
| 139 |
-
return -losses
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.nn.modules.loss import _Loss
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SignalNoisePNormRatio(_Loss):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
p: float = 1.0,
|
| 10 |
+
scale_invariant: bool = False,
|
| 11 |
+
zero_mean: bool = False,
|
| 12 |
+
take_log: bool = True,
|
| 13 |
+
reduction: str = "mean",
|
| 14 |
+
EPS: float = 1e-3,
|
| 15 |
+
) -> None:
|
| 16 |
+
assert reduction != "sum", NotImplementedError
|
| 17 |
+
super().__init__(reduction=reduction)
|
| 18 |
+
assert not zero_mean
|
| 19 |
+
|
| 20 |
+
self.p = p
|
| 21 |
+
|
| 22 |
+
self.EPS = EPS
|
| 23 |
+
self.take_log = take_log
|
| 24 |
+
|
| 25 |
+
self.scale_invariant = scale_invariant
|
| 26 |
+
|
| 27 |
+
def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 28 |
+
|
| 29 |
+
target_ = target
|
| 30 |
+
if self.scale_invariant:
|
| 31 |
+
ndim = target.ndim
|
| 32 |
+
dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True)
|
| 33 |
+
s_target_energy = torch.sum(
|
| 34 |
+
target * torch.conj(target), dim=-1, keepdim=True
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
if ndim > 2:
|
| 38 |
+
dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True)
|
| 39 |
+
s_target_energy = torch.sum(
|
| 40 |
+
s_target_energy, dim=list(range(1, ndim)), keepdim=True
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8)
|
| 44 |
+
target = target_ * target_scaler
|
| 45 |
+
|
| 46 |
+
if torch.is_complex(est_target):
|
| 47 |
+
est_target = torch.view_as_real(est_target)
|
| 48 |
+
target = torch.view_as_real(target)
|
| 49 |
+
|
| 50 |
+
batch_size = est_target.shape[0]
|
| 51 |
+
est_target = est_target.reshape(batch_size, -1)
|
| 52 |
+
target = target.reshape(batch_size, -1)
|
| 53 |
+
|
| 54 |
+
if self.p == 1:
|
| 55 |
+
e_error = torch.abs(est_target - target).mean(dim=-1)
|
| 56 |
+
e_target = torch.abs(target).mean(dim=-1)
|
| 57 |
+
elif self.p == 2:
|
| 58 |
+
e_error = torch.square(est_target - target).mean(dim=-1)
|
| 59 |
+
e_target = torch.square(target).mean(dim=-1)
|
| 60 |
+
else:
|
| 61 |
+
raise NotImplementedError
|
| 62 |
+
|
| 63 |
+
if self.take_log:
|
| 64 |
+
loss = 10 * (
|
| 65 |
+
torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS)
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
loss = (e_error + self.EPS) / (e_target + self.EPS)
|
| 69 |
+
|
| 70 |
+
if self.reduction == "mean":
|
| 71 |
+
loss = loss.mean()
|
| 72 |
+
elif self.reduction == "sum":
|
| 73 |
+
loss = loss.sum()
|
| 74 |
+
|
| 75 |
+
return loss
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class MultichannelSingleSrcNegSDR(_Loss):
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
sdr_type: str,
|
| 82 |
+
p: float = 2.0,
|
| 83 |
+
zero_mean: bool = True,
|
| 84 |
+
take_log: bool = True,
|
| 85 |
+
reduction: str = "mean",
|
| 86 |
+
EPS: float = 1e-8,
|
| 87 |
+
) -> None:
|
| 88 |
+
assert reduction != "sum", NotImplementedError
|
| 89 |
+
super().__init__(reduction=reduction)
|
| 90 |
+
|
| 91 |
+
assert sdr_type in ["snr", "sisdr", "sdsdr"]
|
| 92 |
+
self.sdr_type = sdr_type
|
| 93 |
+
self.zero_mean = zero_mean
|
| 94 |
+
self.take_log = take_log
|
| 95 |
+
self.EPS = 1e-8
|
| 96 |
+
|
| 97 |
+
self.p = p
|
| 98 |
+
|
| 99 |
+
def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
if target.size() != est_target.size() or target.ndim != 3:
|
| 101 |
+
raise TypeError(
|
| 102 |
+
f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead"
|
| 103 |
+
)
|
| 104 |
+
if self.zero_mean:
|
| 105 |
+
mean_source = torch.mean(target, dim=[1, 2], keepdim=True)
|
| 106 |
+
mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True)
|
| 107 |
+
target = target - mean_source
|
| 108 |
+
est_target = est_target - mean_estimate
|
| 109 |
+
if self.sdr_type in ["sisdr", "sdsdr"]:
|
| 110 |
+
dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True)
|
| 111 |
+
s_target_energy = torch.sum(target**2, dim=[1, 2], keepdim=True) + self.EPS
|
| 112 |
+
scaled_target = dot * target / s_target_energy
|
| 113 |
+
else:
|
| 114 |
+
scaled_target = target
|
| 115 |
+
if self.sdr_type in ["sdsdr", "snr"]:
|
| 116 |
+
e_noise = est_target - target
|
| 117 |
+
else:
|
| 118 |
+
e_noise = est_target - scaled_target
|
| 119 |
+
|
| 120 |
+
if self.p == 2.0:
|
| 121 |
+
losses = torch.sum(scaled_target**2, dim=[1, 2]) / (
|
| 122 |
+
torch.sum(e_noise**2, dim=[1, 2]) + self.EPS
|
| 123 |
+
)
|
| 124 |
+
else:
|
| 125 |
+
losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / (
|
| 126 |
+
torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
|
| 127 |
+
)
|
| 128 |
+
if self.take_log:
|
| 129 |
+
losses = 10 * torch.log10(losses + self.EPS)
|
| 130 |
+
losses = losses.mean() if self.reduction == "mean" else losses
|
| 131 |
+
return -losses
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mvsepless/models/bandit/core/metrics/__init__.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
| 1 |
-
from .snr import (
|
| 2 |
-
ChunkMedianScaleInvariantSignalDistortionRatio,
|
| 3 |
-
ChunkMedianScaleInvariantSignalNoiseRatio,
|
| 4 |
-
ChunkMedianSignalDistortionRatio,
|
| 5 |
-
ChunkMedianSignalNoiseRatio,
|
| 6 |
-
SafeSignalDistortionRatio,
|
| 7 |
-
)
|
| 8 |
-
|
| 9 |
-
# from .mushra import EstimatedMushraScore
|
|
|
|
| 1 |
+
from .snr import (
|
| 2 |
+
ChunkMedianScaleInvariantSignalDistortionRatio,
|
| 3 |
+
ChunkMedianScaleInvariantSignalNoiseRatio,
|
| 4 |
+
ChunkMedianSignalDistortionRatio,
|
| 5 |
+
ChunkMedianSignalNoiseRatio,
|
| 6 |
+
SafeSignalDistortionRatio,
|
| 7 |
+
)
|
|
|
|
|
|
mvsepless/models/bandit/core/metrics/_squim.py
CHANGED
|
@@ -1,443 +1,350 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
|
| 3 |
-
from torchaudio._internal import load_state_dict_from_url
|
| 4 |
-
|
| 5 |
-
import math
|
| 6 |
-
from typing import List, Optional, Tuple
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn as nn
|
| 10 |
-
import torch.nn.functional as F
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def transform_wb_pesq_range(x: float) -> float:
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
self.
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
out =
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
)
|
| 95 |
-
|
| 96 |
-
self.
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
)
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
out
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
)
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
.
|
| 173 |
-
)
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
out = out
|
| 203 |
-
out =
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
.
|
| 239 |
-
)
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
out = self.
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
)
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
"""
|
| 351 |
-
|
| 352 |
-
Args:
|
| 353 |
-
feat_dim (int, optional): The feature dimension after Encoder module.
|
| 354 |
-
win_len (int): Kernel size in the Encoder module.
|
| 355 |
-
d_model (int): The number of expected features in the input.
|
| 356 |
-
nhead (int): Number of heads in the multi-head attention model.
|
| 357 |
-
hidden_dim (int): Hidden dimension in the RNN layer of DPRNN.
|
| 358 |
-
num_blocks (int): Number of DPRNN layers.
|
| 359 |
-
rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"].
|
| 360 |
-
chunk_size (int): Chunk size of input for DPRNN.
|
| 361 |
-
chunk_stride (int or None, optional): Stride of chunk input for DPRNN.
|
| 362 |
-
"""
|
| 363 |
-
if chunk_stride is None:
|
| 364 |
-
chunk_stride = chunk_size // 2
|
| 365 |
-
encoder = Encoder(feat_dim, win_len)
|
| 366 |
-
dprnn = DPRNN(
|
| 367 |
-
feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride
|
| 368 |
-
)
|
| 369 |
-
branches = nn.ModuleList(
|
| 370 |
-
[
|
| 371 |
-
_create_branch(d_model, nhead, "stoi"),
|
| 372 |
-
_create_branch(d_model, nhead, "pesq"),
|
| 373 |
-
_create_branch(d_model, nhead, "sisdr"),
|
| 374 |
-
]
|
| 375 |
-
)
|
| 376 |
-
return SquimObjective(encoder, dprnn, branches)
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
def squim_objective_base() -> SquimObjective:
|
| 380 |
-
"""Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments."""
|
| 381 |
-
return squim_objective_model(
|
| 382 |
-
feat_dim=256,
|
| 383 |
-
win_len=64,
|
| 384 |
-
d_model=256,
|
| 385 |
-
nhead=4,
|
| 386 |
-
hidden_dim=256,
|
| 387 |
-
num_blocks=2,
|
| 388 |
-
rnn_type="LSTM",
|
| 389 |
-
chunk_size=71,
|
| 390 |
-
)
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
@dataclass
|
| 394 |
-
class SquimObjectiveBundle:
|
| 395 |
-
|
| 396 |
-
_path: str
|
| 397 |
-
_sample_rate: float
|
| 398 |
-
|
| 399 |
-
def _get_state_dict(self, dl_kwargs):
|
| 400 |
-
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
|
| 401 |
-
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
|
| 402 |
-
state_dict = load_state_dict_from_url(url, **dl_kwargs)
|
| 403 |
-
return state_dict
|
| 404 |
-
|
| 405 |
-
def get_model(self, *, dl_kwargs=None) -> SquimObjective:
|
| 406 |
-
"""Construct the SquimObjective model, and load the pretrained weight.
|
| 407 |
-
|
| 408 |
-
The weight file is downloaded from the internet and cached with
|
| 409 |
-
:func:`torch.hub.load_state_dict_from_url`
|
| 410 |
-
|
| 411 |
-
Args:
|
| 412 |
-
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
|
| 413 |
-
|
| 414 |
-
Returns:
|
| 415 |
-
Variation of :py:class:`~torchaudio.models.SquimObjective`.
|
| 416 |
-
"""
|
| 417 |
-
model = squim_objective_base()
|
| 418 |
-
model.load_state_dict(self._get_state_dict(dl_kwargs))
|
| 419 |
-
model.eval()
|
| 420 |
-
return model
|
| 421 |
-
|
| 422 |
-
@property
|
| 423 |
-
def sample_rate(self):
|
| 424 |
-
"""Sample rate of the audio that the model is trained on.
|
| 425 |
-
|
| 426 |
-
:type: float
|
| 427 |
-
"""
|
| 428 |
-
return self._sample_rate
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
SQUIM_OBJECTIVE = SquimObjectiveBundle(
|
| 432 |
-
"squim_objective_dns2020.pth",
|
| 433 |
-
_sample_rate=16000,
|
| 434 |
-
)
|
| 435 |
-
SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
|
| 436 |
-
:cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
|
| 437 |
-
|
| 438 |
-
The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
|
| 439 |
-
The weights are under `Creative Commons Attribution 4.0 International License
|
| 440 |
-
<https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
|
| 441 |
-
|
| 442 |
-
Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
|
| 443 |
-
"""
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
from torchaudio._internal import load_state_dict_from_url
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from typing import List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def transform_wb_pesq_range(x: float) -> float:
|
| 14 |
+
return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
PESQRange: Tuple[float, float] = (
|
| 18 |
+
1.0,
|
| 19 |
+
transform_wb_pesq_range(4.5),
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class RangeSigmoid(nn.Module):
|
| 24 |
+
def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
|
| 25 |
+
super(RangeSigmoid, self).__init__()
|
| 26 |
+
assert isinstance(val_range, tuple) and len(val_range) == 2
|
| 27 |
+
self.val_range: Tuple[float, float] = val_range
|
| 28 |
+
self.sigmoid: nn.modules.Module = nn.Sigmoid()
|
| 29 |
+
|
| 30 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
out = (
|
| 32 |
+
self.sigmoid(x) * (self.val_range[1] - self.val_range[0])
|
| 33 |
+
+ self.val_range[0]
|
| 34 |
+
)
|
| 35 |
+
return out
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Encoder(nn.Module):
|
| 39 |
+
|
| 40 |
+
def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
|
| 41 |
+
super(Encoder, self).__init__()
|
| 42 |
+
|
| 43 |
+
self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
|
| 44 |
+
|
| 45 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
out = x.unsqueeze(dim=1)
|
| 47 |
+
out = F.relu(self.conv1d(out))
|
| 48 |
+
return out
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class SingleRNN(nn.Module):
|
| 52 |
+
def __init__(
|
| 53 |
+
self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0
|
| 54 |
+
) -> None:
|
| 55 |
+
super(SingleRNN, self).__init__()
|
| 56 |
+
|
| 57 |
+
self.rnn_type = rnn_type
|
| 58 |
+
self.input_size = input_size
|
| 59 |
+
self.hidden_size = hidden_size
|
| 60 |
+
|
| 61 |
+
self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
|
| 62 |
+
input_size,
|
| 63 |
+
hidden_size,
|
| 64 |
+
1,
|
| 65 |
+
dropout=dropout,
|
| 66 |
+
batch_first=True,
|
| 67 |
+
bidirectional=True,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
self.proj = nn.Linear(hidden_size * 2, input_size)
|
| 71 |
+
|
| 72 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
out, _ = self.rnn(x)
|
| 74 |
+
out = self.proj(out)
|
| 75 |
+
return out
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class DPRNN(nn.Module):
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
feat_dim: int = 64,
|
| 83 |
+
hidden_dim: int = 128,
|
| 84 |
+
num_blocks: int = 6,
|
| 85 |
+
rnn_type: str = "LSTM",
|
| 86 |
+
d_model: int = 256,
|
| 87 |
+
chunk_size: int = 100,
|
| 88 |
+
chunk_stride: int = 50,
|
| 89 |
+
) -> None:
|
| 90 |
+
super(DPRNN, self).__init__()
|
| 91 |
+
|
| 92 |
+
self.num_blocks = num_blocks
|
| 93 |
+
|
| 94 |
+
self.row_rnn = nn.ModuleList([])
|
| 95 |
+
self.col_rnn = nn.ModuleList([])
|
| 96 |
+
self.row_norm = nn.ModuleList([])
|
| 97 |
+
self.col_norm = nn.ModuleList([])
|
| 98 |
+
for _ in range(num_blocks):
|
| 99 |
+
self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
|
| 100 |
+
self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
|
| 101 |
+
self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
|
| 102 |
+
self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
|
| 103 |
+
self.conv = nn.Sequential(
|
| 104 |
+
nn.Conv2d(feat_dim, d_model, 1),
|
| 105 |
+
nn.PReLU(),
|
| 106 |
+
)
|
| 107 |
+
self.chunk_size = chunk_size
|
| 108 |
+
self.chunk_stride = chunk_stride
|
| 109 |
+
|
| 110 |
+
def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
| 111 |
+
seq_len = x.shape[-1]
|
| 112 |
+
|
| 113 |
+
rest = (
|
| 114 |
+
self.chunk_size
|
| 115 |
+
- (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
|
| 116 |
+
)
|
| 117 |
+
out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
|
| 118 |
+
|
| 119 |
+
return out, rest
|
| 120 |
+
|
| 121 |
+
def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
| 122 |
+
out, rest = self.pad_chunk(x)
|
| 123 |
+
batch_size, feat_dim, seq_len = out.shape
|
| 124 |
+
|
| 125 |
+
segments1 = (
|
| 126 |
+
out[:, :, : -self.chunk_stride]
|
| 127 |
+
.contiguous()
|
| 128 |
+
.view(batch_size, feat_dim, -1, self.chunk_size)
|
| 129 |
+
)
|
| 130 |
+
segments2 = (
|
| 131 |
+
out[:, :, self.chunk_stride :]
|
| 132 |
+
.contiguous()
|
| 133 |
+
.view(batch_size, feat_dim, -1, self.chunk_size)
|
| 134 |
+
)
|
| 135 |
+
out = torch.cat([segments1, segments2], dim=3)
|
| 136 |
+
out = (
|
| 137 |
+
out.view(batch_size, feat_dim, -1, self.chunk_size)
|
| 138 |
+
.transpose(2, 3)
|
| 139 |
+
.contiguous()
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
return out, rest
|
| 143 |
+
|
| 144 |
+
def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
|
| 145 |
+
batch_size, dim, _, _ = x.shape
|
| 146 |
+
out = (
|
| 147 |
+
x.transpose(2, 3)
|
| 148 |
+
.contiguous()
|
| 149 |
+
.view(batch_size, dim, -1, self.chunk_size * 2)
|
| 150 |
+
)
|
| 151 |
+
out1 = (
|
| 152 |
+
out[:, :, :, : self.chunk_size]
|
| 153 |
+
.contiguous()
|
| 154 |
+
.view(batch_size, dim, -1)[:, :, self.chunk_stride :]
|
| 155 |
+
)
|
| 156 |
+
out2 = (
|
| 157 |
+
out[:, :, :, self.chunk_size :]
|
| 158 |
+
.contiguous()
|
| 159 |
+
.view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
|
| 160 |
+
)
|
| 161 |
+
out = out1 + out2
|
| 162 |
+
if rest > 0:
|
| 163 |
+
out = out[:, :, :-rest]
|
| 164 |
+
out = out.contiguous()
|
| 165 |
+
return out
|
| 166 |
+
|
| 167 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 168 |
+
x, rest = self.chunking(x)
|
| 169 |
+
batch_size, _, dim1, dim2 = x.shape
|
| 170 |
+
out = x
|
| 171 |
+
for row_rnn, row_norm, col_rnn, col_norm in zip(
|
| 172 |
+
self.row_rnn, self.row_norm, self.col_rnn, self.col_norm
|
| 173 |
+
):
|
| 174 |
+
row_in = (
|
| 175 |
+
out.permute(0, 3, 2, 1)
|
| 176 |
+
.contiguous()
|
| 177 |
+
.view(batch_size * dim2, dim1, -1)
|
| 178 |
+
.contiguous()
|
| 179 |
+
)
|
| 180 |
+
row_out = row_rnn(row_in)
|
| 181 |
+
row_out = (
|
| 182 |
+
row_out.view(batch_size, dim2, dim1, -1)
|
| 183 |
+
.permute(0, 3, 2, 1)
|
| 184 |
+
.contiguous()
|
| 185 |
+
)
|
| 186 |
+
row_out = row_norm(row_out)
|
| 187 |
+
out = out + row_out
|
| 188 |
+
|
| 189 |
+
col_in = (
|
| 190 |
+
out.permute(0, 2, 3, 1)
|
| 191 |
+
.contiguous()
|
| 192 |
+
.view(batch_size * dim1, dim2, -1)
|
| 193 |
+
.contiguous()
|
| 194 |
+
)
|
| 195 |
+
col_out = col_rnn(col_in)
|
| 196 |
+
col_out = (
|
| 197 |
+
col_out.view(batch_size, dim1, dim2, -1)
|
| 198 |
+
.permute(0, 3, 1, 2)
|
| 199 |
+
.contiguous()
|
| 200 |
+
)
|
| 201 |
+
col_out = col_norm(col_out)
|
| 202 |
+
out = out + col_out
|
| 203 |
+
out = self.conv(out)
|
| 204 |
+
out = self.merging(out, rest)
|
| 205 |
+
out = out.transpose(1, 2).contiguous()
|
| 206 |
+
return out
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class AutoPool(nn.Module):
|
| 210 |
+
def __init__(self, pool_dim: int = 1) -> None:
|
| 211 |
+
super(AutoPool, self).__init__()
|
| 212 |
+
self.pool_dim: int = pool_dim
|
| 213 |
+
self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
|
| 214 |
+
self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
|
| 215 |
+
|
| 216 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 217 |
+
weight = self.softmax(torch.mul(x, self.alpha))
|
| 218 |
+
out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
|
| 219 |
+
return out
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class SquimObjective(nn.Module):
|
| 223 |
+
|
| 224 |
+
def __init__(
|
| 225 |
+
self,
|
| 226 |
+
encoder: nn.Module,
|
| 227 |
+
dprnn: nn.Module,
|
| 228 |
+
branches: nn.ModuleList,
|
| 229 |
+
):
|
| 230 |
+
super(SquimObjective, self).__init__()
|
| 231 |
+
self.encoder = encoder
|
| 232 |
+
self.dprnn = dprnn
|
| 233 |
+
self.branches = branches
|
| 234 |
+
|
| 235 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 236 |
+
if x.ndim != 2:
|
| 237 |
+
raise ValueError(
|
| 238 |
+
f"The input must be a 2D Tensor. Found dimension {x.ndim}."
|
| 239 |
+
)
|
| 240 |
+
x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
|
| 241 |
+
out = self.encoder(x)
|
| 242 |
+
out = self.dprnn(out)
|
| 243 |
+
scores = []
|
| 244 |
+
for branch in self.branches:
|
| 245 |
+
scores.append(branch(out).squeeze(dim=1))
|
| 246 |
+
return scores
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
|
| 250 |
+
layer1 = nn.TransformerEncoderLayer(
|
| 251 |
+
d_model, nhead, d_model * 4, dropout=0.0, batch_first=True
|
| 252 |
+
)
|
| 253 |
+
layer2 = AutoPool()
|
| 254 |
+
if metric == "stoi":
|
| 255 |
+
layer3 = nn.Sequential(
|
| 256 |
+
nn.Linear(d_model, d_model),
|
| 257 |
+
nn.PReLU(),
|
| 258 |
+
nn.Linear(d_model, 1),
|
| 259 |
+
RangeSigmoid(),
|
| 260 |
+
)
|
| 261 |
+
elif metric == "pesq":
|
| 262 |
+
layer3 = nn.Sequential(
|
| 263 |
+
nn.Linear(d_model, d_model),
|
| 264 |
+
nn.PReLU(),
|
| 265 |
+
nn.Linear(d_model, 1),
|
| 266 |
+
RangeSigmoid(val_range=PESQRange),
|
| 267 |
+
)
|
| 268 |
+
else:
|
| 269 |
+
layer3: nn.modules.Module = nn.Sequential(
|
| 270 |
+
nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1)
|
| 271 |
+
)
|
| 272 |
+
return nn.Sequential(layer1, layer2, layer3)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def squim_objective_model(
|
| 276 |
+
feat_dim: int,
|
| 277 |
+
win_len: int,
|
| 278 |
+
d_model: int,
|
| 279 |
+
nhead: int,
|
| 280 |
+
hidden_dim: int,
|
| 281 |
+
num_blocks: int,
|
| 282 |
+
rnn_type: str,
|
| 283 |
+
chunk_size: int,
|
| 284 |
+
chunk_stride: Optional[int] = None,
|
| 285 |
+
) -> SquimObjective:
|
| 286 |
+
if chunk_stride is None:
|
| 287 |
+
chunk_stride = chunk_size // 2
|
| 288 |
+
encoder = Encoder(feat_dim, win_len)
|
| 289 |
+
dprnn = DPRNN(
|
| 290 |
+
feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride
|
| 291 |
+
)
|
| 292 |
+
branches = nn.ModuleList(
|
| 293 |
+
[
|
| 294 |
+
_create_branch(d_model, nhead, "stoi"),
|
| 295 |
+
_create_branch(d_model, nhead, "pesq"),
|
| 296 |
+
_create_branch(d_model, nhead, "sisdr"),
|
| 297 |
+
]
|
| 298 |
+
)
|
| 299 |
+
return SquimObjective(encoder, dprnn, branches)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def squim_objective_base() -> SquimObjective:
|
| 303 |
+
return squim_objective_model(
|
| 304 |
+
feat_dim=256,
|
| 305 |
+
win_len=64,
|
| 306 |
+
d_model=256,
|
| 307 |
+
nhead=4,
|
| 308 |
+
hidden_dim=256,
|
| 309 |
+
num_blocks=2,
|
| 310 |
+
rnn_type="LSTM",
|
| 311 |
+
chunk_size=71,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@dataclass
|
| 316 |
+
class SquimObjectiveBundle:
|
| 317 |
+
|
| 318 |
+
_path: str
|
| 319 |
+
_sample_rate: float
|
| 320 |
+
|
| 321 |
+
def _get_state_dict(self, dl_kwargs):
|
| 322 |
+
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
|
| 323 |
+
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
|
| 324 |
+
state_dict = load_state_dict_from_url(url, **dl_kwargs)
|
| 325 |
+
return state_dict
|
| 326 |
+
|
| 327 |
+
def get_model(self, *, dl_kwargs=None) -> SquimObjective:
|
| 328 |
+
model = squim_objective_base()
|
| 329 |
+
model.load_state_dict(self._get_state_dict(dl_kwargs))
|
| 330 |
+
model.eval()
|
| 331 |
+
return model
|
| 332 |
+
|
| 333 |
+
@property
|
| 334 |
+
def sample_rate(self):
|
| 335 |
+
return self._sample_rate
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
SQUIM_OBJECTIVE = SquimObjectiveBundle(
|
| 339 |
+
"squim_objective_dns2020.pth",
|
| 340 |
+
_sample_rate=16000,
|
| 341 |
+
)
|
| 342 |
+
SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
|
| 343 |
+
:cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
|
| 344 |
+
|
| 345 |
+
The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
|
| 346 |
+
The weights are under `Creative Commons Attribution 4.0 International License
|
| 347 |
+
<https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
|
| 348 |
+
|
| 349 |
+
Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
|
| 350 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mvsepless/models/bandit/core/metrics/snr.py
CHANGED
|
@@ -1,127 +1,124 @@
|
|
| 1 |
-
from typing import Any, Callable
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
import torchmetrics as tm
|
| 6 |
-
from torch._C import _LinAlgError
|
| 7 |
-
from torchmetrics import functional as tmF
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class SafeSignalDistortionRatio(tm.SignalDistortionRatio):
|
| 11 |
-
def __init__(self, **kwargs) -> None:
|
| 12 |
-
super().__init__(**kwargs)
|
| 13 |
-
|
| 14 |
-
def update(self, *args, **kwargs) -> Any:
|
| 15 |
-
try:
|
| 16 |
-
super().update(*args, **kwargs)
|
| 17 |
-
except:
|
| 18 |
-
pass
|
| 19 |
-
|
| 20 |
-
def compute(self) -> Any:
|
| 21 |
-
if self.total == 0:
|
| 22 |
-
return torch.tensor(torch.nan)
|
| 23 |
-
return super().compute()
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
class BaseChunkMedianSignalRatio(tm.Metric):
|
| 27 |
-
def __init__(
|
| 28 |
-
self,
|
| 29 |
-
func: Callable,
|
| 30 |
-
window_size: int,
|
| 31 |
-
hop_size: int = None,
|
| 32 |
-
zero_mean: bool = False,
|
| 33 |
-
) -> None:
|
| 34 |
-
super().__init__()
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
self.
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
self.add_state("
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
self.
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
hop_size=hop_size,
|
| 126 |
-
zero_mean=zero_mean,
|
| 127 |
-
)
|
|
|
|
| 1 |
+
from typing import Any, Callable
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torchmetrics as tm
|
| 6 |
+
from torch._C import _LinAlgError
|
| 7 |
+
from torchmetrics import functional as tmF
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SafeSignalDistortionRatio(tm.SignalDistortionRatio):
|
| 11 |
+
def __init__(self, **kwargs) -> None:
|
| 12 |
+
super().__init__(**kwargs)
|
| 13 |
+
|
| 14 |
+
def update(self, *args, **kwargs) -> Any:
|
| 15 |
+
try:
|
| 16 |
+
super().update(*args, **kwargs)
|
| 17 |
+
except:
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
def compute(self) -> Any:
|
| 21 |
+
if self.total == 0:
|
| 22 |
+
return torch.tensor(torch.nan)
|
| 23 |
+
return super().compute()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class BaseChunkMedianSignalRatio(tm.Metric):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
func: Callable,
|
| 30 |
+
window_size: int,
|
| 31 |
+
hop_size: int = None,
|
| 32 |
+
zero_mean: bool = False,
|
| 33 |
+
) -> None:
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
self.func = func
|
| 37 |
+
self.window_size = window_size
|
| 38 |
+
if hop_size is None:
|
| 39 |
+
hop_size = window_size
|
| 40 |
+
self.hop_size = hop_size
|
| 41 |
+
|
| 42 |
+
self.add_state("sum_snr", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
| 43 |
+
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
| 44 |
+
|
| 45 |
+
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
|
| 46 |
+
|
| 47 |
+
n_samples = target.shape[-1]
|
| 48 |
+
|
| 49 |
+
n_chunks = int(np.ceil((n_samples - self.window_size) / self.hop_size) + 1)
|
| 50 |
+
|
| 51 |
+
snr_chunk = []
|
| 52 |
+
|
| 53 |
+
for i in range(n_chunks):
|
| 54 |
+
start = i * self.hop_size
|
| 55 |
+
|
| 56 |
+
if n_samples - start < self.window_size:
|
| 57 |
+
continue
|
| 58 |
+
|
| 59 |
+
end = start + self.window_size
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
chunk_snr = self.func(preds[..., start:end], target[..., start:end])
|
| 63 |
+
|
| 64 |
+
if torch.all(torch.isfinite(chunk_snr)):
|
| 65 |
+
snr_chunk.append(chunk_snr)
|
| 66 |
+
except _LinAlgError:
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
snr_chunk = torch.stack(snr_chunk, dim=-1)
|
| 70 |
+
snr_batch, _ = torch.nanmedian(snr_chunk, dim=-1)
|
| 71 |
+
|
| 72 |
+
self.sum_snr += snr_batch.sum()
|
| 73 |
+
self.total += snr_batch.numel()
|
| 74 |
+
|
| 75 |
+
def compute(self) -> Any:
|
| 76 |
+
return self.sum_snr / self.total
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio):
|
| 80 |
+
def __init__(
|
| 81 |
+
self, window_size: int, hop_size: int = None, zero_mean: bool = False
|
| 82 |
+
) -> None:
|
| 83 |
+
super().__init__(
|
| 84 |
+
func=tmF.signal_noise_ratio,
|
| 85 |
+
window_size=window_size,
|
| 86 |
+
hop_size=hop_size,
|
| 87 |
+
zero_mean=zero_mean,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio):
|
| 92 |
+
def __init__(
|
| 93 |
+
self, window_size: int, hop_size: int = None, zero_mean: bool = False
|
| 94 |
+
) -> None:
|
| 95 |
+
super().__init__(
|
| 96 |
+
func=tmF.scale_invariant_signal_noise_ratio,
|
| 97 |
+
window_size=window_size,
|
| 98 |
+
hop_size=hop_size,
|
| 99 |
+
zero_mean=zero_mean,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio):
|
| 104 |
+
def __init__(
|
| 105 |
+
self, window_size: int, hop_size: int = None, zero_mean: bool = False
|
| 106 |
+
) -> None:
|
| 107 |
+
super().__init__(
|
| 108 |
+
func=tmF.signal_distortion_ratio,
|
| 109 |
+
window_size=window_size,
|
| 110 |
+
hop_size=hop_size,
|
| 111 |
+
zero_mean=zero_mean,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class ChunkMedianScaleInvariantSignalDistortionRatio(BaseChunkMedianSignalRatio):
|
| 116 |
+
def __init__(
|
| 117 |
+
self, window_size: int, hop_size: int = None, zero_mean: bool = False
|
| 118 |
+
) -> None:
|
| 119 |
+
super().__init__(
|
| 120 |
+
func=tmF.scale_invariant_signal_distortion_ratio,
|
| 121 |
+
window_size=window_size,
|
| 122 |
+
hop_size=hop_size,
|
| 123 |
+
zero_mean=zero_mean,
|
| 124 |
+
)
|
|
|
|
|
|
|
|
|
mvsepless/models/bandit/core/model/__init__.py
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
-
from .bsrnn.wrapper import (
|
| 2 |
-
MultiMaskMultiSourceBandSplitRNNSimple,
|
| 3 |
-
)
|
|
|
|
| 1 |
+
from .bsrnn.wrapper import (
|
| 2 |
+
MultiMaskMultiSourceBandSplitRNNSimple,
|
| 3 |
+
)
|
mvsepless/models/bandit/core/model/_spectral.py
CHANGED
|
@@ -1,54 +1,54 @@
|
|
| 1 |
-
from typing import Dict, Optional
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torchaudio as ta
|
| 5 |
-
from torch import nn
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class _SpectralComponent(nn.Module):
|
| 9 |
-
def __init__(
|
| 10 |
-
self,
|
| 11 |
-
n_fft: int = 2048,
|
| 12 |
-
win_length: Optional[int] = 2048,
|
| 13 |
-
hop_length: int = 512,
|
| 14 |
-
window_fn: str = "hann_window",
|
| 15 |
-
wkwargs: Optional[Dict] = None,
|
| 16 |
-
power: Optional[int] = None,
|
| 17 |
-
center: bool = True,
|
| 18 |
-
normalized: bool = True,
|
| 19 |
-
pad_mode: str = "constant",
|
| 20 |
-
onesided: bool = True,
|
| 21 |
-
**kwargs,
|
| 22 |
-
) -> None:
|
| 23 |
-
super().__init__()
|
| 24 |
-
|
| 25 |
-
assert power is None
|
| 26 |
-
|
| 27 |
-
window_fn = torch.__dict__[window_fn]
|
| 28 |
-
|
| 29 |
-
self.stft = ta.transforms.Spectrogram(
|
| 30 |
-
n_fft=n_fft,
|
| 31 |
-
win_length=win_length,
|
| 32 |
-
hop_length=hop_length,
|
| 33 |
-
pad_mode=pad_mode,
|
| 34 |
-
pad=0,
|
| 35 |
-
window_fn=window_fn,
|
| 36 |
-
wkwargs=wkwargs,
|
| 37 |
-
power=power,
|
| 38 |
-
normalized=normalized,
|
| 39 |
-
center=center,
|
| 40 |
-
onesided=onesided,
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
self.istft = ta.transforms.InverseSpectrogram(
|
| 44 |
-
n_fft=n_fft,
|
| 45 |
-
win_length=win_length,
|
| 46 |
-
hop_length=hop_length,
|
| 47 |
-
pad_mode=pad_mode,
|
| 48 |
-
pad=0,
|
| 49 |
-
window_fn=window_fn,
|
| 50 |
-
wkwargs=wkwargs,
|
| 51 |
-
normalized=normalized,
|
| 52 |
-
center=center,
|
| 53 |
-
onesided=onesided,
|
| 54 |
-
)
|
|
|
|
| 1 |
+
from typing import Dict, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio as ta
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class _SpectralComponent(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
n_fft: int = 2048,
|
| 12 |
+
win_length: Optional[int] = 2048,
|
| 13 |
+
hop_length: int = 512,
|
| 14 |
+
window_fn: str = "hann_window",
|
| 15 |
+
wkwargs: Optional[Dict] = None,
|
| 16 |
+
power: Optional[int] = None,
|
| 17 |
+
center: bool = True,
|
| 18 |
+
normalized: bool = True,
|
| 19 |
+
pad_mode: str = "constant",
|
| 20 |
+
onesided: bool = True,
|
| 21 |
+
**kwargs,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
|
| 25 |
+
assert power is None
|
| 26 |
+
|
| 27 |
+
window_fn = torch.__dict__[window_fn]
|
| 28 |
+
|
| 29 |
+
self.stft = ta.transforms.Spectrogram(
|
| 30 |
+
n_fft=n_fft,
|
| 31 |
+
win_length=win_length,
|
| 32 |
+
hop_length=hop_length,
|
| 33 |
+
pad_mode=pad_mode,
|
| 34 |
+
pad=0,
|
| 35 |
+
window_fn=window_fn,
|
| 36 |
+
wkwargs=wkwargs,
|
| 37 |
+
power=power,
|
| 38 |
+
normalized=normalized,
|
| 39 |
+
center=center,
|
| 40 |
+
onesided=onesided,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
self.istft = ta.transforms.InverseSpectrogram(
|
| 44 |
+
n_fft=n_fft,
|
| 45 |
+
win_length=win_length,
|
| 46 |
+
hop_length=hop_length,
|
| 47 |
+
pad_mode=pad_mode,
|
| 48 |
+
pad=0,
|
| 49 |
+
window_fn=window_fn,
|
| 50 |
+
wkwargs=wkwargs,
|
| 51 |
+
normalized=normalized,
|
| 52 |
+
center=center,
|
| 53 |
+
onesided=onesided,
|
| 54 |
+
)
|
mvsepless/models/bandit/core/model/bsrnn/__init__.py
CHANGED
|
@@ -1,23 +1,23 @@
|
|
| 1 |
-
from abc import ABC
|
| 2 |
-
from typing import Iterable, Mapping, Union
|
| 3 |
-
|
| 4 |
-
from torch import nn
|
| 5 |
-
|
| 6 |
-
from .bandsplit import BandSplitModule
|
| 7 |
-
from .tfmodel import (
|
| 8 |
-
SeqBandModellingModule,
|
| 9 |
-
TransformerTimeFreqModule,
|
| 10 |
-
)
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class BandsplitCoreBase(nn.Module, ABC):
|
| 14 |
-
band_split: nn.Module
|
| 15 |
-
tf_model: nn.Module
|
| 16 |
-
mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]]
|
| 17 |
-
|
| 18 |
-
def __init__(self) -> None:
|
| 19 |
-
super().__init__()
|
| 20 |
-
|
| 21 |
-
@staticmethod
|
| 22 |
-
def mask(x, m):
|
| 23 |
-
return x * m
|
|
|
|
| 1 |
+
from abc import ABC
|
| 2 |
+
from typing import Iterable, Mapping, Union
|
| 3 |
+
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
from .bandsplit import BandSplitModule
|
| 7 |
+
from .tfmodel import (
|
| 8 |
+
SeqBandModellingModule,
|
| 9 |
+
TransformerTimeFreqModule,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BandsplitCoreBase(nn.Module, ABC):
|
| 14 |
+
band_split: nn.Module
|
| 15 |
+
tf_model: nn.Module
|
| 16 |
+
mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]]
|
| 17 |
+
|
| 18 |
+
def __init__(self) -> None:
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
@staticmethod
|
| 22 |
+
def mask(x, m):
|
| 23 |
+
return x * m
|
mvsepless/models/bandit/core/model/bsrnn/bandsplit.py
CHANGED
|
@@ -1,135 +1,119 @@
|
|
| 1 |
-
from typing import List, Tuple
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn
|
| 5 |
-
|
| 6 |
-
from .utils import (
|
| 7 |
-
band_widths_from_specs,
|
| 8 |
-
check_no_gap,
|
| 9 |
-
check_no_overlap,
|
| 10 |
-
check_nonzero_bandwidth,
|
| 11 |
-
)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class NormFC(nn.Module):
|
| 15 |
-
def __init__(
|
| 16 |
-
self,
|
| 17 |
-
emb_dim: int,
|
| 18 |
-
bandwidth: int,
|
| 19 |
-
in_channel: int,
|
| 20 |
-
normalize_channel_independently: bool = False,
|
| 21 |
-
treat_channel_as_feature: bool = True,
|
| 22 |
-
) -> None:
|
| 23 |
-
super().__init__()
|
| 24 |
-
|
| 25 |
-
self.treat_channel_as_feature = treat_channel_as_feature
|
| 26 |
-
|
| 27 |
-
if normalize_channel_independently:
|
| 28 |
-
raise NotImplementedError
|
| 29 |
-
|
| 30 |
-
reim = 2
|
| 31 |
-
|
| 32 |
-
self.norm = nn.LayerNorm(in_channel * bandwidth * reim)
|
| 33 |
-
|
| 34 |
-
fc_in = bandwidth * reim
|
| 35 |
-
|
| 36 |
-
if treat_channel_as_feature:
|
| 37 |
-
fc_in *= in_channel
|
| 38 |
-
else:
|
| 39 |
-
assert emb_dim % in_channel == 0
|
| 40 |
-
emb_dim = emb_dim // in_channel
|
| 41 |
-
|
| 42 |
-
self.fc = nn.Linear(fc_in, emb_dim)
|
| 43 |
-
|
| 44 |
-
def forward(self, xb):
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
batch, n_time, in_chan
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
xr = torch.view_as_real(x) # batch, in_chan, n_freq, n_time, 2
|
| 121 |
-
xr = torch.permute(xr, (0, 3, 1, 4, 2)) # batch, n_time, in_chan, 2, n_freq
|
| 122 |
-
batch, n_time, in_chan, reim, band_width = xr.shape
|
| 123 |
-
for i, nfm in enumerate(self.norm_fc_modules):
|
| 124 |
-
# print(f"bandsplit/band{i:02d}")
|
| 125 |
-
fstart, fend = self.band_specs[i]
|
| 126 |
-
xb = xr[..., fstart:fend]
|
| 127 |
-
# (batch, n_time, in_chan, reim, band_width)
|
| 128 |
-
xb = torch.reshape(xb, (batch, n_time, in_chan, -1))
|
| 129 |
-
# (batch, n_time, in_chan, reim * band_width)
|
| 130 |
-
# z.append(nfm(xb)) # (batch, n_time, emb_dim)
|
| 131 |
-
z[:, i, :, :] = nfm(xb.contiguous())
|
| 132 |
-
|
| 133 |
-
# z = torch.stack(z, dim=1)
|
| 134 |
-
|
| 135 |
-
return z
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
from .utils import (
|
| 7 |
+
band_widths_from_specs,
|
| 8 |
+
check_no_gap,
|
| 9 |
+
check_no_overlap,
|
| 10 |
+
check_nonzero_bandwidth,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class NormFC(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
emb_dim: int,
|
| 18 |
+
bandwidth: int,
|
| 19 |
+
in_channel: int,
|
| 20 |
+
normalize_channel_independently: bool = False,
|
| 21 |
+
treat_channel_as_feature: bool = True,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
|
| 25 |
+
self.treat_channel_as_feature = treat_channel_as_feature
|
| 26 |
+
|
| 27 |
+
if normalize_channel_independently:
|
| 28 |
+
raise NotImplementedError
|
| 29 |
+
|
| 30 |
+
reim = 2
|
| 31 |
+
|
| 32 |
+
self.norm = nn.LayerNorm(in_channel * bandwidth * reim)
|
| 33 |
+
|
| 34 |
+
fc_in = bandwidth * reim
|
| 35 |
+
|
| 36 |
+
if treat_channel_as_feature:
|
| 37 |
+
fc_in *= in_channel
|
| 38 |
+
else:
|
| 39 |
+
assert emb_dim % in_channel == 0
|
| 40 |
+
emb_dim = emb_dim // in_channel
|
| 41 |
+
|
| 42 |
+
self.fc = nn.Linear(fc_in, emb_dim)
|
| 43 |
+
|
| 44 |
+
def forward(self, xb):
|
| 45 |
+
|
| 46 |
+
batch, n_time, in_chan, ribw = xb.shape
|
| 47 |
+
xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw))
|
| 48 |
+
|
| 49 |
+
if not self.treat_channel_as_feature:
|
| 50 |
+
xb = xb.reshape(batch, n_time, in_chan, ribw)
|
| 51 |
+
|
| 52 |
+
zb = self.fc(xb)
|
| 53 |
+
|
| 54 |
+
if not self.treat_channel_as_feature:
|
| 55 |
+
batch, n_time, in_chan, emb_dim_per_chan = zb.shape
|
| 56 |
+
zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan))
|
| 57 |
+
|
| 58 |
+
return zb
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class BandSplitModule(nn.Module):
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
band_specs: List[Tuple[float, float]],
|
| 65 |
+
emb_dim: int,
|
| 66 |
+
in_channel: int,
|
| 67 |
+
require_no_overlap: bool = False,
|
| 68 |
+
require_no_gap: bool = True,
|
| 69 |
+
normalize_channel_independently: bool = False,
|
| 70 |
+
treat_channel_as_feature: bool = True,
|
| 71 |
+
) -> None:
|
| 72 |
+
super().__init__()
|
| 73 |
+
|
| 74 |
+
check_nonzero_bandwidth(band_specs)
|
| 75 |
+
|
| 76 |
+
if require_no_gap:
|
| 77 |
+
check_no_gap(band_specs)
|
| 78 |
+
|
| 79 |
+
if require_no_overlap:
|
| 80 |
+
check_no_overlap(band_specs)
|
| 81 |
+
|
| 82 |
+
self.band_specs = band_specs
|
| 83 |
+
self.band_widths = band_widths_from_specs(band_specs)
|
| 84 |
+
self.n_bands = len(band_specs)
|
| 85 |
+
self.emb_dim = emb_dim
|
| 86 |
+
|
| 87 |
+
self.norm_fc_modules = nn.ModuleList(
|
| 88 |
+
[ # type: ignore
|
| 89 |
+
(
|
| 90 |
+
NormFC(
|
| 91 |
+
emb_dim=emb_dim,
|
| 92 |
+
bandwidth=bw,
|
| 93 |
+
in_channel=in_channel,
|
| 94 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 95 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 96 |
+
)
|
| 97 |
+
)
|
| 98 |
+
for bw in self.band_widths
|
| 99 |
+
]
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def forward(self, x: torch.Tensor):
|
| 103 |
+
|
| 104 |
+
batch, in_chan, _, n_time = x.shape
|
| 105 |
+
|
| 106 |
+
z = torch.zeros(
|
| 107 |
+
size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
xr = torch.view_as_real(x)
|
| 111 |
+
xr = torch.permute(xr, (0, 3, 1, 4, 2))
|
| 112 |
+
batch, n_time, in_chan, reim, band_width = xr.shape
|
| 113 |
+
for i, nfm in enumerate(self.norm_fc_modules):
|
| 114 |
+
fstart, fend = self.band_specs[i]
|
| 115 |
+
xb = xr[..., fstart:fend]
|
| 116 |
+
xb = torch.reshape(xb, (batch, n_time, in_chan, -1))
|
| 117 |
+
z[:, i, :, :] = nfm(xb.contiguous())
|
| 118 |
+
|
| 119 |
+
return z
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mvsepless/models/bandit/core/model/bsrnn/core.py
CHANGED
|
@@ -1,651 +1,619 @@
|
|
| 1 |
-
from typing import Dict, List, Optional, Tuple
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn
|
| 5 |
-
from torch.nn import functional as F
|
| 6 |
-
|
| 7 |
-
from . import BandsplitCoreBase
|
| 8 |
-
from .bandsplit import BandSplitModule
|
| 9 |
-
from .maskestim import (
|
| 10 |
-
MaskEstimationModule,
|
| 11 |
-
OverlappingMaskEstimationModule,
|
| 12 |
-
)
|
| 13 |
-
from .tfmodel import (
|
| 14 |
-
ConvolutionalTimeFreqModule,
|
| 15 |
-
SeqBandModellingModule,
|
| 16 |
-
TransformerTimeFreqModule,
|
| 17 |
-
)
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class MultiMaskBandSplitCoreBase(BandsplitCoreBase):
|
| 21 |
-
def __init__(self) -> None:
|
| 22 |
-
super().__init__()
|
| 23 |
-
|
| 24 |
-
def forward(self, x, cond=None, compute_residual: bool = True):
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
):
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
def
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
rnn_type=rnn_type,
|
| 621 |
-
)
|
| 622 |
-
|
| 623 |
-
if hidden_activation_kwargs is None:
|
| 624 |
-
hidden_activation_kwargs = {}
|
| 625 |
-
|
| 626 |
-
if overlapping_band:
|
| 627 |
-
assert freq_weights is not None
|
| 628 |
-
assert n_freq is not None
|
| 629 |
-
self.mask_estim = nn.ModuleDict(
|
| 630 |
-
{
|
| 631 |
-
stem: PatchingMaskEstimationModule(
|
| 632 |
-
band_specs=band_specs,
|
| 633 |
-
freq_weights=freq_weights,
|
| 634 |
-
n_freq=n_freq,
|
| 635 |
-
emb_dim=emb_dim,
|
| 636 |
-
mlp_dim=mlp_dim,
|
| 637 |
-
in_channel=in_channel,
|
| 638 |
-
hidden_activation=hidden_activation,
|
| 639 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 640 |
-
complex_mask=complex_mask,
|
| 641 |
-
mask_kernel_freq=mask_kernel_freq,
|
| 642 |
-
mask_kernel_time=mask_kernel_time,
|
| 643 |
-
conv_kernel_freq=conv_kernel_freq,
|
| 644 |
-
conv_kernel_time=conv_kernel_time,
|
| 645 |
-
kernel_norm_mlp_version=kernel_norm_mlp_version,
|
| 646 |
-
)
|
| 647 |
-
for stem in stems
|
| 648 |
-
}
|
| 649 |
-
)
|
| 650 |
-
else:
|
| 651 |
-
raise NotImplementedError
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
from . import BandsplitCoreBase
|
| 8 |
+
from .bandsplit import BandSplitModule
|
| 9 |
+
from .maskestim import (
|
| 10 |
+
MaskEstimationModule,
|
| 11 |
+
OverlappingMaskEstimationModule,
|
| 12 |
+
)
|
| 13 |
+
from .tfmodel import (
|
| 14 |
+
ConvolutionalTimeFreqModule,
|
| 15 |
+
SeqBandModellingModule,
|
| 16 |
+
TransformerTimeFreqModule,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MultiMaskBandSplitCoreBase(BandsplitCoreBase):
|
| 21 |
+
def __init__(self) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
def forward(self, x, cond=None, compute_residual: bool = True):
|
| 25 |
+
batch, in_chan, n_freq, n_time = x.shape
|
| 26 |
+
x = torch.reshape(x, (-1, 1, n_freq, n_time))
|
| 27 |
+
|
| 28 |
+
z = self.band_split(x)
|
| 29 |
+
|
| 30 |
+
q = self.tf_model(z)
|
| 31 |
+
|
| 32 |
+
out = {}
|
| 33 |
+
|
| 34 |
+
for stem, mem in self.mask_estim.items():
|
| 35 |
+
m = mem(q, cond=cond)
|
| 36 |
+
|
| 37 |
+
s = self.mask(x, m)
|
| 38 |
+
s = torch.reshape(s, (batch, in_chan, n_freq, n_time))
|
| 39 |
+
out[stem] = s
|
| 40 |
+
|
| 41 |
+
return {"spectrogram": out}
|
| 42 |
+
|
| 43 |
+
def instantiate_mask_estim(
|
| 44 |
+
self,
|
| 45 |
+
in_channel: int,
|
| 46 |
+
stems: List[str],
|
| 47 |
+
band_specs: List[Tuple[float, float]],
|
| 48 |
+
emb_dim: int,
|
| 49 |
+
mlp_dim: int,
|
| 50 |
+
cond_dim: int,
|
| 51 |
+
hidden_activation: str,
|
| 52 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 53 |
+
complex_mask: bool = True,
|
| 54 |
+
overlapping_band: bool = False,
|
| 55 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 56 |
+
n_freq: Optional[int] = None,
|
| 57 |
+
use_freq_weights: bool = True,
|
| 58 |
+
mult_add_mask: bool = False,
|
| 59 |
+
):
|
| 60 |
+
if hidden_activation_kwargs is None:
|
| 61 |
+
hidden_activation_kwargs = {}
|
| 62 |
+
|
| 63 |
+
if "mne:+" in stems:
|
| 64 |
+
stems = [s for s in stems if s != "mne:+"]
|
| 65 |
+
|
| 66 |
+
if overlapping_band:
|
| 67 |
+
assert freq_weights is not None
|
| 68 |
+
assert n_freq is not None
|
| 69 |
+
|
| 70 |
+
if mult_add_mask:
|
| 71 |
+
|
| 72 |
+
self.mask_estim = nn.ModuleDict(
|
| 73 |
+
{
|
| 74 |
+
stem: MultAddMaskEstimationModule(
|
| 75 |
+
band_specs=band_specs,
|
| 76 |
+
freq_weights=freq_weights,
|
| 77 |
+
n_freq=n_freq,
|
| 78 |
+
emb_dim=emb_dim,
|
| 79 |
+
mlp_dim=mlp_dim,
|
| 80 |
+
in_channel=in_channel,
|
| 81 |
+
hidden_activation=hidden_activation,
|
| 82 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 83 |
+
complex_mask=complex_mask,
|
| 84 |
+
use_freq_weights=use_freq_weights,
|
| 85 |
+
)
|
| 86 |
+
for stem in stems
|
| 87 |
+
}
|
| 88 |
+
)
|
| 89 |
+
else:
|
| 90 |
+
self.mask_estim = nn.ModuleDict(
|
| 91 |
+
{
|
| 92 |
+
stem: OverlappingMaskEstimationModule(
|
| 93 |
+
band_specs=band_specs,
|
| 94 |
+
freq_weights=freq_weights,
|
| 95 |
+
n_freq=n_freq,
|
| 96 |
+
emb_dim=emb_dim,
|
| 97 |
+
mlp_dim=mlp_dim,
|
| 98 |
+
in_channel=in_channel,
|
| 99 |
+
hidden_activation=hidden_activation,
|
| 100 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 101 |
+
complex_mask=complex_mask,
|
| 102 |
+
use_freq_weights=use_freq_weights,
|
| 103 |
+
)
|
| 104 |
+
for stem in stems
|
| 105 |
+
}
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
self.mask_estim = nn.ModuleDict(
|
| 109 |
+
{
|
| 110 |
+
stem: MaskEstimationModule(
|
| 111 |
+
band_specs=band_specs,
|
| 112 |
+
emb_dim=emb_dim,
|
| 113 |
+
mlp_dim=mlp_dim,
|
| 114 |
+
cond_dim=cond_dim,
|
| 115 |
+
in_channel=in_channel,
|
| 116 |
+
hidden_activation=hidden_activation,
|
| 117 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 118 |
+
complex_mask=complex_mask,
|
| 119 |
+
)
|
| 120 |
+
for stem in stems
|
| 121 |
+
}
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def instantiate_bandsplit(
|
| 125 |
+
self,
|
| 126 |
+
in_channel: int,
|
| 127 |
+
band_specs: List[Tuple[float, float]],
|
| 128 |
+
require_no_overlap: bool = False,
|
| 129 |
+
require_no_gap: bool = True,
|
| 130 |
+
normalize_channel_independently: bool = False,
|
| 131 |
+
treat_channel_as_feature: bool = True,
|
| 132 |
+
emb_dim: int = 128,
|
| 133 |
+
):
|
| 134 |
+
self.band_split = BandSplitModule(
|
| 135 |
+
in_channel=in_channel,
|
| 136 |
+
band_specs=band_specs,
|
| 137 |
+
require_no_overlap=require_no_overlap,
|
| 138 |
+
require_no_gap=require_no_gap,
|
| 139 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 140 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 141 |
+
emb_dim=emb_dim,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class SingleMaskBandsplitCoreBase(BandsplitCoreBase):
|
| 146 |
+
def __init__(self, **kwargs) -> None:
|
| 147 |
+
super().__init__()
|
| 148 |
+
|
| 149 |
+
def forward(self, x):
|
| 150 |
+
z = self.band_split(x)
|
| 151 |
+
q = self.tf_model(z)
|
| 152 |
+
m = self.mask_estim(q)
|
| 153 |
+
|
| 154 |
+
s = self.mask(x, m)
|
| 155 |
+
|
| 156 |
+
return s
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class SingleMaskBandsplitCoreRNN(
|
| 160 |
+
SingleMaskBandsplitCoreBase,
|
| 161 |
+
):
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
in_channel: int,
|
| 165 |
+
band_specs: List[Tuple[float, float]],
|
| 166 |
+
require_no_overlap: bool = False,
|
| 167 |
+
require_no_gap: bool = True,
|
| 168 |
+
normalize_channel_independently: bool = False,
|
| 169 |
+
treat_channel_as_feature: bool = True,
|
| 170 |
+
n_sqm_modules: int = 12,
|
| 171 |
+
emb_dim: int = 128,
|
| 172 |
+
rnn_dim: int = 256,
|
| 173 |
+
bidirectional: bool = True,
|
| 174 |
+
rnn_type: str = "LSTM",
|
| 175 |
+
mlp_dim: int = 512,
|
| 176 |
+
hidden_activation: str = "Tanh",
|
| 177 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 178 |
+
complex_mask: bool = True,
|
| 179 |
+
) -> None:
|
| 180 |
+
super().__init__()
|
| 181 |
+
self.band_split = BandSplitModule(
|
| 182 |
+
in_channel=in_channel,
|
| 183 |
+
band_specs=band_specs,
|
| 184 |
+
require_no_overlap=require_no_overlap,
|
| 185 |
+
require_no_gap=require_no_gap,
|
| 186 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 187 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 188 |
+
emb_dim=emb_dim,
|
| 189 |
+
)
|
| 190 |
+
self.tf_model = SeqBandModellingModule(
|
| 191 |
+
n_modules=n_sqm_modules,
|
| 192 |
+
emb_dim=emb_dim,
|
| 193 |
+
rnn_dim=rnn_dim,
|
| 194 |
+
bidirectional=bidirectional,
|
| 195 |
+
rnn_type=rnn_type,
|
| 196 |
+
)
|
| 197 |
+
self.mask_estim = MaskEstimationModule(
|
| 198 |
+
in_channel=in_channel,
|
| 199 |
+
band_specs=band_specs,
|
| 200 |
+
emb_dim=emb_dim,
|
| 201 |
+
mlp_dim=mlp_dim,
|
| 202 |
+
hidden_activation=hidden_activation,
|
| 203 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 204 |
+
complex_mask=complex_mask,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class SingleMaskBandsplitCoreTransformer(
|
| 209 |
+
SingleMaskBandsplitCoreBase,
|
| 210 |
+
):
|
| 211 |
+
def __init__(
|
| 212 |
+
self,
|
| 213 |
+
in_channel: int,
|
| 214 |
+
band_specs: List[Tuple[float, float]],
|
| 215 |
+
require_no_overlap: bool = False,
|
| 216 |
+
require_no_gap: bool = True,
|
| 217 |
+
normalize_channel_independently: bool = False,
|
| 218 |
+
treat_channel_as_feature: bool = True,
|
| 219 |
+
n_sqm_modules: int = 12,
|
| 220 |
+
emb_dim: int = 128,
|
| 221 |
+
rnn_dim: int = 256,
|
| 222 |
+
bidirectional: bool = True,
|
| 223 |
+
tf_dropout: float = 0.0,
|
| 224 |
+
mlp_dim: int = 512,
|
| 225 |
+
hidden_activation: str = "Tanh",
|
| 226 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 227 |
+
complex_mask: bool = True,
|
| 228 |
+
) -> None:
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.band_split = BandSplitModule(
|
| 231 |
+
in_channel=in_channel,
|
| 232 |
+
band_specs=band_specs,
|
| 233 |
+
require_no_overlap=require_no_overlap,
|
| 234 |
+
require_no_gap=require_no_gap,
|
| 235 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 236 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 237 |
+
emb_dim=emb_dim,
|
| 238 |
+
)
|
| 239 |
+
self.tf_model = TransformerTimeFreqModule(
|
| 240 |
+
n_modules=n_sqm_modules,
|
| 241 |
+
emb_dim=emb_dim,
|
| 242 |
+
rnn_dim=rnn_dim,
|
| 243 |
+
bidirectional=bidirectional,
|
| 244 |
+
dropout=tf_dropout,
|
| 245 |
+
)
|
| 246 |
+
self.mask_estim = MaskEstimationModule(
|
| 247 |
+
in_channel=in_channel,
|
| 248 |
+
band_specs=band_specs,
|
| 249 |
+
emb_dim=emb_dim,
|
| 250 |
+
mlp_dim=mlp_dim,
|
| 251 |
+
hidden_activation=hidden_activation,
|
| 252 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 253 |
+
complex_mask=complex_mask,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase):
|
| 258 |
+
def __init__(
|
| 259 |
+
self,
|
| 260 |
+
in_channel: int,
|
| 261 |
+
stems: List[str],
|
| 262 |
+
band_specs: List[Tuple[float, float]],
|
| 263 |
+
require_no_overlap: bool = False,
|
| 264 |
+
require_no_gap: bool = True,
|
| 265 |
+
normalize_channel_independently: bool = False,
|
| 266 |
+
treat_channel_as_feature: bool = True,
|
| 267 |
+
n_sqm_modules: int = 12,
|
| 268 |
+
emb_dim: int = 128,
|
| 269 |
+
rnn_dim: int = 256,
|
| 270 |
+
bidirectional: bool = True,
|
| 271 |
+
rnn_type: str = "LSTM",
|
| 272 |
+
mlp_dim: int = 512,
|
| 273 |
+
cond_dim: int = 0,
|
| 274 |
+
hidden_activation: str = "Tanh",
|
| 275 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 276 |
+
complex_mask: bool = True,
|
| 277 |
+
overlapping_band: bool = False,
|
| 278 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 279 |
+
n_freq: Optional[int] = None,
|
| 280 |
+
use_freq_weights: bool = True,
|
| 281 |
+
mult_add_mask: bool = False,
|
| 282 |
+
) -> None:
|
| 283 |
+
|
| 284 |
+
super().__init__()
|
| 285 |
+
self.instantiate_bandsplit(
|
| 286 |
+
in_channel=in_channel,
|
| 287 |
+
band_specs=band_specs,
|
| 288 |
+
require_no_overlap=require_no_overlap,
|
| 289 |
+
require_no_gap=require_no_gap,
|
| 290 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 291 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 292 |
+
emb_dim=emb_dim,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
self.tf_model = SeqBandModellingModule(
|
| 296 |
+
n_modules=n_sqm_modules,
|
| 297 |
+
emb_dim=emb_dim,
|
| 298 |
+
rnn_dim=rnn_dim,
|
| 299 |
+
bidirectional=bidirectional,
|
| 300 |
+
rnn_type=rnn_type,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
self.mult_add_mask = mult_add_mask
|
| 304 |
+
|
| 305 |
+
self.instantiate_mask_estim(
|
| 306 |
+
in_channel=in_channel,
|
| 307 |
+
stems=stems,
|
| 308 |
+
band_specs=band_specs,
|
| 309 |
+
emb_dim=emb_dim,
|
| 310 |
+
mlp_dim=mlp_dim,
|
| 311 |
+
cond_dim=cond_dim,
|
| 312 |
+
hidden_activation=hidden_activation,
|
| 313 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 314 |
+
complex_mask=complex_mask,
|
| 315 |
+
overlapping_band=overlapping_band,
|
| 316 |
+
freq_weights=freq_weights,
|
| 317 |
+
n_freq=n_freq,
|
| 318 |
+
use_freq_weights=use_freq_weights,
|
| 319 |
+
mult_add_mask=mult_add_mask,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
@staticmethod
|
| 323 |
+
def _mult_add_mask(x, m):
|
| 324 |
+
|
| 325 |
+
assert m.ndim == 5
|
| 326 |
+
|
| 327 |
+
mm = m[..., 0]
|
| 328 |
+
am = m[..., 1]
|
| 329 |
+
|
| 330 |
+
return x * mm + am
|
| 331 |
+
|
| 332 |
+
def mask(self, x, m):
|
| 333 |
+
if self.mult_add_mask:
|
| 334 |
+
|
| 335 |
+
return self._mult_add_mask(x, m)
|
| 336 |
+
else:
|
| 337 |
+
return super().mask(x, m)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class MultiSourceMultiMaskBandSplitCoreTransformer(
|
| 341 |
+
MultiMaskBandSplitCoreBase,
|
| 342 |
+
):
|
| 343 |
+
def __init__(
|
| 344 |
+
self,
|
| 345 |
+
in_channel: int,
|
| 346 |
+
stems: List[str],
|
| 347 |
+
band_specs: List[Tuple[float, float]],
|
| 348 |
+
require_no_overlap: bool = False,
|
| 349 |
+
require_no_gap: bool = True,
|
| 350 |
+
normalize_channel_independently: bool = False,
|
| 351 |
+
treat_channel_as_feature: bool = True,
|
| 352 |
+
n_sqm_modules: int = 12,
|
| 353 |
+
emb_dim: int = 128,
|
| 354 |
+
rnn_dim: int = 256,
|
| 355 |
+
bidirectional: bool = True,
|
| 356 |
+
tf_dropout: float = 0.0,
|
| 357 |
+
mlp_dim: int = 512,
|
| 358 |
+
hidden_activation: str = "Tanh",
|
| 359 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 360 |
+
complex_mask: bool = True,
|
| 361 |
+
overlapping_band: bool = False,
|
| 362 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 363 |
+
n_freq: Optional[int] = None,
|
| 364 |
+
use_freq_weights: bool = True,
|
| 365 |
+
rnn_type: str = "LSTM",
|
| 366 |
+
cond_dim: int = 0,
|
| 367 |
+
mult_add_mask: bool = False,
|
| 368 |
+
) -> None:
|
| 369 |
+
super().__init__()
|
| 370 |
+
self.instantiate_bandsplit(
|
| 371 |
+
in_channel=in_channel,
|
| 372 |
+
band_specs=band_specs,
|
| 373 |
+
require_no_overlap=require_no_overlap,
|
| 374 |
+
require_no_gap=require_no_gap,
|
| 375 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 376 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 377 |
+
emb_dim=emb_dim,
|
| 378 |
+
)
|
| 379 |
+
self.tf_model = TransformerTimeFreqModule(
|
| 380 |
+
n_modules=n_sqm_modules,
|
| 381 |
+
emb_dim=emb_dim,
|
| 382 |
+
rnn_dim=rnn_dim,
|
| 383 |
+
bidirectional=bidirectional,
|
| 384 |
+
dropout=tf_dropout,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
self.instantiate_mask_estim(
|
| 388 |
+
in_channel=in_channel,
|
| 389 |
+
stems=stems,
|
| 390 |
+
band_specs=band_specs,
|
| 391 |
+
emb_dim=emb_dim,
|
| 392 |
+
mlp_dim=mlp_dim,
|
| 393 |
+
cond_dim=cond_dim,
|
| 394 |
+
hidden_activation=hidden_activation,
|
| 395 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 396 |
+
complex_mask=complex_mask,
|
| 397 |
+
overlapping_band=overlapping_band,
|
| 398 |
+
freq_weights=freq_weights,
|
| 399 |
+
n_freq=n_freq,
|
| 400 |
+
use_freq_weights=use_freq_weights,
|
| 401 |
+
mult_add_mask=mult_add_mask,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class MultiSourceMultiMaskBandSplitCoreConv(
|
| 406 |
+
MultiMaskBandSplitCoreBase,
|
| 407 |
+
):
|
| 408 |
+
def __init__(
|
| 409 |
+
self,
|
| 410 |
+
in_channel: int,
|
| 411 |
+
stems: List[str],
|
| 412 |
+
band_specs: List[Tuple[float, float]],
|
| 413 |
+
require_no_overlap: bool = False,
|
| 414 |
+
require_no_gap: bool = True,
|
| 415 |
+
normalize_channel_independently: bool = False,
|
| 416 |
+
treat_channel_as_feature: bool = True,
|
| 417 |
+
n_sqm_modules: int = 12,
|
| 418 |
+
emb_dim: int = 128,
|
| 419 |
+
rnn_dim: int = 256,
|
| 420 |
+
bidirectional: bool = True,
|
| 421 |
+
tf_dropout: float = 0.0,
|
| 422 |
+
mlp_dim: int = 512,
|
| 423 |
+
hidden_activation: str = "Tanh",
|
| 424 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 425 |
+
complex_mask: bool = True,
|
| 426 |
+
overlapping_band: bool = False,
|
| 427 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 428 |
+
n_freq: Optional[int] = None,
|
| 429 |
+
use_freq_weights: bool = True,
|
| 430 |
+
rnn_type: str = "LSTM",
|
| 431 |
+
cond_dim: int = 0,
|
| 432 |
+
mult_add_mask: bool = False,
|
| 433 |
+
) -> None:
|
| 434 |
+
super().__init__()
|
| 435 |
+
self.instantiate_bandsplit(
|
| 436 |
+
in_channel=in_channel,
|
| 437 |
+
band_specs=band_specs,
|
| 438 |
+
require_no_overlap=require_no_overlap,
|
| 439 |
+
require_no_gap=require_no_gap,
|
| 440 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 441 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 442 |
+
emb_dim=emb_dim,
|
| 443 |
+
)
|
| 444 |
+
self.tf_model = ConvolutionalTimeFreqModule(
|
| 445 |
+
n_modules=n_sqm_modules,
|
| 446 |
+
emb_dim=emb_dim,
|
| 447 |
+
rnn_dim=rnn_dim,
|
| 448 |
+
bidirectional=bidirectional,
|
| 449 |
+
dropout=tf_dropout,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
self.instantiate_mask_estim(
|
| 453 |
+
in_channel=in_channel,
|
| 454 |
+
stems=stems,
|
| 455 |
+
band_specs=band_specs,
|
| 456 |
+
emb_dim=emb_dim,
|
| 457 |
+
mlp_dim=mlp_dim,
|
| 458 |
+
cond_dim=cond_dim,
|
| 459 |
+
hidden_activation=hidden_activation,
|
| 460 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 461 |
+
complex_mask=complex_mask,
|
| 462 |
+
overlapping_band=overlapping_band,
|
| 463 |
+
freq_weights=freq_weights,
|
| 464 |
+
n_freq=n_freq,
|
| 465 |
+
use_freq_weights=use_freq_weights,
|
| 466 |
+
mult_add_mask=mult_add_mask,
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
class PatchingMaskBandsplitCoreBase(MultiMaskBandSplitCoreBase):
|
| 471 |
+
def __init__(self) -> None:
|
| 472 |
+
super().__init__()
|
| 473 |
+
|
| 474 |
+
def mask(self, x, m):
|
| 475 |
+
|
| 476 |
+
_, n_channel, kernel_freq, kernel_time, n_freq, n_time = m.shape
|
| 477 |
+
padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2)
|
| 478 |
+
|
| 479 |
+
xf = F.unfold(
|
| 480 |
+
x,
|
| 481 |
+
kernel_size=(kernel_freq, kernel_time),
|
| 482 |
+
padding=padding,
|
| 483 |
+
stride=(1, 1),
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
xf = xf.view(
|
| 487 |
+
-1,
|
| 488 |
+
n_channel,
|
| 489 |
+
kernel_freq,
|
| 490 |
+
kernel_time,
|
| 491 |
+
n_freq,
|
| 492 |
+
n_time,
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
sf = xf * m
|
| 496 |
+
|
| 497 |
+
sf = sf.view(
|
| 498 |
+
-1,
|
| 499 |
+
n_channel * kernel_freq * kernel_time,
|
| 500 |
+
n_freq * n_time,
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
s = F.fold(
|
| 504 |
+
sf,
|
| 505 |
+
output_size=(n_freq, n_time),
|
| 506 |
+
kernel_size=(kernel_freq, kernel_time),
|
| 507 |
+
padding=padding,
|
| 508 |
+
stride=(1, 1),
|
| 509 |
+
).view(
|
| 510 |
+
-1,
|
| 511 |
+
n_channel,
|
| 512 |
+
n_freq,
|
| 513 |
+
n_time,
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
return s
|
| 517 |
+
|
| 518 |
+
def old_mask(self, x, m):
|
| 519 |
+
|
| 520 |
+
s = torch.zeros_like(x)
|
| 521 |
+
|
| 522 |
+
_, n_channel, n_freq, n_time = x.shape
|
| 523 |
+
kernel_freq, kernel_time, _, _, _, _ = m.shape
|
| 524 |
+
|
| 525 |
+
kernel_freq_half = (kernel_freq - 1) // 2
|
| 526 |
+
kernel_time_half = (kernel_time - 1) // 2
|
| 527 |
+
|
| 528 |
+
for ifreq in range(kernel_freq):
|
| 529 |
+
for itime in range(kernel_time):
|
| 530 |
+
df, dt = kernel_freq_half - ifreq, kernel_time_half - itime
|
| 531 |
+
x = x.roll(shifts=(df, dt), dims=(2, 3))
|
| 532 |
+
|
| 533 |
+
fslice = slice(max(0, df), min(n_freq, n_freq + df))
|
| 534 |
+
tslice = slice(max(0, dt), min(n_time, n_time + dt))
|
| 535 |
+
|
| 536 |
+
s[:, :, fslice, tslice] += (
|
| 537 |
+
x[:, :, fslice, tslice] * m[ifreq, itime, :, :, fslice, tslice]
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
return s
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
class MultiSourceMultiPatchingMaskBandSplitCoreRNN(PatchingMaskBandsplitCoreBase):
|
| 544 |
+
def __init__(
|
| 545 |
+
self,
|
| 546 |
+
in_channel: int,
|
| 547 |
+
stems: List[str],
|
| 548 |
+
band_specs: List[Tuple[float, float]],
|
| 549 |
+
mask_kernel_freq: int,
|
| 550 |
+
mask_kernel_time: int,
|
| 551 |
+
conv_kernel_freq: int,
|
| 552 |
+
conv_kernel_time: int,
|
| 553 |
+
kernel_norm_mlp_version: int,
|
| 554 |
+
require_no_overlap: bool = False,
|
| 555 |
+
require_no_gap: bool = True,
|
| 556 |
+
normalize_channel_independently: bool = False,
|
| 557 |
+
treat_channel_as_feature: bool = True,
|
| 558 |
+
n_sqm_modules: int = 12,
|
| 559 |
+
emb_dim: int = 128,
|
| 560 |
+
rnn_dim: int = 256,
|
| 561 |
+
bidirectional: bool = True,
|
| 562 |
+
rnn_type: str = "LSTM",
|
| 563 |
+
mlp_dim: int = 512,
|
| 564 |
+
hidden_activation: str = "Tanh",
|
| 565 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 566 |
+
complex_mask: bool = True,
|
| 567 |
+
overlapping_band: bool = False,
|
| 568 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 569 |
+
n_freq: Optional[int] = None,
|
| 570 |
+
) -> None:
|
| 571 |
+
|
| 572 |
+
super().__init__()
|
| 573 |
+
self.band_split = BandSplitModule(
|
| 574 |
+
in_channel=in_channel,
|
| 575 |
+
band_specs=band_specs,
|
| 576 |
+
require_no_overlap=require_no_overlap,
|
| 577 |
+
require_no_gap=require_no_gap,
|
| 578 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 579 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 580 |
+
emb_dim=emb_dim,
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
self.tf_model = SeqBandModellingModule(
|
| 584 |
+
n_modules=n_sqm_modules,
|
| 585 |
+
emb_dim=emb_dim,
|
| 586 |
+
rnn_dim=rnn_dim,
|
| 587 |
+
bidirectional=bidirectional,
|
| 588 |
+
rnn_type=rnn_type,
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
if hidden_activation_kwargs is None:
|
| 592 |
+
hidden_activation_kwargs = {}
|
| 593 |
+
|
| 594 |
+
if overlapping_band:
|
| 595 |
+
assert freq_weights is not None
|
| 596 |
+
assert n_freq is not None
|
| 597 |
+
self.mask_estim = nn.ModuleDict(
|
| 598 |
+
{
|
| 599 |
+
stem: PatchingMaskEstimationModule(
|
| 600 |
+
band_specs=band_specs,
|
| 601 |
+
freq_weights=freq_weights,
|
| 602 |
+
n_freq=n_freq,
|
| 603 |
+
emb_dim=emb_dim,
|
| 604 |
+
mlp_dim=mlp_dim,
|
| 605 |
+
in_channel=in_channel,
|
| 606 |
+
hidden_activation=hidden_activation,
|
| 607 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 608 |
+
complex_mask=complex_mask,
|
| 609 |
+
mask_kernel_freq=mask_kernel_freq,
|
| 610 |
+
mask_kernel_time=mask_kernel_time,
|
| 611 |
+
conv_kernel_freq=conv_kernel_freq,
|
| 612 |
+
conv_kernel_time=conv_kernel_time,
|
| 613 |
+
kernel_norm_mlp_version=kernel_norm_mlp_version,
|
| 614 |
+
)
|
| 615 |
+
for stem in stems
|
| 616 |
+
}
|
| 617 |
+
)
|
| 618 |
+
else:
|
| 619 |
+
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mvsepless/models/bandit/core/model/bsrnn/maskestim.py
CHANGED
|
@@ -1,351 +1,327 @@
|
|
| 1 |
-
import warnings
|
| 2 |
-
from typing import Dict, List, Optional, Tuple, Type
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
from torch import nn
|
| 6 |
-
from torch.nn.modules import activation
|
| 7 |
-
|
| 8 |
-
from .utils import (
|
| 9 |
-
band_widths_from_specs,
|
| 10 |
-
check_no_gap,
|
| 11 |
-
check_no_overlap,
|
| 12 |
-
check_nonzero_bandwidth,
|
| 13 |
-
)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class BaseNormMLP(nn.Module):
|
| 17 |
-
def __init__(
|
| 18 |
-
self,
|
| 19 |
-
emb_dim: int,
|
| 20 |
-
mlp_dim: int,
|
| 21 |
-
bandwidth: int,
|
| 22 |
-
in_channel: Optional[int],
|
| 23 |
-
hidden_activation: str = "Tanh",
|
| 24 |
-
hidden_activation_kwargs=None,
|
| 25 |
-
complex_mask: bool = True,
|
| 26 |
-
):
|
| 27 |
-
|
| 28 |
-
super().__init__()
|
| 29 |
-
if hidden_activation_kwargs is None:
|
| 30 |
-
hidden_activation_kwargs = {}
|
| 31 |
-
self.hidden_activation_kwargs = hidden_activation_kwargs
|
| 32 |
-
self.norm = nn.LayerNorm(emb_dim)
|
| 33 |
-
self.hidden = torch.jit.script(
|
| 34 |
-
nn.Sequential(
|
| 35 |
-
nn.Linear(in_features=emb_dim, out_features=mlp_dim),
|
| 36 |
-
activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
|
| 37 |
-
)
|
| 38 |
-
)
|
| 39 |
-
|
| 40 |
-
self.bandwidth = bandwidth
|
| 41 |
-
self.in_channel = in_channel
|
| 42 |
-
|
| 43 |
-
self.complex_mask = complex_mask
|
| 44 |
-
self.reim = 2 if complex_mask else 1
|
| 45 |
-
self.glu_mult = 2
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
class NormMLP(BaseNormMLP):
|
| 49 |
-
def __init__(
|
| 50 |
-
self,
|
| 51 |
-
emb_dim: int,
|
| 52 |
-
mlp_dim: int,
|
| 53 |
-
bandwidth: int,
|
| 54 |
-
in_channel: Optional[int],
|
| 55 |
-
hidden_activation: str = "Tanh",
|
| 56 |
-
hidden_activation_kwargs=None,
|
| 57 |
-
complex_mask: bool = True,
|
| 58 |
-
) -> None:
|
| 59 |
-
super().__init__(
|
| 60 |
-
emb_dim=emb_dim,
|
| 61 |
-
mlp_dim=mlp_dim,
|
| 62 |
-
bandwidth=bandwidth,
|
| 63 |
-
in_channel=in_channel,
|
| 64 |
-
hidden_activation=hidden_activation,
|
| 65 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 66 |
-
complex_mask=complex_mask,
|
| 67 |
-
)
|
| 68 |
-
|
| 69 |
-
self.output = torch.jit.script(
|
| 70 |
-
nn.Sequential(
|
| 71 |
-
nn.Linear(
|
| 72 |
-
in_features=mlp_dim,
|
| 73 |
-
out_features=bandwidth * in_channel * self.reim * 2,
|
| 74 |
-
),
|
| 75 |
-
nn.GLU(dim=-1),
|
| 76 |
-
)
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
def reshape_output(self, mb):
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
mb =
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
self
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
band_specs
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
check_no_overlap(band_specs)
|
| 329 |
-
super().__init__(
|
| 330 |
-
in_channel=in_channel,
|
| 331 |
-
band_specs=band_specs,
|
| 332 |
-
freq_weights=None,
|
| 333 |
-
n_freq=None,
|
| 334 |
-
emb_dim=emb_dim,
|
| 335 |
-
mlp_dim=mlp_dim,
|
| 336 |
-
hidden_activation=hidden_activation,
|
| 337 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 338 |
-
complex_mask=complex_mask,
|
| 339 |
-
)
|
| 340 |
-
|
| 341 |
-
def forward(self, q, cond=None):
|
| 342 |
-
# q = (batch, n_bands, n_time, emb_dim)
|
| 343 |
-
|
| 344 |
-
masks = self.compute_masks(
|
| 345 |
-
q
|
| 346 |
-
) # [n_bands * (batch, in_channel, bandwidth, n_time)]
|
| 347 |
-
|
| 348 |
-
# TODO: currently this requires band specs to have no gap and no overlap
|
| 349 |
-
masks = torch.concat(masks, dim=2) # (batch, in_channel, n_freq, n_time)
|
| 350 |
-
|
| 351 |
-
return masks
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Dict, List, Optional, Tuple, Type
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn.modules import activation
|
| 7 |
+
|
| 8 |
+
from .utils import (
|
| 9 |
+
band_widths_from_specs,
|
| 10 |
+
check_no_gap,
|
| 11 |
+
check_no_overlap,
|
| 12 |
+
check_nonzero_bandwidth,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BaseNormMLP(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
emb_dim: int,
|
| 20 |
+
mlp_dim: int,
|
| 21 |
+
bandwidth: int,
|
| 22 |
+
in_channel: Optional[int],
|
| 23 |
+
hidden_activation: str = "Tanh",
|
| 24 |
+
hidden_activation_kwargs=None,
|
| 25 |
+
complex_mask: bool = True,
|
| 26 |
+
):
|
| 27 |
+
|
| 28 |
+
super().__init__()
|
| 29 |
+
if hidden_activation_kwargs is None:
|
| 30 |
+
hidden_activation_kwargs = {}
|
| 31 |
+
self.hidden_activation_kwargs = hidden_activation_kwargs
|
| 32 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 33 |
+
self.hidden = torch.jit.script(
|
| 34 |
+
nn.Sequential(
|
| 35 |
+
nn.Linear(in_features=emb_dim, out_features=mlp_dim),
|
| 36 |
+
activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
|
| 37 |
+
)
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
self.bandwidth = bandwidth
|
| 41 |
+
self.in_channel = in_channel
|
| 42 |
+
|
| 43 |
+
self.complex_mask = complex_mask
|
| 44 |
+
self.reim = 2 if complex_mask else 1
|
| 45 |
+
self.glu_mult = 2
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class NormMLP(BaseNormMLP):
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
emb_dim: int,
|
| 52 |
+
mlp_dim: int,
|
| 53 |
+
bandwidth: int,
|
| 54 |
+
in_channel: Optional[int],
|
| 55 |
+
hidden_activation: str = "Tanh",
|
| 56 |
+
hidden_activation_kwargs=None,
|
| 57 |
+
complex_mask: bool = True,
|
| 58 |
+
) -> None:
|
| 59 |
+
super().__init__(
|
| 60 |
+
emb_dim=emb_dim,
|
| 61 |
+
mlp_dim=mlp_dim,
|
| 62 |
+
bandwidth=bandwidth,
|
| 63 |
+
in_channel=in_channel,
|
| 64 |
+
hidden_activation=hidden_activation,
|
| 65 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 66 |
+
complex_mask=complex_mask,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
self.output = torch.jit.script(
|
| 70 |
+
nn.Sequential(
|
| 71 |
+
nn.Linear(
|
| 72 |
+
in_features=mlp_dim,
|
| 73 |
+
out_features=bandwidth * in_channel * self.reim * 2,
|
| 74 |
+
),
|
| 75 |
+
nn.GLU(dim=-1),
|
| 76 |
+
)
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def reshape_output(self, mb):
|
| 80 |
+
batch, n_time, _ = mb.shape
|
| 81 |
+
if self.complex_mask:
|
| 82 |
+
mb = mb.reshape(
|
| 83 |
+
batch, n_time, self.in_channel, self.bandwidth, self.reim
|
| 84 |
+
).contiguous()
|
| 85 |
+
mb = torch.view_as_complex(mb)
|
| 86 |
+
else:
|
| 87 |
+
mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth)
|
| 88 |
+
|
| 89 |
+
mb = torch.permute(mb, (0, 2, 3, 1))
|
| 90 |
+
|
| 91 |
+
return mb
|
| 92 |
+
|
| 93 |
+
def forward(self, qb):
|
| 94 |
+
|
| 95 |
+
qb = self.norm(qb)
|
| 96 |
+
|
| 97 |
+
qb = self.hidden(qb)
|
| 98 |
+
mb = self.output(qb)
|
| 99 |
+
mb = self.reshape_output(mb)
|
| 100 |
+
|
| 101 |
+
return mb
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class MultAddNormMLP(NormMLP):
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
emb_dim: int,
|
| 108 |
+
mlp_dim: int,
|
| 109 |
+
bandwidth: int,
|
| 110 |
+
in_channel: "int | None",
|
| 111 |
+
hidden_activation: str = "Tanh",
|
| 112 |
+
hidden_activation_kwargs=None,
|
| 113 |
+
complex_mask: bool = True,
|
| 114 |
+
) -> None:
|
| 115 |
+
super().__init__(
|
| 116 |
+
emb_dim,
|
| 117 |
+
mlp_dim,
|
| 118 |
+
bandwidth,
|
| 119 |
+
in_channel,
|
| 120 |
+
hidden_activation,
|
| 121 |
+
hidden_activation_kwargs,
|
| 122 |
+
complex_mask,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
self.output2 = torch.jit.script(
|
| 126 |
+
nn.Sequential(
|
| 127 |
+
nn.Linear(
|
| 128 |
+
in_features=mlp_dim,
|
| 129 |
+
out_features=bandwidth * in_channel * self.reim * 2,
|
| 130 |
+
),
|
| 131 |
+
nn.GLU(dim=-1),
|
| 132 |
+
)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def forward(self, qb):
|
| 136 |
+
|
| 137 |
+
qb = self.norm(qb)
|
| 138 |
+
qb = self.hidden(qb)
|
| 139 |
+
mmb = self.output(qb)
|
| 140 |
+
mmb = self.reshape_output(mmb)
|
| 141 |
+
amb = self.output2(qb)
|
| 142 |
+
amb = self.reshape_output(amb)
|
| 143 |
+
|
| 144 |
+
return mmb, amb
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class MaskEstimationModuleSuperBase(nn.Module):
|
| 148 |
+
pass
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
|
| 152 |
+
def __init__(
|
| 153 |
+
self,
|
| 154 |
+
band_specs: List[Tuple[float, float]],
|
| 155 |
+
emb_dim: int,
|
| 156 |
+
mlp_dim: int,
|
| 157 |
+
in_channel: Optional[int],
|
| 158 |
+
hidden_activation: str = "Tanh",
|
| 159 |
+
hidden_activation_kwargs: Dict = None,
|
| 160 |
+
complex_mask: bool = True,
|
| 161 |
+
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 162 |
+
norm_mlp_kwargs: Dict = None,
|
| 163 |
+
) -> None:
|
| 164 |
+
super().__init__()
|
| 165 |
+
|
| 166 |
+
self.band_widths = band_widths_from_specs(band_specs)
|
| 167 |
+
self.n_bands = len(band_specs)
|
| 168 |
+
|
| 169 |
+
if hidden_activation_kwargs is None:
|
| 170 |
+
hidden_activation_kwargs = {}
|
| 171 |
+
|
| 172 |
+
if norm_mlp_kwargs is None:
|
| 173 |
+
norm_mlp_kwargs = {}
|
| 174 |
+
|
| 175 |
+
self.norm_mlp = nn.ModuleList(
|
| 176 |
+
[
|
| 177 |
+
(
|
| 178 |
+
norm_mlp_cls(
|
| 179 |
+
bandwidth=self.band_widths[b],
|
| 180 |
+
emb_dim=emb_dim,
|
| 181 |
+
mlp_dim=mlp_dim,
|
| 182 |
+
in_channel=in_channel,
|
| 183 |
+
hidden_activation=hidden_activation,
|
| 184 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 185 |
+
complex_mask=complex_mask,
|
| 186 |
+
**norm_mlp_kwargs,
|
| 187 |
+
)
|
| 188 |
+
)
|
| 189 |
+
for b in range(self.n_bands)
|
| 190 |
+
]
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
def compute_masks(self, q):
|
| 194 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 195 |
+
|
| 196 |
+
masks = []
|
| 197 |
+
|
| 198 |
+
for b, nmlp in enumerate(self.norm_mlp):
|
| 199 |
+
qb = q[:, b, :, :]
|
| 200 |
+
mb = nmlp(qb)
|
| 201 |
+
masks.append(mb)
|
| 202 |
+
|
| 203 |
+
return masks
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
|
| 207 |
+
def __init__(
|
| 208 |
+
self,
|
| 209 |
+
in_channel: int,
|
| 210 |
+
band_specs: List[Tuple[float, float]],
|
| 211 |
+
freq_weights: List[torch.Tensor],
|
| 212 |
+
n_freq: int,
|
| 213 |
+
emb_dim: int,
|
| 214 |
+
mlp_dim: int,
|
| 215 |
+
cond_dim: int = 0,
|
| 216 |
+
hidden_activation: str = "Tanh",
|
| 217 |
+
hidden_activation_kwargs: Dict = None,
|
| 218 |
+
complex_mask: bool = True,
|
| 219 |
+
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 220 |
+
norm_mlp_kwargs: Dict = None,
|
| 221 |
+
use_freq_weights: bool = True,
|
| 222 |
+
) -> None:
|
| 223 |
+
check_nonzero_bandwidth(band_specs)
|
| 224 |
+
check_no_gap(band_specs)
|
| 225 |
+
|
| 226 |
+
super().__init__(
|
| 227 |
+
band_specs=band_specs,
|
| 228 |
+
emb_dim=emb_dim + cond_dim,
|
| 229 |
+
mlp_dim=mlp_dim,
|
| 230 |
+
in_channel=in_channel,
|
| 231 |
+
hidden_activation=hidden_activation,
|
| 232 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 233 |
+
complex_mask=complex_mask,
|
| 234 |
+
norm_mlp_cls=norm_mlp_cls,
|
| 235 |
+
norm_mlp_kwargs=norm_mlp_kwargs,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
self.n_freq = n_freq
|
| 239 |
+
self.band_specs = band_specs
|
| 240 |
+
self.in_channel = in_channel
|
| 241 |
+
|
| 242 |
+
if freq_weights is not None:
|
| 243 |
+
for i, fw in enumerate(freq_weights):
|
| 244 |
+
self.register_buffer(f"freq_weights/{i}", fw)
|
| 245 |
+
|
| 246 |
+
self.use_freq_weights = use_freq_weights
|
| 247 |
+
else:
|
| 248 |
+
self.use_freq_weights = False
|
| 249 |
+
|
| 250 |
+
self.cond_dim = cond_dim
|
| 251 |
+
|
| 252 |
+
def forward(self, q, cond=None):
|
| 253 |
+
|
| 254 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 255 |
+
|
| 256 |
+
if cond is not None:
|
| 257 |
+
print(cond)
|
| 258 |
+
if cond.ndim == 2:
|
| 259 |
+
cond = cond[:, None, None, :].expand(-1, n_bands, n_time, -1)
|
| 260 |
+
elif cond.ndim == 3:
|
| 261 |
+
assert cond.shape[1] == n_time
|
| 262 |
+
else:
|
| 263 |
+
raise ValueError(f"Invalid cond shape: {cond.shape}")
|
| 264 |
+
|
| 265 |
+
q = torch.cat([q, cond], dim=-1)
|
| 266 |
+
elif self.cond_dim > 0:
|
| 267 |
+
cond = torch.ones(
|
| 268 |
+
(batch, n_bands, n_time, self.cond_dim),
|
| 269 |
+
device=q.device,
|
| 270 |
+
dtype=q.dtype,
|
| 271 |
+
)
|
| 272 |
+
q = torch.cat([q, cond], dim=-1)
|
| 273 |
+
else:
|
| 274 |
+
pass
|
| 275 |
+
|
| 276 |
+
mask_list = self.compute_masks(q)
|
| 277 |
+
|
| 278 |
+
masks = torch.zeros(
|
| 279 |
+
(batch, self.in_channel, self.n_freq, n_time),
|
| 280 |
+
device=q.device,
|
| 281 |
+
dtype=mask_list[0].dtype,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
for im, mask in enumerate(mask_list):
|
| 285 |
+
fstart, fend = self.band_specs[im]
|
| 286 |
+
if self.use_freq_weights:
|
| 287 |
+
fw = self.get_buffer(f"freq_weights/{im}")[:, None]
|
| 288 |
+
mask = mask * fw
|
| 289 |
+
masks[:, :, fstart:fend, :] += mask
|
| 290 |
+
|
| 291 |
+
return masks
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class MaskEstimationModule(OverlappingMaskEstimationModule):
|
| 295 |
+
def __init__(
|
| 296 |
+
self,
|
| 297 |
+
band_specs: List[Tuple[float, float]],
|
| 298 |
+
emb_dim: int,
|
| 299 |
+
mlp_dim: int,
|
| 300 |
+
in_channel: Optional[int],
|
| 301 |
+
hidden_activation: str = "Tanh",
|
| 302 |
+
hidden_activation_kwargs: Dict = None,
|
| 303 |
+
complex_mask: bool = True,
|
| 304 |
+
**kwargs,
|
| 305 |
+
) -> None:
|
| 306 |
+
check_nonzero_bandwidth(band_specs)
|
| 307 |
+
check_no_gap(band_specs)
|
| 308 |
+
check_no_overlap(band_specs)
|
| 309 |
+
super().__init__(
|
| 310 |
+
in_channel=in_channel,
|
| 311 |
+
band_specs=band_specs,
|
| 312 |
+
freq_weights=None,
|
| 313 |
+
n_freq=None,
|
| 314 |
+
emb_dim=emb_dim,
|
| 315 |
+
mlp_dim=mlp_dim,
|
| 316 |
+
hidden_activation=hidden_activation,
|
| 317 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 318 |
+
complex_mask=complex_mask,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
def forward(self, q, cond=None):
|
| 322 |
+
|
| 323 |
+
masks = self.compute_masks(q)
|
| 324 |
+
|
| 325 |
+
masks = torch.concat(masks, dim=2)
|
| 326 |
+
|
| 327 |
+
return masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mvsepless/models/bandit/core/model/bsrnn/tfmodel.py
CHANGED
|
@@ -1,320 +1,287 @@
|
|
| 1 |
-
import warnings
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn
|
| 5 |
-
from torch.nn import functional as F
|
| 6 |
-
from torch.nn.modules import rnn
|
| 7 |
-
|
| 8 |
-
import torch.backends.cuda
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class TimeFrequencyModellingModule(nn.Module):
|
| 12 |
-
def __init__(self) -> None:
|
| 13 |
-
super().__init__()
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class ResidualRNN(nn.Module):
|
| 17 |
-
def __init__(
|
| 18 |
-
self,
|
| 19 |
-
emb_dim: int,
|
| 20 |
-
rnn_dim: int,
|
| 21 |
-
bidirectional: bool = True,
|
| 22 |
-
rnn_type: str = "LSTM",
|
| 23 |
-
use_batch_trick: bool = True,
|
| 24 |
-
use_layer_norm: bool = True,
|
| 25 |
-
) -> None:
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
self.use_batch_trick
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
z = self.norm(z)
|
| 60 |
-
|
| 61 |
-
z = torch.permute(
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
self.
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
z =
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
def __init__(
|
| 289 |
-
self,
|
| 290 |
-
n_modules: int = 12,
|
| 291 |
-
emb_dim: int = 128,
|
| 292 |
-
rnn_dim: int = 256,
|
| 293 |
-
bidirectional: bool = True,
|
| 294 |
-
dropout: float = 0.0,
|
| 295 |
-
) -> None:
|
| 296 |
-
super().__init__()
|
| 297 |
-
self.seqband = torch.jit.script(
|
| 298 |
-
nn.Sequential(
|
| 299 |
-
*[
|
| 300 |
-
ResidualConvolution(
|
| 301 |
-
emb_dim=emb_dim,
|
| 302 |
-
rnn_dim=rnn_dim,
|
| 303 |
-
bidirectional=bidirectional,
|
| 304 |
-
dropout=dropout,
|
| 305 |
-
)
|
| 306 |
-
for _ in range(2 * n_modules)
|
| 307 |
-
]
|
| 308 |
-
)
|
| 309 |
-
)
|
| 310 |
-
|
| 311 |
-
def forward(self, z):
|
| 312 |
-
# z = (batch, n_bands, n_time, emb_dim)
|
| 313 |
-
|
| 314 |
-
z = torch.permute(z, (0, 3, 1, 2)) # (batch, emb_dim, n_bands, n_time)
|
| 315 |
-
|
| 316 |
-
z = self.seqband(z) # (batch, emb_dim, n_bands, n_time)
|
| 317 |
-
|
| 318 |
-
z = torch.permute(z, (0, 2, 3, 1)) # (batch, n_bands, n_time, emb_dim)
|
| 319 |
-
|
| 320 |
-
return z
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from torch.nn.modules import rnn
|
| 7 |
+
|
| 8 |
+
import torch.backends.cuda
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TimeFrequencyModellingModule(nn.Module):
|
| 12 |
+
def __init__(self) -> None:
|
| 13 |
+
super().__init__()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ResidualRNN(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
emb_dim: int,
|
| 20 |
+
rnn_dim: int,
|
| 21 |
+
bidirectional: bool = True,
|
| 22 |
+
rnn_type: str = "LSTM",
|
| 23 |
+
use_batch_trick: bool = True,
|
| 24 |
+
use_layer_norm: bool = True,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
self.use_layer_norm = use_layer_norm
|
| 29 |
+
if use_layer_norm:
|
| 30 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 31 |
+
else:
|
| 32 |
+
self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim)
|
| 33 |
+
|
| 34 |
+
self.rnn = rnn.__dict__[rnn_type](
|
| 35 |
+
input_size=emb_dim,
|
| 36 |
+
hidden_size=rnn_dim,
|
| 37 |
+
num_layers=1,
|
| 38 |
+
batch_first=True,
|
| 39 |
+
bidirectional=bidirectional,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
self.fc = nn.Linear(
|
| 43 |
+
in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
self.use_batch_trick = use_batch_trick
|
| 47 |
+
if not self.use_batch_trick:
|
| 48 |
+
warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
|
| 49 |
+
|
| 50 |
+
def forward(self, z):
|
| 51 |
+
|
| 52 |
+
z0 = torch.clone(z)
|
| 53 |
+
|
| 54 |
+
if self.use_layer_norm:
|
| 55 |
+
z = self.norm(z)
|
| 56 |
+
else:
|
| 57 |
+
z = torch.permute(z, (0, 3, 1, 2))
|
| 58 |
+
|
| 59 |
+
z = self.norm(z)
|
| 60 |
+
|
| 61 |
+
z = torch.permute(z, (0, 2, 3, 1))
|
| 62 |
+
|
| 63 |
+
batch, n_uncrossed, n_across, emb_dim = z.shape
|
| 64 |
+
|
| 65 |
+
if self.use_batch_trick:
|
| 66 |
+
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
|
| 67 |
+
|
| 68 |
+
z = self.rnn(z.contiguous())[0]
|
| 69 |
+
|
| 70 |
+
z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
|
| 71 |
+
else:
|
| 72 |
+
zlist = []
|
| 73 |
+
for i in range(n_uncrossed):
|
| 74 |
+
zi = self.rnn(z[:, i, :, :])[0]
|
| 75 |
+
zlist.append(zi)
|
| 76 |
+
|
| 77 |
+
z = torch.stack(zlist, dim=1)
|
| 78 |
+
|
| 79 |
+
z = self.fc(z)
|
| 80 |
+
|
| 81 |
+
z = z + z0
|
| 82 |
+
|
| 83 |
+
return z
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class SeqBandModellingModule(TimeFrequencyModellingModule):
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
n_modules: int = 12,
|
| 90 |
+
emb_dim: int = 128,
|
| 91 |
+
rnn_dim: int = 256,
|
| 92 |
+
bidirectional: bool = True,
|
| 93 |
+
rnn_type: str = "LSTM",
|
| 94 |
+
parallel_mode=False,
|
| 95 |
+
) -> None:
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.seqband = nn.ModuleList([])
|
| 98 |
+
|
| 99 |
+
if parallel_mode:
|
| 100 |
+
for _ in range(n_modules):
|
| 101 |
+
self.seqband.append(
|
| 102 |
+
nn.ModuleList(
|
| 103 |
+
[
|
| 104 |
+
ResidualRNN(
|
| 105 |
+
emb_dim=emb_dim,
|
| 106 |
+
rnn_dim=rnn_dim,
|
| 107 |
+
bidirectional=bidirectional,
|
| 108 |
+
rnn_type=rnn_type,
|
| 109 |
+
),
|
| 110 |
+
ResidualRNN(
|
| 111 |
+
emb_dim=emb_dim,
|
| 112 |
+
rnn_dim=rnn_dim,
|
| 113 |
+
bidirectional=bidirectional,
|
| 114 |
+
rnn_type=rnn_type,
|
| 115 |
+
),
|
| 116 |
+
]
|
| 117 |
+
)
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
|
| 121 |
+
for _ in range(2 * n_modules):
|
| 122 |
+
self.seqband.append(
|
| 123 |
+
ResidualRNN(
|
| 124 |
+
emb_dim=emb_dim,
|
| 125 |
+
rnn_dim=rnn_dim,
|
| 126 |
+
bidirectional=bidirectional,
|
| 127 |
+
rnn_type=rnn_type,
|
| 128 |
+
)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
self.parallel_mode = parallel_mode
|
| 132 |
+
|
| 133 |
+
def forward(self, z):
|
| 134 |
+
|
| 135 |
+
if self.parallel_mode:
|
| 136 |
+
for sbm_pair in self.seqband:
|
| 137 |
+
sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
|
| 138 |
+
zt = sbm_t(z)
|
| 139 |
+
zf = sbm_f(z.transpose(1, 2))
|
| 140 |
+
z = zt + zf.transpose(1, 2)
|
| 141 |
+
else:
|
| 142 |
+
for sbm in self.seqband:
|
| 143 |
+
z = sbm(z)
|
| 144 |
+
z = z.transpose(1, 2)
|
| 145 |
+
|
| 146 |
+
q = z
|
| 147 |
+
return q
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class ResidualTransformer(nn.Module):
|
| 151 |
+
def __init__(
|
| 152 |
+
self,
|
| 153 |
+
emb_dim: int = 128,
|
| 154 |
+
rnn_dim: int = 256,
|
| 155 |
+
bidirectional: bool = True,
|
| 156 |
+
dropout: float = 0.0,
|
| 157 |
+
) -> None:
|
| 158 |
+
super().__init__()
|
| 159 |
+
|
| 160 |
+
self.tf = nn.TransformerEncoderLayer(
|
| 161 |
+
d_model=emb_dim, nhead=4, dim_feedforward=rnn_dim, batch_first=True
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
self.is_causal = not bidirectional
|
| 165 |
+
self.dropout = dropout
|
| 166 |
+
|
| 167 |
+
def forward(self, z):
|
| 168 |
+
batch, n_uncrossed, n_across, emb_dim = z.shape
|
| 169 |
+
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
|
| 170 |
+
z = self.tf(z, is_causal=self.is_causal)
|
| 171 |
+
z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim))
|
| 172 |
+
|
| 173 |
+
return z
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class TransformerTimeFreqModule(TimeFrequencyModellingModule):
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
n_modules: int = 12,
|
| 180 |
+
emb_dim: int = 128,
|
| 181 |
+
rnn_dim: int = 256,
|
| 182 |
+
bidirectional: bool = True,
|
| 183 |
+
dropout: float = 0.0,
|
| 184 |
+
) -> None:
|
| 185 |
+
super().__init__()
|
| 186 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 187 |
+
self.seqband = nn.ModuleList([])
|
| 188 |
+
|
| 189 |
+
for _ in range(2 * n_modules):
|
| 190 |
+
self.seqband.append(
|
| 191 |
+
ResidualTransformer(
|
| 192 |
+
emb_dim=emb_dim,
|
| 193 |
+
rnn_dim=rnn_dim,
|
| 194 |
+
bidirectional=bidirectional,
|
| 195 |
+
dropout=dropout,
|
| 196 |
+
)
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
def forward(self, z):
|
| 200 |
+
z = self.norm(z)
|
| 201 |
+
|
| 202 |
+
for sbm in self.seqband:
|
| 203 |
+
z = sbm(z)
|
| 204 |
+
z = z.transpose(1, 2)
|
| 205 |
+
|
| 206 |
+
q = z
|
| 207 |
+
return q
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class ResidualConvolution(nn.Module):
|
| 211 |
+
def __init__(
|
| 212 |
+
self,
|
| 213 |
+
emb_dim: int = 128,
|
| 214 |
+
rnn_dim: int = 256,
|
| 215 |
+
bidirectional: bool = True,
|
| 216 |
+
dropout: float = 0.0,
|
| 217 |
+
) -> None:
|
| 218 |
+
super().__init__()
|
| 219 |
+
self.norm = nn.InstanceNorm2d(emb_dim, affine=True)
|
| 220 |
+
|
| 221 |
+
self.conv = nn.Sequential(
|
| 222 |
+
nn.Conv2d(
|
| 223 |
+
in_channels=emb_dim,
|
| 224 |
+
out_channels=rnn_dim,
|
| 225 |
+
kernel_size=(3, 3),
|
| 226 |
+
padding="same",
|
| 227 |
+
stride=(1, 1),
|
| 228 |
+
),
|
| 229 |
+
nn.Tanhshrink(),
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
self.is_causal = not bidirectional
|
| 233 |
+
self.dropout = dropout
|
| 234 |
+
|
| 235 |
+
self.fc = nn.Conv2d(
|
| 236 |
+
in_channels=rnn_dim,
|
| 237 |
+
out_channels=emb_dim,
|
| 238 |
+
kernel_size=(1, 1),
|
| 239 |
+
padding="same",
|
| 240 |
+
stride=(1, 1),
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
def forward(self, z):
|
| 244 |
+
|
| 245 |
+
z0 = torch.clone(z)
|
| 246 |
+
|
| 247 |
+
z = self.norm(z)
|
| 248 |
+
z = self.conv(z)
|
| 249 |
+
z = self.fc(z)
|
| 250 |
+
z = z + z0
|
| 251 |
+
|
| 252 |
+
return z
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule):
|
| 256 |
+
def __init__(
|
| 257 |
+
self,
|
| 258 |
+
n_modules: int = 12,
|
| 259 |
+
emb_dim: int = 128,
|
| 260 |
+
rnn_dim: int = 256,
|
| 261 |
+
bidirectional: bool = True,
|
| 262 |
+
dropout: float = 0.0,
|
| 263 |
+
) -> None:
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.seqband = torch.jit.script(
|
| 266 |
+
nn.Sequential(
|
| 267 |
+
*[
|
| 268 |
+
ResidualConvolution(
|
| 269 |
+
emb_dim=emb_dim,
|
| 270 |
+
rnn_dim=rnn_dim,
|
| 271 |
+
bidirectional=bidirectional,
|
| 272 |
+
dropout=dropout,
|
| 273 |
+
)
|
| 274 |
+
for _ in range(2 * n_modules)
|
| 275 |
+
]
|
| 276 |
+
)
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
def forward(self, z):
|
| 280 |
+
|
| 281 |
+
z = torch.permute(z, (0, 3, 1, 2))
|
| 282 |
+
|
| 283 |
+
z = self.seqband(z)
|
| 284 |
+
|
| 285 |
+
z = torch.permute(z, (0, 2, 3, 1))
|
| 286 |
+
|
| 287 |
+
return z
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mvsepless/models/bandit/core/model/bsrnn/utils.py
CHANGED
|
@@ -1,525 +1,518 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from abc import abstractmethod
|
| 3 |
-
from typing import Any, Callable
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import torch
|
| 7 |
-
from librosa import hz_to_midi, midi_to_hz
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
from torchaudio import functional as taF
|
| 10 |
-
from spafe.fbanks import bark_fbanks
|
| 11 |
-
from spafe.utils.converters import erb2hz, hz2bark, hz2erb
|
| 12 |
-
from torchaudio.functional.functional import _create_triangular_filterbank
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def band_widths_from_specs(band_specs):
|
| 16 |
-
return [e - i for i, e in band_specs]
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def check_nonzero_bandwidth(band_specs):
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
fstart
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
self.
|
| 47 |
-
self.
|
| 48 |
-
self.
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
self.
|
| 52 |
-
self.
|
| 53 |
-
self.
|
| 54 |
-
self.
|
| 55 |
-
self.
|
| 56 |
-
self.
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
self.
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
upper =
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
)
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
)
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
self.
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
"
|
| 292 |
-
"
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
)
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
fb =
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
for i, (f_min, f_max) in enumerate(mbs):
|
| 520 |
-
band_defs.append(
|
| 521 |
-
{"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
|
| 522 |
-
)
|
| 523 |
-
|
| 524 |
-
df = pd.DataFrame(band_defs)
|
| 525 |
-
df.to_csv("vox7bands.csv", index=False)
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import abstractmethod
|
| 3 |
+
from typing import Any, Callable
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from librosa import hz_to_midi, midi_to_hz
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torchaudio import functional as taF
|
| 10 |
+
from spafe.fbanks import bark_fbanks
|
| 11 |
+
from spafe.utils.converters import erb2hz, hz2bark, hz2erb
|
| 12 |
+
from torchaudio.functional.functional import _create_triangular_filterbank
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def band_widths_from_specs(band_specs):
|
| 16 |
+
return [e - i for i, e in band_specs]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def check_nonzero_bandwidth(band_specs):
|
| 20 |
+
for fstart, fend in band_specs:
|
| 21 |
+
if fend - fstart <= 0:
|
| 22 |
+
raise ValueError("Bands cannot be zero-width")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def check_no_overlap(band_specs):
|
| 26 |
+
fend_prev = -1
|
| 27 |
+
for fstart_curr, fend_curr in band_specs:
|
| 28 |
+
if fstart_curr <= fend_prev:
|
| 29 |
+
raise ValueError("Bands cannot overlap")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def check_no_gap(band_specs):
|
| 33 |
+
fstart, _ = band_specs[0]
|
| 34 |
+
assert fstart == 0
|
| 35 |
+
|
| 36 |
+
fend_prev = -1
|
| 37 |
+
for fstart_curr, fend_curr in band_specs:
|
| 38 |
+
if fstart_curr - fend_prev > 1:
|
| 39 |
+
raise ValueError("Bands cannot leave gap")
|
| 40 |
+
fend_prev = fend_curr
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class BandsplitSpecification:
|
| 44 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 45 |
+
self.fs = fs
|
| 46 |
+
self.nfft = nfft
|
| 47 |
+
self.nyquist = fs / 2
|
| 48 |
+
self.max_index = nfft // 2 + 1
|
| 49 |
+
|
| 50 |
+
self.split500 = self.hertz_to_index(500)
|
| 51 |
+
self.split1k = self.hertz_to_index(1000)
|
| 52 |
+
self.split2k = self.hertz_to_index(2000)
|
| 53 |
+
self.split4k = self.hertz_to_index(4000)
|
| 54 |
+
self.split8k = self.hertz_to_index(8000)
|
| 55 |
+
self.split16k = self.hertz_to_index(16000)
|
| 56 |
+
self.split20k = self.hertz_to_index(20000)
|
| 57 |
+
|
| 58 |
+
self.above20k = [(self.split20k, self.max_index)]
|
| 59 |
+
self.above16k = [(self.split16k, self.split20k)] + self.above20k
|
| 60 |
+
|
| 61 |
+
def index_to_hertz(self, index: int):
|
| 62 |
+
return index * self.fs / self.nfft
|
| 63 |
+
|
| 64 |
+
def hertz_to_index(self, hz: float, round: bool = True):
|
| 65 |
+
index = hz * self.nfft / self.fs
|
| 66 |
+
|
| 67 |
+
if round:
|
| 68 |
+
index = int(np.round(index))
|
| 69 |
+
|
| 70 |
+
return index
|
| 71 |
+
|
| 72 |
+
def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
|
| 73 |
+
band_specs = []
|
| 74 |
+
lower = start_index
|
| 75 |
+
|
| 76 |
+
while lower < end_index:
|
| 77 |
+
upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
|
| 78 |
+
upper = min(upper, end_index)
|
| 79 |
+
|
| 80 |
+
band_specs.append((lower, upper))
|
| 81 |
+
lower = upper
|
| 82 |
+
|
| 83 |
+
return band_specs
|
| 84 |
+
|
| 85 |
+
@abstractmethod
|
| 86 |
+
def get_band_specs(self):
|
| 87 |
+
raise NotImplementedError
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class VocalBandsplitSpecification(BandsplitSpecification):
|
| 91 |
+
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 92 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 93 |
+
|
| 94 |
+
self.version = version
|
| 95 |
+
|
| 96 |
+
def get_band_specs(self):
|
| 97 |
+
return getattr(self, f"version{self.version}")()
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def version1(self):
|
| 101 |
+
return self.get_band_specs_with_bandwidth(
|
| 102 |
+
start_index=0, end_index=self.max_index, bandwidth_hz=1000
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def version2(self):
|
| 106 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 107 |
+
start_index=0, end_index=self.split16k, bandwidth_hz=1000
|
| 108 |
+
)
|
| 109 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 110 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
return below16k + below20k + self.above20k
|
| 114 |
+
|
| 115 |
+
def version3(self):
|
| 116 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 117 |
+
start_index=0, end_index=self.split8k, bandwidth_hz=1000
|
| 118 |
+
)
|
| 119 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 120 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
return below8k + below16k + self.above16k
|
| 124 |
+
|
| 125 |
+
def version4(self):
|
| 126 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 127 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 128 |
+
)
|
| 129 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 130 |
+
start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
|
| 131 |
+
)
|
| 132 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 133 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
return below1k + below8k + below16k + self.above16k
|
| 137 |
+
|
| 138 |
+
def version5(self):
|
| 139 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 140 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 141 |
+
)
|
| 142 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 143 |
+
start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
|
| 144 |
+
)
|
| 145 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 146 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 147 |
+
)
|
| 148 |
+
return below1k + below16k + below20k + self.above20k
|
| 149 |
+
|
| 150 |
+
def version6(self):
|
| 151 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 152 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 153 |
+
)
|
| 154 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 155 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
|
| 156 |
+
)
|
| 157 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 158 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
|
| 159 |
+
)
|
| 160 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 161 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 162 |
+
)
|
| 163 |
+
return below1k + below4k + below8k + below16k + self.above16k
|
| 164 |
+
|
| 165 |
+
def version7(self):
|
| 166 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 167 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 168 |
+
)
|
| 169 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 170 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
|
| 171 |
+
)
|
| 172 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 173 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
|
| 174 |
+
)
|
| 175 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 176 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
|
| 177 |
+
)
|
| 178 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 179 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 180 |
+
)
|
| 181 |
+
return below1k + below4k + below8k + below16k + below20k + self.above20k
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class OtherBandsplitSpecification(VocalBandsplitSpecification):
|
| 185 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 186 |
+
super().__init__(nfft=nfft, fs=fs, version="7")
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class BassBandsplitSpecification(BandsplitSpecification):
|
| 190 |
+
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 191 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 192 |
+
|
| 193 |
+
def get_band_specs(self):
|
| 194 |
+
below500 = self.get_band_specs_with_bandwidth(
|
| 195 |
+
start_index=0, end_index=self.split500, bandwidth_hz=50
|
| 196 |
+
)
|
| 197 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 198 |
+
start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
|
| 199 |
+
)
|
| 200 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 201 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
|
| 202 |
+
)
|
| 203 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 204 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
|
| 205 |
+
)
|
| 206 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 207 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 208 |
+
)
|
| 209 |
+
above16k = [(self.split16k, self.max_index)]
|
| 210 |
+
|
| 211 |
+
return below500 + below1k + below4k + below8k + below16k + above16k
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class DrumBandsplitSpecification(BandsplitSpecification):
|
| 215 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 216 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 217 |
+
|
| 218 |
+
def get_band_specs(self):
|
| 219 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 220 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=50
|
| 221 |
+
)
|
| 222 |
+
below2k = self.get_band_specs_with_bandwidth(
|
| 223 |
+
start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
|
| 224 |
+
)
|
| 225 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 226 |
+
start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
|
| 227 |
+
)
|
| 228 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 229 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
|
| 230 |
+
)
|
| 231 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 232 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
|
| 233 |
+
)
|
| 234 |
+
above16k = [(self.split16k, self.max_index)]
|
| 235 |
+
|
| 236 |
+
return below1k + below2k + below4k + below8k + below16k + above16k
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class PerceptualBandsplitSpecification(BandsplitSpecification):
|
| 240 |
+
def __init__(
|
| 241 |
+
self,
|
| 242 |
+
nfft: int,
|
| 243 |
+
fs: int,
|
| 244 |
+
fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
|
| 245 |
+
n_bands: int,
|
| 246 |
+
f_min: float = 0.0,
|
| 247 |
+
f_max: float = None,
|
| 248 |
+
) -> None:
|
| 249 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 250 |
+
self.n_bands = n_bands
|
| 251 |
+
if f_max is None:
|
| 252 |
+
f_max = fs / 2
|
| 253 |
+
|
| 254 |
+
self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
|
| 255 |
+
|
| 256 |
+
weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True)
|
| 257 |
+
normalized_mel_fb = self.filterbank / weight_per_bin
|
| 258 |
+
|
| 259 |
+
freq_weights = []
|
| 260 |
+
band_specs = []
|
| 261 |
+
for i in range(self.n_bands):
|
| 262 |
+
active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
|
| 263 |
+
if isinstance(active_bins, int):
|
| 264 |
+
active_bins = (active_bins, active_bins)
|
| 265 |
+
if len(active_bins) == 0:
|
| 266 |
+
continue
|
| 267 |
+
start_index = active_bins[0]
|
| 268 |
+
end_index = active_bins[-1] + 1
|
| 269 |
+
band_specs.append((start_index, end_index))
|
| 270 |
+
freq_weights.append(normalized_mel_fb[i, start_index:end_index])
|
| 271 |
+
|
| 272 |
+
self.freq_weights = freq_weights
|
| 273 |
+
self.band_specs = band_specs
|
| 274 |
+
|
| 275 |
+
def get_band_specs(self):
|
| 276 |
+
return self.band_specs
|
| 277 |
+
|
| 278 |
+
def get_freq_weights(self):
|
| 279 |
+
return self.freq_weights
|
| 280 |
+
|
| 281 |
+
def save_to_file(self, dir_path: str) -> None:
|
| 282 |
+
|
| 283 |
+
os.makedirs(dir_path, exist_ok=True)
|
| 284 |
+
|
| 285 |
+
import pickle
|
| 286 |
+
|
| 287 |
+
with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
|
| 288 |
+
pickle.dump(
|
| 289 |
+
{
|
| 290 |
+
"band_specs": self.band_specs,
|
| 291 |
+
"freq_weights": self.freq_weights,
|
| 292 |
+
"filterbank": self.filterbank,
|
| 293 |
+
},
|
| 294 |
+
f,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 299 |
+
fb = taF.melscale_fbanks(
|
| 300 |
+
n_mels=n_bands,
|
| 301 |
+
sample_rate=fs,
|
| 302 |
+
f_min=f_min,
|
| 303 |
+
f_max=f_max,
|
| 304 |
+
n_freqs=n_freqs,
|
| 305 |
+
).T
|
| 306 |
+
|
| 307 |
+
fb[0, 0] = 1.0
|
| 308 |
+
|
| 309 |
+
return fb
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class MelBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 313 |
+
def __init__(
|
| 314 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 315 |
+
) -> None:
|
| 316 |
+
super().__init__(
|
| 317 |
+
fbank_fn=mel_filterbank,
|
| 318 |
+
nfft=nfft,
|
| 319 |
+
fs=fs,
|
| 320 |
+
n_bands=n_bands,
|
| 321 |
+
f_min=f_min,
|
| 322 |
+
f_max=f_max,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
|
| 327 |
+
|
| 328 |
+
nfft = 2 * (n_freqs - 1)
|
| 329 |
+
df = fs / nfft
|
| 330 |
+
f_max = f_max or fs / 2
|
| 331 |
+
f_min = f_min or 0
|
| 332 |
+
f_min = fs / nfft
|
| 333 |
+
|
| 334 |
+
n_octaves = np.log2(f_max / f_min)
|
| 335 |
+
n_octaves_per_band = n_octaves / n_bands
|
| 336 |
+
bandwidth_mult = np.power(2.0, n_octaves_per_band)
|
| 337 |
+
|
| 338 |
+
low_midi = max(0, hz_to_midi(f_min))
|
| 339 |
+
high_midi = hz_to_midi(f_max)
|
| 340 |
+
midi_points = np.linspace(low_midi, high_midi, n_bands)
|
| 341 |
+
hz_pts = midi_to_hz(midi_points)
|
| 342 |
+
|
| 343 |
+
low_pts = hz_pts / bandwidth_mult
|
| 344 |
+
high_pts = hz_pts * bandwidth_mult
|
| 345 |
+
|
| 346 |
+
low_bins = np.floor(low_pts / df).astype(int)
|
| 347 |
+
high_bins = np.ceil(high_pts / df).astype(int)
|
| 348 |
+
|
| 349 |
+
fb = np.zeros((n_bands, n_freqs))
|
| 350 |
+
|
| 351 |
+
for i in range(n_bands):
|
| 352 |
+
fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
|
| 353 |
+
|
| 354 |
+
fb[0, : low_bins[0]] = 1.0
|
| 355 |
+
fb[-1, high_bins[-1] + 1 :] = 1.0
|
| 356 |
+
|
| 357 |
+
return torch.as_tensor(fb)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 361 |
+
def __init__(
|
| 362 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 363 |
+
) -> None:
|
| 364 |
+
super().__init__(
|
| 365 |
+
fbank_fn=musical_filterbank,
|
| 366 |
+
nfft=nfft,
|
| 367 |
+
fs=fs,
|
| 368 |
+
n_bands=n_bands,
|
| 369 |
+
f_min=f_min,
|
| 370 |
+
f_max=f_max,
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 375 |
+
nfft = 2 * (n_freqs - 1)
|
| 376 |
+
fb, _ = bark_fbanks.bark_filter_banks(
|
| 377 |
+
nfilts=n_bands,
|
| 378 |
+
nfft=nfft,
|
| 379 |
+
fs=fs,
|
| 380 |
+
low_freq=f_min,
|
| 381 |
+
high_freq=f_max,
|
| 382 |
+
scale="constant",
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
return torch.as_tensor(fb)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 389 |
+
def __init__(
|
| 390 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 391 |
+
) -> None:
|
| 392 |
+
super().__init__(
|
| 393 |
+
fbank_fn=bark_filterbank,
|
| 394 |
+
nfft=nfft,
|
| 395 |
+
fs=fs,
|
| 396 |
+
n_bands=n_bands,
|
| 397 |
+
f_min=f_min,
|
| 398 |
+
f_max=f_max,
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def triangular_bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 403 |
+
|
| 404 |
+
all_freqs = torch.linspace(0, fs // 2, n_freqs)
|
| 405 |
+
|
| 406 |
+
m_min = hz2bark(f_min)
|
| 407 |
+
m_max = hz2bark(f_max)
|
| 408 |
+
|
| 409 |
+
m_pts = torch.linspace(m_min, m_max, n_bands + 2)
|
| 410 |
+
f_pts = 600 * torch.sinh(m_pts / 6)
|
| 411 |
+
|
| 412 |
+
fb = _create_triangular_filterbank(all_freqs, f_pts)
|
| 413 |
+
|
| 414 |
+
fb = fb.T
|
| 415 |
+
|
| 416 |
+
first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
|
| 417 |
+
first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
|
| 418 |
+
|
| 419 |
+
fb[first_active_band, :first_active_bin] = 1.0
|
| 420 |
+
|
| 421 |
+
return fb
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 425 |
+
def __init__(
|
| 426 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 427 |
+
) -> None:
|
| 428 |
+
super().__init__(
|
| 429 |
+
fbank_fn=triangular_bark_filterbank,
|
| 430 |
+
nfft=nfft,
|
| 431 |
+
fs=fs,
|
| 432 |
+
n_bands=n_bands,
|
| 433 |
+
f_min=f_min,
|
| 434 |
+
f_max=f_max,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def minibark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 439 |
+
fb = bark_filterbank(n_bands, fs, f_min, f_max, n_freqs)
|
| 440 |
+
|
| 441 |
+
fb[fb < np.sqrt(0.5)] = 0.0
|
| 442 |
+
|
| 443 |
+
return fb
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 447 |
+
def __init__(
|
| 448 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 449 |
+
) -> None:
|
| 450 |
+
super().__init__(
|
| 451 |
+
fbank_fn=minibark_filterbank,
|
| 452 |
+
nfft=nfft,
|
| 453 |
+
fs=fs,
|
| 454 |
+
n_bands=n_bands,
|
| 455 |
+
f_min=f_min,
|
| 456 |
+
f_max=f_max,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def erb_filterbank(
|
| 461 |
+
n_bands: int,
|
| 462 |
+
fs: int,
|
| 463 |
+
f_min: float,
|
| 464 |
+
f_max: float,
|
| 465 |
+
n_freqs: int,
|
| 466 |
+
) -> Tensor:
|
| 467 |
+
A = (1000 * np.log(10)) / (24.7 * 4.37)
|
| 468 |
+
all_freqs = torch.linspace(0, fs // 2, n_freqs)
|
| 469 |
+
|
| 470 |
+
m_min = hz2erb(f_min)
|
| 471 |
+
m_max = hz2erb(f_max)
|
| 472 |
+
|
| 473 |
+
m_pts = torch.linspace(m_min, m_max, n_bands + 2)
|
| 474 |
+
f_pts = (torch.pow(10, (m_pts / A)) - 1) / 0.00437
|
| 475 |
+
|
| 476 |
+
fb = _create_triangular_filterbank(all_freqs, f_pts)
|
| 477 |
+
|
| 478 |
+
fb = fb.T
|
| 479 |
+
|
| 480 |
+
first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
|
| 481 |
+
first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
|
| 482 |
+
|
| 483 |
+
fb[first_active_band, :first_active_bin] = 1.0
|
| 484 |
+
|
| 485 |
+
return fb
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 489 |
+
def __init__(
|
| 490 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 491 |
+
) -> None:
|
| 492 |
+
super().__init__(
|
| 493 |
+
fbank_fn=erb_filterbank,
|
| 494 |
+
nfft=nfft,
|
| 495 |
+
fs=fs,
|
| 496 |
+
n_bands=n_bands,
|
| 497 |
+
f_min=f_min,
|
| 498 |
+
f_max=f_max,
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
if __name__ == "__main__":
|
| 503 |
+
import pandas as pd
|
| 504 |
+
|
| 505 |
+
band_defs = []
|
| 506 |
+
|
| 507 |
+
for bands in [VocalBandsplitSpecification]:
|
| 508 |
+
band_name = bands.__name__.replace("BandsplitSpecification", "")
|
| 509 |
+
|
| 510 |
+
mbs = bands(nfft=2048, fs=44100).get_band_specs()
|
| 511 |
+
|
| 512 |
+
for i, (f_min, f_max) in enumerate(mbs):
|
| 513 |
+
band_defs.append(
|
| 514 |
+
{"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
df = pd.DataFrame(band_defs)
|
| 518 |
+
df.to_csv("vox7bands.csv", index=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mvsepless/models/bandit/core/model/bsrnn/wrapper.py
CHANGED
|
@@ -1,829 +1,828 @@
|
|
| 1 |
-
from pprint import pprint
|
| 2 |
-
from typing import Dict, List, Optional, Tuple, Union
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
from torch import nn
|
| 6 |
-
|
| 7 |
-
from .._spectral import _SpectralComponent
|
| 8 |
-
from .utils import (
|
| 9 |
-
BarkBandsplitSpecification,
|
| 10 |
-
BassBandsplitSpecification,
|
| 11 |
-
DrumBandsplitSpecification,
|
| 12 |
-
EquivalentRectangularBandsplitSpecification,
|
| 13 |
-
MelBandsplitSpecification,
|
| 14 |
-
MusicalBandsplitSpecification,
|
| 15 |
-
OtherBandsplitSpecification,
|
| 16 |
-
TriangularBarkBandsplitSpecification,
|
| 17 |
-
VocalBandsplitSpecification,
|
| 18 |
-
)
|
| 19 |
-
from .core import (
|
| 20 |
-
MultiSourceMultiMaskBandSplitCoreConv,
|
| 21 |
-
MultiSourceMultiMaskBandSplitCoreRNN,
|
| 22 |
-
MultiSourceMultiMaskBandSplitCoreTransformer,
|
| 23 |
-
MultiSourceMultiPatchingMaskBandSplitCoreRNN,
|
| 24 |
-
SingleMaskBandsplitCoreRNN,
|
| 25 |
-
SingleMaskBandsplitCoreTransformer,
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
import pytorch_lightning as pl
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def get_band_specs(band_specs, n_fft, fs, n_bands=None):
|
| 32 |
-
if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]:
|
| 33 |
-
bsm = VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs()
|
| 34 |
-
freq_weights = None
|
| 35 |
-
overlapping_band = False
|
| 36 |
-
elif "tribark" in band_specs:
|
| 37 |
-
assert n_bands is not None
|
| 38 |
-
specs = TriangularBarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 39 |
-
bsm = specs.get_band_specs()
|
| 40 |
-
freq_weights = specs.get_freq_weights()
|
| 41 |
-
overlapping_band = True
|
| 42 |
-
elif "bark" in band_specs:
|
| 43 |
-
assert n_bands is not None
|
| 44 |
-
specs = BarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 45 |
-
bsm = specs.get_band_specs()
|
| 46 |
-
freq_weights = specs.get_freq_weights()
|
| 47 |
-
overlapping_band = True
|
| 48 |
-
elif "erb" in band_specs:
|
| 49 |
-
assert n_bands is not None
|
| 50 |
-
specs = EquivalentRectangularBandsplitSpecification(
|
| 51 |
-
nfft=n_fft, fs=fs, n_bands=n_bands
|
| 52 |
-
)
|
| 53 |
-
bsm = specs.get_band_specs()
|
| 54 |
-
freq_weights = specs.get_freq_weights()
|
| 55 |
-
overlapping_band = True
|
| 56 |
-
elif "musical" in band_specs:
|
| 57 |
-
assert n_bands is not None
|
| 58 |
-
specs = MusicalBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 59 |
-
bsm = specs.get_band_specs()
|
| 60 |
-
freq_weights = specs.get_freq_weights()
|
| 61 |
-
overlapping_band = True
|
| 62 |
-
elif band_specs == "dnr:mel" or "mel" in band_specs:
|
| 63 |
-
assert n_bands is not None
|
| 64 |
-
specs = MelBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 65 |
-
bsm = specs.get_band_specs()
|
| 66 |
-
freq_weights = specs.get_freq_weights()
|
| 67 |
-
overlapping_band = True
|
| 68 |
-
else:
|
| 69 |
-
raise NameError
|
| 70 |
-
|
| 71 |
-
return bsm, freq_weights, overlapping_band
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None):
|
| 75 |
-
if band_specs_map == "musdb:all":
|
| 76 |
-
bsm = {
|
| 77 |
-
"vocals": VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 78 |
-
"drums": DrumBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 79 |
-
"bass": BassBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 80 |
-
"other": OtherBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 81 |
-
}
|
| 82 |
-
freq_weights = None
|
| 83 |
-
overlapping_band = False
|
| 84 |
-
elif band_specs_map == "dnr:vox7":
|
| 85 |
-
bsm_, freq_weights, overlapping_band = get_band_specs(
|
| 86 |
-
"dnr:speech", n_fft, fs, n_bands
|
| 87 |
-
)
|
| 88 |
-
bsm = {"speech": bsm_, "music": bsm_, "effects": bsm_}
|
| 89 |
-
elif "dnr:vox7:" in band_specs_map:
|
| 90 |
-
stem = band_specs_map.split(":")[-1]
|
| 91 |
-
bsm_, freq_weights, overlapping_band = get_band_specs(
|
| 92 |
-
"dnr:speech", n_fft, fs, n_bands
|
| 93 |
-
)
|
| 94 |
-
bsm = {stem: bsm_}
|
| 95 |
-
else:
|
| 96 |
-
raise NameError
|
| 97 |
-
|
| 98 |
-
return bsm, freq_weights, overlapping_band
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
class BandSplitWrapperBase(pl.LightningModule):
|
| 102 |
-
bsrnn: nn.Module
|
| 103 |
-
|
| 104 |
-
def __init__(self, **kwargs):
|
| 105 |
-
super().__init__()
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
class SingleMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
|
| 109 |
-
def __init__(
|
| 110 |
-
self,
|
| 111 |
-
band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
|
| 112 |
-
fs: int = 44100,
|
| 113 |
-
n_fft: int = 2048,
|
| 114 |
-
win_length: Optional[int] = 2048,
|
| 115 |
-
hop_length: int = 512,
|
| 116 |
-
window_fn: str = "hann_window",
|
| 117 |
-
wkwargs: Optional[Dict] = None,
|
| 118 |
-
power: Optional[int] = None,
|
| 119 |
-
center: bool = True,
|
| 120 |
-
normalized: bool = True,
|
| 121 |
-
pad_mode: str = "constant",
|
| 122 |
-
onesided: bool = True,
|
| 123 |
-
n_bands: int = None,
|
| 124 |
-
) -> None:
|
| 125 |
-
super().__init__(
|
| 126 |
-
n_fft=n_fft,
|
| 127 |
-
win_length=win_length,
|
| 128 |
-
hop_length=hop_length,
|
| 129 |
-
window_fn=window_fn,
|
| 130 |
-
wkwargs=wkwargs,
|
| 131 |
-
power=power,
|
| 132 |
-
center=center,
|
| 133 |
-
normalized=normalized,
|
| 134 |
-
pad_mode=pad_mode,
|
| 135 |
-
onesided=onesided,
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
if isinstance(band_specs_map, str):
|
| 139 |
-
self.band_specs_map, self.freq_weights, self.overlapping_band = (
|
| 140 |
-
get_band_specs_map(band_specs_map, n_fft, fs, n_bands=n_bands)
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
self.stems = list(self.band_specs_map.keys())
|
| 144 |
-
|
| 145 |
-
def forward(self, batch):
|
| 146 |
-
audio = batch["audio"]
|
| 147 |
-
|
| 148 |
-
with torch.no_grad():
|
| 149 |
-
batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
|
| 150 |
-
|
| 151 |
-
X = batch["spectrogram"]["mixture"]
|
| 152 |
-
length = batch["audio"]["mixture"].shape[-1]
|
| 153 |
-
|
| 154 |
-
output = {"spectrogram": {}, "audio": {}}
|
| 155 |
-
|
| 156 |
-
for stem, bsrnn in self.bsrnn.items():
|
| 157 |
-
S = bsrnn(X)
|
| 158 |
-
s = self.istft(S, length)
|
| 159 |
-
output["spectrogram"][stem] = S
|
| 160 |
-
output["audio"][stem] = s
|
| 161 |
-
|
| 162 |
-
return batch, output
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
class MultiMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
|
| 166 |
-
def __init__(
|
| 167 |
-
self,
|
| 168 |
-
stems: List[str],
|
| 169 |
-
band_specs: Union[str, List[Tuple[float, float]]],
|
| 170 |
-
fs: int = 44100,
|
| 171 |
-
n_fft: int = 2048,
|
| 172 |
-
win_length: Optional[int] = 2048,
|
| 173 |
-
hop_length: int = 512,
|
| 174 |
-
window_fn: str = "hann_window",
|
| 175 |
-
wkwargs: Optional[Dict] = None,
|
| 176 |
-
power: Optional[int] = None,
|
| 177 |
-
center: bool = True,
|
| 178 |
-
normalized: bool = True,
|
| 179 |
-
pad_mode: str = "constant",
|
| 180 |
-
onesided: bool = True,
|
| 181 |
-
n_bands: int = None,
|
| 182 |
-
) -> None:
|
| 183 |
-
super().__init__(
|
| 184 |
-
n_fft=n_fft,
|
| 185 |
-
win_length=win_length,
|
| 186 |
-
hop_length=hop_length,
|
| 187 |
-
window_fn=window_fn,
|
| 188 |
-
wkwargs=wkwargs,
|
| 189 |
-
power=power,
|
| 190 |
-
center=center,
|
| 191 |
-
normalized=normalized,
|
| 192 |
-
pad_mode=pad_mode,
|
| 193 |
-
onesided=onesided,
|
| 194 |
-
)
|
| 195 |
-
|
| 196 |
-
if isinstance(band_specs, str):
|
| 197 |
-
self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
|
| 198 |
-
band_specs, n_fft, fs, n_bands
|
| 199 |
-
)
|
| 200 |
-
|
| 201 |
-
self.stems = stems
|
| 202 |
-
|
| 203 |
-
def forward(self, batch):
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
output =
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
)
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
)
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
self.
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
self.
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
)
|
|
|
|
| 1 |
+
from pprint import pprint
|
| 2 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from .._spectral import _SpectralComponent
|
| 8 |
+
from .utils import (
|
| 9 |
+
BarkBandsplitSpecification,
|
| 10 |
+
BassBandsplitSpecification,
|
| 11 |
+
DrumBandsplitSpecification,
|
| 12 |
+
EquivalentRectangularBandsplitSpecification,
|
| 13 |
+
MelBandsplitSpecification,
|
| 14 |
+
MusicalBandsplitSpecification,
|
| 15 |
+
OtherBandsplitSpecification,
|
| 16 |
+
TriangularBarkBandsplitSpecification,
|
| 17 |
+
VocalBandsplitSpecification,
|
| 18 |
+
)
|
| 19 |
+
from .core import (
|
| 20 |
+
MultiSourceMultiMaskBandSplitCoreConv,
|
| 21 |
+
MultiSourceMultiMaskBandSplitCoreRNN,
|
| 22 |
+
MultiSourceMultiMaskBandSplitCoreTransformer,
|
| 23 |
+
MultiSourceMultiPatchingMaskBandSplitCoreRNN,
|
| 24 |
+
SingleMaskBandsplitCoreRNN,
|
| 25 |
+
SingleMaskBandsplitCoreTransformer,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
import pytorch_lightning as pl
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_band_specs(band_specs, n_fft, fs, n_bands=None):
|
| 32 |
+
if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]:
|
| 33 |
+
bsm = VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs()
|
| 34 |
+
freq_weights = None
|
| 35 |
+
overlapping_band = False
|
| 36 |
+
elif "tribark" in band_specs:
|
| 37 |
+
assert n_bands is not None
|
| 38 |
+
specs = TriangularBarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 39 |
+
bsm = specs.get_band_specs()
|
| 40 |
+
freq_weights = specs.get_freq_weights()
|
| 41 |
+
overlapping_band = True
|
| 42 |
+
elif "bark" in band_specs:
|
| 43 |
+
assert n_bands is not None
|
| 44 |
+
specs = BarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 45 |
+
bsm = specs.get_band_specs()
|
| 46 |
+
freq_weights = specs.get_freq_weights()
|
| 47 |
+
overlapping_band = True
|
| 48 |
+
elif "erb" in band_specs:
|
| 49 |
+
assert n_bands is not None
|
| 50 |
+
specs = EquivalentRectangularBandsplitSpecification(
|
| 51 |
+
nfft=n_fft, fs=fs, n_bands=n_bands
|
| 52 |
+
)
|
| 53 |
+
bsm = specs.get_band_specs()
|
| 54 |
+
freq_weights = specs.get_freq_weights()
|
| 55 |
+
overlapping_band = True
|
| 56 |
+
elif "musical" in band_specs:
|
| 57 |
+
assert n_bands is not None
|
| 58 |
+
specs = MusicalBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 59 |
+
bsm = specs.get_band_specs()
|
| 60 |
+
freq_weights = specs.get_freq_weights()
|
| 61 |
+
overlapping_band = True
|
| 62 |
+
elif band_specs == "dnr:mel" or "mel" in band_specs:
|
| 63 |
+
assert n_bands is not None
|
| 64 |
+
specs = MelBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 65 |
+
bsm = specs.get_band_specs()
|
| 66 |
+
freq_weights = specs.get_freq_weights()
|
| 67 |
+
overlapping_band = True
|
| 68 |
+
else:
|
| 69 |
+
raise NameError
|
| 70 |
+
|
| 71 |
+
return bsm, freq_weights, overlapping_band
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None):
|
| 75 |
+
if band_specs_map == "musdb:all":
|
| 76 |
+
bsm = {
|
| 77 |
+
"vocals": VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 78 |
+
"drums": DrumBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 79 |
+
"bass": BassBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 80 |
+
"other": OtherBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 81 |
+
}
|
| 82 |
+
freq_weights = None
|
| 83 |
+
overlapping_band = False
|
| 84 |
+
elif band_specs_map == "dnr:vox7":
|
| 85 |
+
bsm_, freq_weights, overlapping_band = get_band_specs(
|
| 86 |
+
"dnr:speech", n_fft, fs, n_bands
|
| 87 |
+
)
|
| 88 |
+
bsm = {"speech": bsm_, "music": bsm_, "effects": bsm_}
|
| 89 |
+
elif "dnr:vox7:" in band_specs_map:
|
| 90 |
+
stem = band_specs_map.split(":")[-1]
|
| 91 |
+
bsm_, freq_weights, overlapping_band = get_band_specs(
|
| 92 |
+
"dnr:speech", n_fft, fs, n_bands
|
| 93 |
+
)
|
| 94 |
+
bsm = {stem: bsm_}
|
| 95 |
+
else:
|
| 96 |
+
raise NameError
|
| 97 |
+
|
| 98 |
+
return bsm, freq_weights, overlapping_band
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class BandSplitWrapperBase(pl.LightningModule):
|
| 102 |
+
bsrnn: nn.Module
|
| 103 |
+
|
| 104 |
+
def __init__(self, **kwargs):
|
| 105 |
+
super().__init__()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class SingleMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
|
| 112 |
+
fs: int = 44100,
|
| 113 |
+
n_fft: int = 2048,
|
| 114 |
+
win_length: Optional[int] = 2048,
|
| 115 |
+
hop_length: int = 512,
|
| 116 |
+
window_fn: str = "hann_window",
|
| 117 |
+
wkwargs: Optional[Dict] = None,
|
| 118 |
+
power: Optional[int] = None,
|
| 119 |
+
center: bool = True,
|
| 120 |
+
normalized: bool = True,
|
| 121 |
+
pad_mode: str = "constant",
|
| 122 |
+
onesided: bool = True,
|
| 123 |
+
n_bands: int = None,
|
| 124 |
+
) -> None:
|
| 125 |
+
super().__init__(
|
| 126 |
+
n_fft=n_fft,
|
| 127 |
+
win_length=win_length,
|
| 128 |
+
hop_length=hop_length,
|
| 129 |
+
window_fn=window_fn,
|
| 130 |
+
wkwargs=wkwargs,
|
| 131 |
+
power=power,
|
| 132 |
+
center=center,
|
| 133 |
+
normalized=normalized,
|
| 134 |
+
pad_mode=pad_mode,
|
| 135 |
+
onesided=onesided,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
if isinstance(band_specs_map, str):
|
| 139 |
+
self.band_specs_map, self.freq_weights, self.overlapping_band = (
|
| 140 |
+
get_band_specs_map(band_specs_map, n_fft, fs, n_bands=n_bands)
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
self.stems = list(self.band_specs_map.keys())
|
| 144 |
+
|
| 145 |
+
def forward(self, batch):
|
| 146 |
+
audio = batch["audio"]
|
| 147 |
+
|
| 148 |
+
with torch.no_grad():
|
| 149 |
+
batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
|
| 150 |
+
|
| 151 |
+
X = batch["spectrogram"]["mixture"]
|
| 152 |
+
length = batch["audio"]["mixture"].shape[-1]
|
| 153 |
+
|
| 154 |
+
output = {"spectrogram": {}, "audio": {}}
|
| 155 |
+
|
| 156 |
+
for stem, bsrnn in self.bsrnn.items():
|
| 157 |
+
S = bsrnn(X)
|
| 158 |
+
s = self.istft(S, length)
|
| 159 |
+
output["spectrogram"][stem] = S
|
| 160 |
+
output["audio"][stem] = s
|
| 161 |
+
|
| 162 |
+
return batch, output
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class MultiMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
stems: List[str],
|
| 169 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 170 |
+
fs: int = 44100,
|
| 171 |
+
n_fft: int = 2048,
|
| 172 |
+
win_length: Optional[int] = 2048,
|
| 173 |
+
hop_length: int = 512,
|
| 174 |
+
window_fn: str = "hann_window",
|
| 175 |
+
wkwargs: Optional[Dict] = None,
|
| 176 |
+
power: Optional[int] = None,
|
| 177 |
+
center: bool = True,
|
| 178 |
+
normalized: bool = True,
|
| 179 |
+
pad_mode: str = "constant",
|
| 180 |
+
onesided: bool = True,
|
| 181 |
+
n_bands: int = None,
|
| 182 |
+
) -> None:
|
| 183 |
+
super().__init__(
|
| 184 |
+
n_fft=n_fft,
|
| 185 |
+
win_length=win_length,
|
| 186 |
+
hop_length=hop_length,
|
| 187 |
+
window_fn=window_fn,
|
| 188 |
+
wkwargs=wkwargs,
|
| 189 |
+
power=power,
|
| 190 |
+
center=center,
|
| 191 |
+
normalized=normalized,
|
| 192 |
+
pad_mode=pad_mode,
|
| 193 |
+
onesided=onesided,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
if isinstance(band_specs, str):
|
| 197 |
+
self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
|
| 198 |
+
band_specs, n_fft, fs, n_bands
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
self.stems = stems
|
| 202 |
+
|
| 203 |
+
def forward(self, batch):
|
| 204 |
+
audio = batch["audio"]
|
| 205 |
+
cond = batch.get("condition", None)
|
| 206 |
+
with torch.no_grad():
|
| 207 |
+
batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
|
| 208 |
+
|
| 209 |
+
X = batch["spectrogram"]["mixture"]
|
| 210 |
+
length = batch["audio"]["mixture"].shape[-1]
|
| 211 |
+
|
| 212 |
+
output = self.bsrnn(X, cond=cond)
|
| 213 |
+
output["audio"] = {}
|
| 214 |
+
|
| 215 |
+
for stem, S in output["spectrogram"].items():
|
| 216 |
+
s = self.istft(S, length)
|
| 217 |
+
output["audio"][stem] = s
|
| 218 |
+
|
| 219 |
+
return batch, output
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class MultiMaskMultiSourceBandSplitBaseSimple(BandSplitWrapperBase, _SpectralComponent):
|
| 223 |
+
def __init__(
|
| 224 |
+
self,
|
| 225 |
+
stems: List[str],
|
| 226 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 227 |
+
fs: int = 44100,
|
| 228 |
+
n_fft: int = 2048,
|
| 229 |
+
win_length: Optional[int] = 2048,
|
| 230 |
+
hop_length: int = 512,
|
| 231 |
+
window_fn: str = "hann_window",
|
| 232 |
+
wkwargs: Optional[Dict] = None,
|
| 233 |
+
power: Optional[int] = None,
|
| 234 |
+
center: bool = True,
|
| 235 |
+
normalized: bool = True,
|
| 236 |
+
pad_mode: str = "constant",
|
| 237 |
+
onesided: bool = True,
|
| 238 |
+
n_bands: int = None,
|
| 239 |
+
) -> None:
|
| 240 |
+
super().__init__(
|
| 241 |
+
n_fft=n_fft,
|
| 242 |
+
win_length=win_length,
|
| 243 |
+
hop_length=hop_length,
|
| 244 |
+
window_fn=window_fn,
|
| 245 |
+
wkwargs=wkwargs,
|
| 246 |
+
power=power,
|
| 247 |
+
center=center,
|
| 248 |
+
normalized=normalized,
|
| 249 |
+
pad_mode=pad_mode,
|
| 250 |
+
onesided=onesided,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
if isinstance(band_specs, str):
|
| 254 |
+
self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
|
| 255 |
+
band_specs, n_fft, fs, n_bands
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
self.stems = stems
|
| 259 |
+
|
| 260 |
+
def forward(self, batch):
|
| 261 |
+
with torch.no_grad():
|
| 262 |
+
X = self.stft(batch)
|
| 263 |
+
length = batch.shape[-1]
|
| 264 |
+
output = self.bsrnn(X, cond=None)
|
| 265 |
+
res = []
|
| 266 |
+
for stem, S in output["spectrogram"].items():
|
| 267 |
+
s = self.istft(S, length)
|
| 268 |
+
res.append(s)
|
| 269 |
+
res = torch.stack(res, dim=1)
|
| 270 |
+
return res
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase):
|
| 274 |
+
def __init__(
|
| 275 |
+
self,
|
| 276 |
+
in_channel: int,
|
| 277 |
+
band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
|
| 278 |
+
fs: int = 44100,
|
| 279 |
+
require_no_overlap: bool = False,
|
| 280 |
+
require_no_gap: bool = True,
|
| 281 |
+
normalize_channel_independently: bool = False,
|
| 282 |
+
treat_channel_as_feature: bool = True,
|
| 283 |
+
n_sqm_modules: int = 12,
|
| 284 |
+
emb_dim: int = 128,
|
| 285 |
+
rnn_dim: int = 256,
|
| 286 |
+
bidirectional: bool = True,
|
| 287 |
+
rnn_type: str = "LSTM",
|
| 288 |
+
mlp_dim: int = 512,
|
| 289 |
+
hidden_activation: str = "Tanh",
|
| 290 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 291 |
+
complex_mask: bool = True,
|
| 292 |
+
n_fft: int = 2048,
|
| 293 |
+
win_length: Optional[int] = 2048,
|
| 294 |
+
hop_length: int = 512,
|
| 295 |
+
window_fn: str = "hann_window",
|
| 296 |
+
wkwargs: Optional[Dict] = None,
|
| 297 |
+
power: Optional[int] = None,
|
| 298 |
+
center: bool = True,
|
| 299 |
+
normalized: bool = True,
|
| 300 |
+
pad_mode: str = "constant",
|
| 301 |
+
onesided: bool = True,
|
| 302 |
+
) -> None:
|
| 303 |
+
super().__init__(
|
| 304 |
+
band_specs_map=band_specs_map,
|
| 305 |
+
fs=fs,
|
| 306 |
+
n_fft=n_fft,
|
| 307 |
+
win_length=win_length,
|
| 308 |
+
hop_length=hop_length,
|
| 309 |
+
window_fn=window_fn,
|
| 310 |
+
wkwargs=wkwargs,
|
| 311 |
+
power=power,
|
| 312 |
+
center=center,
|
| 313 |
+
normalized=normalized,
|
| 314 |
+
pad_mode=pad_mode,
|
| 315 |
+
onesided=onesided,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
self.bsrnn = nn.ModuleDict(
|
| 319 |
+
{
|
| 320 |
+
src: SingleMaskBandsplitCoreRNN(
|
| 321 |
+
band_specs=specs,
|
| 322 |
+
in_channel=in_channel,
|
| 323 |
+
require_no_overlap=require_no_overlap,
|
| 324 |
+
require_no_gap=require_no_gap,
|
| 325 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 326 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 327 |
+
n_sqm_modules=n_sqm_modules,
|
| 328 |
+
emb_dim=emb_dim,
|
| 329 |
+
rnn_dim=rnn_dim,
|
| 330 |
+
bidirectional=bidirectional,
|
| 331 |
+
rnn_type=rnn_type,
|
| 332 |
+
mlp_dim=mlp_dim,
|
| 333 |
+
hidden_activation=hidden_activation,
|
| 334 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 335 |
+
complex_mask=complex_mask,
|
| 336 |
+
)
|
| 337 |
+
for src, specs in self.band_specs_map.items()
|
| 338 |
+
}
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class SingleMaskMultiSourceBandSplitTransformer(SingleMaskMultiSourceBandSplitBase):
|
| 343 |
+
def __init__(
|
| 344 |
+
self,
|
| 345 |
+
in_channel: int,
|
| 346 |
+
band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
|
| 347 |
+
fs: int = 44100,
|
| 348 |
+
require_no_overlap: bool = False,
|
| 349 |
+
require_no_gap: bool = True,
|
| 350 |
+
normalize_channel_independently: bool = False,
|
| 351 |
+
treat_channel_as_feature: bool = True,
|
| 352 |
+
n_sqm_modules: int = 12,
|
| 353 |
+
emb_dim: int = 128,
|
| 354 |
+
rnn_dim: int = 256,
|
| 355 |
+
bidirectional: bool = True,
|
| 356 |
+
tf_dropout: float = 0.0,
|
| 357 |
+
mlp_dim: int = 512,
|
| 358 |
+
hidden_activation: str = "Tanh",
|
| 359 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 360 |
+
complex_mask: bool = True,
|
| 361 |
+
n_fft: int = 2048,
|
| 362 |
+
win_length: Optional[int] = 2048,
|
| 363 |
+
hop_length: int = 512,
|
| 364 |
+
window_fn: str = "hann_window",
|
| 365 |
+
wkwargs: Optional[Dict] = None,
|
| 366 |
+
power: Optional[int] = None,
|
| 367 |
+
center: bool = True,
|
| 368 |
+
normalized: bool = True,
|
| 369 |
+
pad_mode: str = "constant",
|
| 370 |
+
onesided: bool = True,
|
| 371 |
+
) -> None:
|
| 372 |
+
super().__init__(
|
| 373 |
+
band_specs_map=band_specs_map,
|
| 374 |
+
fs=fs,
|
| 375 |
+
n_fft=n_fft,
|
| 376 |
+
win_length=win_length,
|
| 377 |
+
hop_length=hop_length,
|
| 378 |
+
window_fn=window_fn,
|
| 379 |
+
wkwargs=wkwargs,
|
| 380 |
+
power=power,
|
| 381 |
+
center=center,
|
| 382 |
+
normalized=normalized,
|
| 383 |
+
pad_mode=pad_mode,
|
| 384 |
+
onesided=onesided,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
self.bsrnn = nn.ModuleDict(
|
| 388 |
+
{
|
| 389 |
+
src: SingleMaskBandsplitCoreTransformer(
|
| 390 |
+
band_specs=specs,
|
| 391 |
+
in_channel=in_channel,
|
| 392 |
+
require_no_overlap=require_no_overlap,
|
| 393 |
+
require_no_gap=require_no_gap,
|
| 394 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 395 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 396 |
+
n_sqm_modules=n_sqm_modules,
|
| 397 |
+
emb_dim=emb_dim,
|
| 398 |
+
rnn_dim=rnn_dim,
|
| 399 |
+
bidirectional=bidirectional,
|
| 400 |
+
tf_dropout=tf_dropout,
|
| 401 |
+
mlp_dim=mlp_dim,
|
| 402 |
+
hidden_activation=hidden_activation,
|
| 403 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 404 |
+
complex_mask=complex_mask,
|
| 405 |
+
)
|
| 406 |
+
for src, specs in self.band_specs_map.items()
|
| 407 |
+
}
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
|
| 412 |
+
def __init__(
|
| 413 |
+
self,
|
| 414 |
+
in_channel: int,
|
| 415 |
+
stems: List[str],
|
| 416 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 417 |
+
fs: int = 44100,
|
| 418 |
+
require_no_overlap: bool = False,
|
| 419 |
+
require_no_gap: bool = True,
|
| 420 |
+
normalize_channel_independently: bool = False,
|
| 421 |
+
treat_channel_as_feature: bool = True,
|
| 422 |
+
n_sqm_modules: int = 12,
|
| 423 |
+
emb_dim: int = 128,
|
| 424 |
+
rnn_dim: int = 256,
|
| 425 |
+
cond_dim: int = 0,
|
| 426 |
+
bidirectional: bool = True,
|
| 427 |
+
rnn_type: str = "LSTM",
|
| 428 |
+
mlp_dim: int = 512,
|
| 429 |
+
hidden_activation: str = "Tanh",
|
| 430 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 431 |
+
complex_mask: bool = True,
|
| 432 |
+
n_fft: int = 2048,
|
| 433 |
+
win_length: Optional[int] = 2048,
|
| 434 |
+
hop_length: int = 512,
|
| 435 |
+
window_fn: str = "hann_window",
|
| 436 |
+
wkwargs: Optional[Dict] = None,
|
| 437 |
+
power: Optional[int] = None,
|
| 438 |
+
center: bool = True,
|
| 439 |
+
normalized: bool = True,
|
| 440 |
+
pad_mode: str = "constant",
|
| 441 |
+
onesided: bool = True,
|
| 442 |
+
n_bands: int = None,
|
| 443 |
+
use_freq_weights: bool = True,
|
| 444 |
+
normalize_input: bool = False,
|
| 445 |
+
mult_add_mask: bool = False,
|
| 446 |
+
freeze_encoder: bool = False,
|
| 447 |
+
) -> None:
|
| 448 |
+
super().__init__(
|
| 449 |
+
stems=stems,
|
| 450 |
+
band_specs=band_specs,
|
| 451 |
+
fs=fs,
|
| 452 |
+
n_fft=n_fft,
|
| 453 |
+
win_length=win_length,
|
| 454 |
+
hop_length=hop_length,
|
| 455 |
+
window_fn=window_fn,
|
| 456 |
+
wkwargs=wkwargs,
|
| 457 |
+
power=power,
|
| 458 |
+
center=center,
|
| 459 |
+
normalized=normalized,
|
| 460 |
+
pad_mode=pad_mode,
|
| 461 |
+
onesided=onesided,
|
| 462 |
+
n_bands=n_bands,
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
|
| 466 |
+
stems=stems,
|
| 467 |
+
band_specs=self.band_specs,
|
| 468 |
+
in_channel=in_channel,
|
| 469 |
+
require_no_overlap=require_no_overlap,
|
| 470 |
+
require_no_gap=require_no_gap,
|
| 471 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 472 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 473 |
+
n_sqm_modules=n_sqm_modules,
|
| 474 |
+
emb_dim=emb_dim,
|
| 475 |
+
rnn_dim=rnn_dim,
|
| 476 |
+
bidirectional=bidirectional,
|
| 477 |
+
rnn_type=rnn_type,
|
| 478 |
+
mlp_dim=mlp_dim,
|
| 479 |
+
cond_dim=cond_dim,
|
| 480 |
+
hidden_activation=hidden_activation,
|
| 481 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 482 |
+
complex_mask=complex_mask,
|
| 483 |
+
overlapping_band=self.overlapping_band,
|
| 484 |
+
freq_weights=self.freq_weights,
|
| 485 |
+
n_freq=n_fft // 2 + 1,
|
| 486 |
+
use_freq_weights=use_freq_weights,
|
| 487 |
+
mult_add_mask=mult_add_mask,
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
self.normalize_input = normalize_input
|
| 491 |
+
self.cond_dim = cond_dim
|
| 492 |
+
|
| 493 |
+
if freeze_encoder:
|
| 494 |
+
for param in self.bsrnn.band_split.parameters():
|
| 495 |
+
param.requires_grad = False
|
| 496 |
+
|
| 497 |
+
for param in self.bsrnn.tf_model.parameters():
|
| 498 |
+
param.requires_grad = False
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple):
|
| 502 |
+
def __init__(
|
| 503 |
+
self,
|
| 504 |
+
in_channel: int,
|
| 505 |
+
stems: List[str],
|
| 506 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 507 |
+
fs: int = 44100,
|
| 508 |
+
require_no_overlap: bool = False,
|
| 509 |
+
require_no_gap: bool = True,
|
| 510 |
+
normalize_channel_independently: bool = False,
|
| 511 |
+
treat_channel_as_feature: bool = True,
|
| 512 |
+
n_sqm_modules: int = 12,
|
| 513 |
+
emb_dim: int = 128,
|
| 514 |
+
rnn_dim: int = 256,
|
| 515 |
+
cond_dim: int = 0,
|
| 516 |
+
bidirectional: bool = True,
|
| 517 |
+
rnn_type: str = "LSTM",
|
| 518 |
+
mlp_dim: int = 512,
|
| 519 |
+
hidden_activation: str = "Tanh",
|
| 520 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 521 |
+
complex_mask: bool = True,
|
| 522 |
+
n_fft: int = 2048,
|
| 523 |
+
win_length: Optional[int] = 2048,
|
| 524 |
+
hop_length: int = 512,
|
| 525 |
+
window_fn: str = "hann_window",
|
| 526 |
+
wkwargs: Optional[Dict] = None,
|
| 527 |
+
power: Optional[int] = None,
|
| 528 |
+
center: bool = True,
|
| 529 |
+
normalized: bool = True,
|
| 530 |
+
pad_mode: str = "constant",
|
| 531 |
+
onesided: bool = True,
|
| 532 |
+
n_bands: int = None,
|
| 533 |
+
use_freq_weights: bool = True,
|
| 534 |
+
normalize_input: bool = False,
|
| 535 |
+
mult_add_mask: bool = False,
|
| 536 |
+
freeze_encoder: bool = False,
|
| 537 |
+
) -> None:
|
| 538 |
+
super().__init__(
|
| 539 |
+
stems=stems,
|
| 540 |
+
band_specs=band_specs,
|
| 541 |
+
fs=fs,
|
| 542 |
+
n_fft=n_fft,
|
| 543 |
+
win_length=win_length,
|
| 544 |
+
hop_length=hop_length,
|
| 545 |
+
window_fn=window_fn,
|
| 546 |
+
wkwargs=wkwargs,
|
| 547 |
+
power=power,
|
| 548 |
+
center=center,
|
| 549 |
+
normalized=normalized,
|
| 550 |
+
pad_mode=pad_mode,
|
| 551 |
+
onesided=onesided,
|
| 552 |
+
n_bands=n_bands,
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
|
| 556 |
+
stems=stems,
|
| 557 |
+
band_specs=self.band_specs,
|
| 558 |
+
in_channel=in_channel,
|
| 559 |
+
require_no_overlap=require_no_overlap,
|
| 560 |
+
require_no_gap=require_no_gap,
|
| 561 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 562 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 563 |
+
n_sqm_modules=n_sqm_modules,
|
| 564 |
+
emb_dim=emb_dim,
|
| 565 |
+
rnn_dim=rnn_dim,
|
| 566 |
+
bidirectional=bidirectional,
|
| 567 |
+
rnn_type=rnn_type,
|
| 568 |
+
mlp_dim=mlp_dim,
|
| 569 |
+
cond_dim=cond_dim,
|
| 570 |
+
hidden_activation=hidden_activation,
|
| 571 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 572 |
+
complex_mask=complex_mask,
|
| 573 |
+
overlapping_band=self.overlapping_band,
|
| 574 |
+
freq_weights=self.freq_weights,
|
| 575 |
+
n_freq=n_fft // 2 + 1,
|
| 576 |
+
use_freq_weights=use_freq_weights,
|
| 577 |
+
mult_add_mask=mult_add_mask,
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
self.normalize_input = normalize_input
|
| 581 |
+
self.cond_dim = cond_dim
|
| 582 |
+
|
| 583 |
+
if freeze_encoder:
|
| 584 |
+
for param in self.bsrnn.band_split.parameters():
|
| 585 |
+
param.requires_grad = False
|
| 586 |
+
|
| 587 |
+
for param in self.bsrnn.tf_model.parameters():
|
| 588 |
+
param.requires_grad = False
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
class MultiMaskMultiSourceBandSplitTransformer(MultiMaskMultiSourceBandSplitBase):
|
| 592 |
+
def __init__(
|
| 593 |
+
self,
|
| 594 |
+
in_channel: int,
|
| 595 |
+
stems: List[str],
|
| 596 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 597 |
+
fs: int = 44100,
|
| 598 |
+
require_no_overlap: bool = False,
|
| 599 |
+
require_no_gap: bool = True,
|
| 600 |
+
normalize_channel_independently: bool = False,
|
| 601 |
+
treat_channel_as_feature: bool = True,
|
| 602 |
+
n_sqm_modules: int = 12,
|
| 603 |
+
emb_dim: int = 128,
|
| 604 |
+
rnn_dim: int = 256,
|
| 605 |
+
cond_dim: int = 0,
|
| 606 |
+
bidirectional: bool = True,
|
| 607 |
+
rnn_type: str = "LSTM",
|
| 608 |
+
mlp_dim: int = 512,
|
| 609 |
+
hidden_activation: str = "Tanh",
|
| 610 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 611 |
+
complex_mask: bool = True,
|
| 612 |
+
n_fft: int = 2048,
|
| 613 |
+
win_length: Optional[int] = 2048,
|
| 614 |
+
hop_length: int = 512,
|
| 615 |
+
window_fn: str = "hann_window",
|
| 616 |
+
wkwargs: Optional[Dict] = None,
|
| 617 |
+
power: Optional[int] = None,
|
| 618 |
+
center: bool = True,
|
| 619 |
+
normalized: bool = True,
|
| 620 |
+
pad_mode: str = "constant",
|
| 621 |
+
onesided: bool = True,
|
| 622 |
+
n_bands: int = None,
|
| 623 |
+
use_freq_weights: bool = True,
|
| 624 |
+
normalize_input: bool = False,
|
| 625 |
+
mult_add_mask: bool = False,
|
| 626 |
+
) -> None:
|
| 627 |
+
super().__init__(
|
| 628 |
+
stems=stems,
|
| 629 |
+
band_specs=band_specs,
|
| 630 |
+
fs=fs,
|
| 631 |
+
n_fft=n_fft,
|
| 632 |
+
win_length=win_length,
|
| 633 |
+
hop_length=hop_length,
|
| 634 |
+
window_fn=window_fn,
|
| 635 |
+
wkwargs=wkwargs,
|
| 636 |
+
power=power,
|
| 637 |
+
center=center,
|
| 638 |
+
normalized=normalized,
|
| 639 |
+
pad_mode=pad_mode,
|
| 640 |
+
onesided=onesided,
|
| 641 |
+
n_bands=n_bands,
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer(
|
| 645 |
+
stems=stems,
|
| 646 |
+
band_specs=self.band_specs,
|
| 647 |
+
in_channel=in_channel,
|
| 648 |
+
require_no_overlap=require_no_overlap,
|
| 649 |
+
require_no_gap=require_no_gap,
|
| 650 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 651 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 652 |
+
n_sqm_modules=n_sqm_modules,
|
| 653 |
+
emb_dim=emb_dim,
|
| 654 |
+
rnn_dim=rnn_dim,
|
| 655 |
+
bidirectional=bidirectional,
|
| 656 |
+
rnn_type=rnn_type,
|
| 657 |
+
mlp_dim=mlp_dim,
|
| 658 |
+
cond_dim=cond_dim,
|
| 659 |
+
hidden_activation=hidden_activation,
|
| 660 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 661 |
+
complex_mask=complex_mask,
|
| 662 |
+
overlapping_band=self.overlapping_band,
|
| 663 |
+
freq_weights=self.freq_weights,
|
| 664 |
+
n_freq=n_fft // 2 + 1,
|
| 665 |
+
use_freq_weights=use_freq_weights,
|
| 666 |
+
mult_add_mask=mult_add_mask,
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
class MultiMaskMultiSourceBandSplitConv(MultiMaskMultiSourceBandSplitBase):
|
| 671 |
+
def __init__(
|
| 672 |
+
self,
|
| 673 |
+
in_channel: int,
|
| 674 |
+
stems: List[str],
|
| 675 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 676 |
+
fs: int = 44100,
|
| 677 |
+
require_no_overlap: bool = False,
|
| 678 |
+
require_no_gap: bool = True,
|
| 679 |
+
normalize_channel_independently: bool = False,
|
| 680 |
+
treat_channel_as_feature: bool = True,
|
| 681 |
+
n_sqm_modules: int = 12,
|
| 682 |
+
emb_dim: int = 128,
|
| 683 |
+
rnn_dim: int = 256,
|
| 684 |
+
cond_dim: int = 0,
|
| 685 |
+
bidirectional: bool = True,
|
| 686 |
+
rnn_type: str = "LSTM",
|
| 687 |
+
mlp_dim: int = 512,
|
| 688 |
+
hidden_activation: str = "Tanh",
|
| 689 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 690 |
+
complex_mask: bool = True,
|
| 691 |
+
n_fft: int = 2048,
|
| 692 |
+
win_length: Optional[int] = 2048,
|
| 693 |
+
hop_length: int = 512,
|
| 694 |
+
window_fn: str = "hann_window",
|
| 695 |
+
wkwargs: Optional[Dict] = None,
|
| 696 |
+
power: Optional[int] = None,
|
| 697 |
+
center: bool = True,
|
| 698 |
+
normalized: bool = True,
|
| 699 |
+
pad_mode: str = "constant",
|
| 700 |
+
onesided: bool = True,
|
| 701 |
+
n_bands: int = None,
|
| 702 |
+
use_freq_weights: bool = True,
|
| 703 |
+
normalize_input: bool = False,
|
| 704 |
+
mult_add_mask: bool = False,
|
| 705 |
+
) -> None:
|
| 706 |
+
super().__init__(
|
| 707 |
+
stems=stems,
|
| 708 |
+
band_specs=band_specs,
|
| 709 |
+
fs=fs,
|
| 710 |
+
n_fft=n_fft,
|
| 711 |
+
win_length=win_length,
|
| 712 |
+
hop_length=hop_length,
|
| 713 |
+
window_fn=window_fn,
|
| 714 |
+
wkwargs=wkwargs,
|
| 715 |
+
power=power,
|
| 716 |
+
center=center,
|
| 717 |
+
normalized=normalized,
|
| 718 |
+
pad_mode=pad_mode,
|
| 719 |
+
onesided=onesided,
|
| 720 |
+
n_bands=n_bands,
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv(
|
| 724 |
+
stems=stems,
|
| 725 |
+
band_specs=self.band_specs,
|
| 726 |
+
in_channel=in_channel,
|
| 727 |
+
require_no_overlap=require_no_overlap,
|
| 728 |
+
require_no_gap=require_no_gap,
|
| 729 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 730 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 731 |
+
n_sqm_modules=n_sqm_modules,
|
| 732 |
+
emb_dim=emb_dim,
|
| 733 |
+
rnn_dim=rnn_dim,
|
| 734 |
+
bidirectional=bidirectional,
|
| 735 |
+
rnn_type=rnn_type,
|
| 736 |
+
mlp_dim=mlp_dim,
|
| 737 |
+
cond_dim=cond_dim,
|
| 738 |
+
hidden_activation=hidden_activation,
|
| 739 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 740 |
+
complex_mask=complex_mask,
|
| 741 |
+
overlapping_band=self.overlapping_band,
|
| 742 |
+
freq_weights=self.freq_weights,
|
| 743 |
+
n_freq=n_fft // 2 + 1,
|
| 744 |
+
use_freq_weights=use_freq_weights,
|
| 745 |
+
mult_add_mask=mult_add_mask,
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
|
| 750 |
+
def __init__(
|
| 751 |
+
self,
|
| 752 |
+
in_channel: int,
|
| 753 |
+
stems: List[str],
|
| 754 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 755 |
+
kernel_norm_mlp_version: int = 1,
|
| 756 |
+
mask_kernel_freq: int = 3,
|
| 757 |
+
mask_kernel_time: int = 3,
|
| 758 |
+
conv_kernel_freq: int = 1,
|
| 759 |
+
conv_kernel_time: int = 1,
|
| 760 |
+
fs: int = 44100,
|
| 761 |
+
require_no_overlap: bool = False,
|
| 762 |
+
require_no_gap: bool = True,
|
| 763 |
+
normalize_channel_independently: bool = False,
|
| 764 |
+
treat_channel_as_feature: bool = True,
|
| 765 |
+
n_sqm_modules: int = 12,
|
| 766 |
+
emb_dim: int = 128,
|
| 767 |
+
rnn_dim: int = 256,
|
| 768 |
+
bidirectional: bool = True,
|
| 769 |
+
rnn_type: str = "LSTM",
|
| 770 |
+
mlp_dim: int = 512,
|
| 771 |
+
hidden_activation: str = "Tanh",
|
| 772 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 773 |
+
complex_mask: bool = True,
|
| 774 |
+
n_fft: int = 2048,
|
| 775 |
+
win_length: Optional[int] = 2048,
|
| 776 |
+
hop_length: int = 512,
|
| 777 |
+
window_fn: str = "hann_window",
|
| 778 |
+
wkwargs: Optional[Dict] = None,
|
| 779 |
+
power: Optional[int] = None,
|
| 780 |
+
center: bool = True,
|
| 781 |
+
normalized: bool = True,
|
| 782 |
+
pad_mode: str = "constant",
|
| 783 |
+
onesided: bool = True,
|
| 784 |
+
n_bands: int = None,
|
| 785 |
+
) -> None:
|
| 786 |
+
super().__init__(
|
| 787 |
+
stems=stems,
|
| 788 |
+
band_specs=band_specs,
|
| 789 |
+
fs=fs,
|
| 790 |
+
n_fft=n_fft,
|
| 791 |
+
win_length=win_length,
|
| 792 |
+
hop_length=hop_length,
|
| 793 |
+
window_fn=window_fn,
|
| 794 |
+
wkwargs=wkwargs,
|
| 795 |
+
power=power,
|
| 796 |
+
center=center,
|
| 797 |
+
normalized=normalized,
|
| 798 |
+
pad_mode=pad_mode,
|
| 799 |
+
onesided=onesided,
|
| 800 |
+
n_bands=n_bands,
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN(
|
| 804 |
+
stems=stems,
|
| 805 |
+
band_specs=self.band_specs,
|
| 806 |
+
in_channel=in_channel,
|
| 807 |
+
require_no_overlap=require_no_overlap,
|
| 808 |
+
require_no_gap=require_no_gap,
|
| 809 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 810 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 811 |
+
n_sqm_modules=n_sqm_modules,
|
| 812 |
+
emb_dim=emb_dim,
|
| 813 |
+
rnn_dim=rnn_dim,
|
| 814 |
+
bidirectional=bidirectional,
|
| 815 |
+
rnn_type=rnn_type,
|
| 816 |
+
mlp_dim=mlp_dim,
|
| 817 |
+
hidden_activation=hidden_activation,
|
| 818 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 819 |
+
complex_mask=complex_mask,
|
| 820 |
+
overlapping_band=self.overlapping_band,
|
| 821 |
+
freq_weights=self.freq_weights,
|
| 822 |
+
n_freq=n_fft // 2 + 1,
|
| 823 |
+
mask_kernel_freq=mask_kernel_freq,
|
| 824 |
+
mask_kernel_time=mask_kernel_time,
|
| 825 |
+
conv_kernel_freq=conv_kernel_freq,
|
| 826 |
+
conv_kernel_time=conv_kernel_time,
|
| 827 |
+
kernel_norm_mlp_version=kernel_norm_mlp_version,
|
| 828 |
+
)
|
|
|
mvsepless/models/bandit/core/utils/audio.py
CHANGED
|
@@ -1,412 +1,324 @@
|
|
| 1 |
-
from collections import defaultdict
|
| 2 |
-
|
| 3 |
-
from tqdm.auto import tqdm
|
| 4 |
-
from typing import Callable, Dict, List, Optional, Tuple
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
-
from torch import nn
|
| 9 |
-
from torch.nn import functional as F
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
@torch.jit.script
|
| 13 |
-
def merge(
|
| 14 |
-
combined: torch.Tensor,
|
| 15 |
-
original_batch_size: int,
|
| 16 |
-
n_channel: int,
|
| 17 |
-
n_chunks: int,
|
| 18 |
-
chunk_size: int,
|
| 19 |
-
):
|
| 20 |
-
combined = torch.reshape(
|
| 21 |
-
combined, (original_batch_size, n_chunks, n_channel, chunk_size)
|
| 22 |
-
)
|
| 23 |
-
combined = torch.permute(combined, (0, 2, 3, 1)).reshape(
|
| 24 |
-
original_batch_size * n_channel, chunk_size, n_chunks
|
| 25 |
-
)
|
| 26 |
-
|
| 27 |
-
return combined
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
@torch.jit.script
|
| 31 |
-
def unfold(
|
| 32 |
-
padded_audio: torch.Tensor,
|
| 33 |
-
original_batch_size: int,
|
| 34 |
-
n_channel: int,
|
| 35 |
-
chunk_size: int,
|
| 36 |
-
hop_size: int,
|
| 37 |
-
) -> torch.Tensor:
|
| 38 |
-
|
| 39 |
-
unfolded_input = F.unfold(
|
| 40 |
-
padded_audio[:, :, None, :], kernel_size=(1, chunk_size), stride=(1, hop_size)
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
_, _, n_chunks = unfolded_input.shape
|
| 44 |
-
unfolded_input = unfolded_input.view(
|
| 45 |
-
original_batch_size, n_channel, chunk_size, n_chunks
|
| 46 |
-
)
|
| 47 |
-
unfolded_input = torch.permute(unfolded_input, (0, 3, 1, 2)).reshape(
|
| 48 |
-
original_batch_size * n_chunks, n_channel, chunk_size
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
return unfolded_input
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
@torch.jit.script
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
combined =
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
combined
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
self.
|
| 139 |
-
self.
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
)
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
del
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
self.
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
class LinearFader(BaseFader):
|
| 326 |
-
def __init__(
|
| 327 |
-
self,
|
| 328 |
-
chunk_size_second: float,
|
| 329 |
-
hop_size_second: float,
|
| 330 |
-
fs: int,
|
| 331 |
-
fade_edge_frames: bool = False,
|
| 332 |
-
batch_size: int = 1,
|
| 333 |
-
) -> None:
|
| 334 |
-
|
| 335 |
-
assert hop_size_second >= chunk_size_second / 2
|
| 336 |
-
|
| 337 |
-
super().__init__(
|
| 338 |
-
chunk_size_second=chunk_size_second,
|
| 339 |
-
hop_size_second=hop_size_second,
|
| 340 |
-
fs=fs,
|
| 341 |
-
fade_edge_frames=fade_edge_frames,
|
| 342 |
-
batch_size=batch_size,
|
| 343 |
-
)
|
| 344 |
-
|
| 345 |
-
in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1]
|
| 346 |
-
out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:]
|
| 347 |
-
center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size)
|
| 348 |
-
inout_ones = torch.ones(self.overlap_size)
|
| 349 |
-
|
| 350 |
-
# using nn.Parameters allows lightning to take care of devices for us
|
| 351 |
-
self.register_buffer(
|
| 352 |
-
"standard_window", torch.concat([in_fade, center_ones, out_fade])
|
| 353 |
-
)
|
| 354 |
-
|
| 355 |
-
self.fade_edge_frames = fade_edge_frames
|
| 356 |
-
self.edge_frame_pad_size = (self.overlap_size, self.overlap_size)
|
| 357 |
-
|
| 358 |
-
if not self.fade_edge_frames:
|
| 359 |
-
self.first_window = nn.Parameter(
|
| 360 |
-
torch.concat([inout_ones, center_ones, out_fade]), requires_grad=False
|
| 361 |
-
)
|
| 362 |
-
self.last_window = nn.Parameter(
|
| 363 |
-
torch.concat([in_fade, center_ones, inout_ones]), requires_grad=False
|
| 364 |
-
)
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
class OverlapAddFader(BaseFader):
|
| 368 |
-
def __init__(
|
| 369 |
-
self,
|
| 370 |
-
window_type: str,
|
| 371 |
-
chunk_size_second: float,
|
| 372 |
-
hop_size_second: float,
|
| 373 |
-
fs: int,
|
| 374 |
-
batch_size: int = 1,
|
| 375 |
-
) -> None:
|
| 376 |
-
assert (chunk_size_second / hop_size_second) % 2 == 0
|
| 377 |
-
assert int(chunk_size_second * fs) % 2 == 0
|
| 378 |
-
|
| 379 |
-
super().__init__(
|
| 380 |
-
chunk_size_second=chunk_size_second,
|
| 381 |
-
hop_size_second=hop_size_second,
|
| 382 |
-
fs=fs,
|
| 383 |
-
fade_edge_frames=True,
|
| 384 |
-
batch_size=batch_size,
|
| 385 |
-
)
|
| 386 |
-
|
| 387 |
-
self.hop_multiplier = self.chunk_size / (2 * self.hop_size)
|
| 388 |
-
# print(f"hop multiplier: {self.hop_multiplier}")
|
| 389 |
-
|
| 390 |
-
self.edge_frame_pad_sizes = (2 * self.overlap_size, 2 * self.overlap_size)
|
| 391 |
-
|
| 392 |
-
self.register_buffer(
|
| 393 |
-
"standard_window",
|
| 394 |
-
torch.windows.__dict__[window_type](
|
| 395 |
-
self.chunk_size,
|
| 396 |
-
sym=False, # dtype=torch.float64
|
| 397 |
-
)
|
| 398 |
-
/ self.hop_multiplier,
|
| 399 |
-
)
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
if __name__ == "__main__":
|
| 403 |
-
import torchaudio as ta
|
| 404 |
-
|
| 405 |
-
fs = 44100
|
| 406 |
-
ola = OverlapAddFader("hann", 6.0, 1.0, fs, batch_size=16)
|
| 407 |
-
audio_, _ = ta.load(
|
| 408 |
-
"$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too " "Much/vocals.wav"
|
| 409 |
-
)
|
| 410 |
-
audio_ = audio_[None, ...]
|
| 411 |
-
out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"]
|
| 412 |
-
print(torch.allclose(out, audio_))
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
|
| 3 |
+
from tqdm.auto import tqdm
|
| 4 |
+
from typing import Callable, Dict, List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@torch.jit.script
|
| 13 |
+
def merge(
|
| 14 |
+
combined: torch.Tensor,
|
| 15 |
+
original_batch_size: int,
|
| 16 |
+
n_channel: int,
|
| 17 |
+
n_chunks: int,
|
| 18 |
+
chunk_size: int,
|
| 19 |
+
):
|
| 20 |
+
combined = torch.reshape(
|
| 21 |
+
combined, (original_batch_size, n_chunks, n_channel, chunk_size)
|
| 22 |
+
)
|
| 23 |
+
combined = torch.permute(combined, (0, 2, 3, 1)).reshape(
|
| 24 |
+
original_batch_size * n_channel, chunk_size, n_chunks
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
return combined
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@torch.jit.script
|
| 31 |
+
def unfold(
|
| 32 |
+
padded_audio: torch.Tensor,
|
| 33 |
+
original_batch_size: int,
|
| 34 |
+
n_channel: int,
|
| 35 |
+
chunk_size: int,
|
| 36 |
+
hop_size: int,
|
| 37 |
+
) -> torch.Tensor:
|
| 38 |
+
|
| 39 |
+
unfolded_input = F.unfold(
|
| 40 |
+
padded_audio[:, :, None, :], kernel_size=(1, chunk_size), stride=(1, hop_size)
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
_, _, n_chunks = unfolded_input.shape
|
| 44 |
+
unfolded_input = unfolded_input.view(
|
| 45 |
+
original_batch_size, n_channel, chunk_size, n_chunks
|
| 46 |
+
)
|
| 47 |
+
unfolded_input = torch.permute(unfolded_input, (0, 3, 1, 2)).reshape(
|
| 48 |
+
original_batch_size * n_chunks, n_channel, chunk_size
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
return unfolded_input
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@torch.jit.script
|
| 55 |
+
def merge_chunks_all(
|
| 56 |
+
combined: torch.Tensor,
|
| 57 |
+
original_batch_size: int,
|
| 58 |
+
n_channel: int,
|
| 59 |
+
n_samples: int,
|
| 60 |
+
n_padded_samples: int,
|
| 61 |
+
n_chunks: int,
|
| 62 |
+
chunk_size: int,
|
| 63 |
+
hop_size: int,
|
| 64 |
+
edge_frame_pad_sizes: Tuple[int, int],
|
| 65 |
+
standard_window: torch.Tensor,
|
| 66 |
+
first_window: torch.Tensor,
|
| 67 |
+
last_window: torch.Tensor,
|
| 68 |
+
):
|
| 69 |
+
combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
|
| 70 |
+
|
| 71 |
+
combined = combined * standard_window[:, None].to(combined.device)
|
| 72 |
+
|
| 73 |
+
combined = F.fold(
|
| 74 |
+
combined.to(torch.float32),
|
| 75 |
+
output_size=(1, n_padded_samples),
|
| 76 |
+
kernel_size=(1, chunk_size),
|
| 77 |
+
stride=(1, hop_size),
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
combined = combined.view(original_batch_size, n_channel, n_padded_samples)
|
| 81 |
+
|
| 82 |
+
pad_front, pad_back = edge_frame_pad_sizes
|
| 83 |
+
combined = combined[..., pad_front:-pad_back]
|
| 84 |
+
|
| 85 |
+
combined = combined[..., :n_samples]
|
| 86 |
+
|
| 87 |
+
return combined
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def merge_chunks_edge(
|
| 91 |
+
combined: torch.Tensor,
|
| 92 |
+
original_batch_size: int,
|
| 93 |
+
n_channel: int,
|
| 94 |
+
n_samples: int,
|
| 95 |
+
n_padded_samples: int,
|
| 96 |
+
n_chunks: int,
|
| 97 |
+
chunk_size: int,
|
| 98 |
+
hop_size: int,
|
| 99 |
+
edge_frame_pad_sizes: Tuple[int, int],
|
| 100 |
+
standard_window: torch.Tensor,
|
| 101 |
+
first_window: torch.Tensor,
|
| 102 |
+
last_window: torch.Tensor,
|
| 103 |
+
):
|
| 104 |
+
combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
|
| 105 |
+
|
| 106 |
+
combined[..., 0] = combined[..., 0] * first_window
|
| 107 |
+
combined[..., -1] = combined[..., -1] * last_window
|
| 108 |
+
combined[..., 1:-1] = combined[..., 1:-1] * standard_window[:, None]
|
| 109 |
+
|
| 110 |
+
combined = F.fold(
|
| 111 |
+
combined,
|
| 112 |
+
output_size=(1, n_padded_samples),
|
| 113 |
+
kernel_size=(1, chunk_size),
|
| 114 |
+
stride=(1, hop_size),
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
combined = combined.view(original_batch_size, n_channel, n_padded_samples)
|
| 118 |
+
|
| 119 |
+
combined = combined[..., :n_samples]
|
| 120 |
+
|
| 121 |
+
return combined
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class BaseFader(nn.Module):
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
chunk_size_second: float,
|
| 128 |
+
hop_size_second: float,
|
| 129 |
+
fs: int,
|
| 130 |
+
fade_edge_frames: bool,
|
| 131 |
+
batch_size: int,
|
| 132 |
+
) -> None:
|
| 133 |
+
super().__init__()
|
| 134 |
+
|
| 135 |
+
self.chunk_size = int(chunk_size_second * fs)
|
| 136 |
+
self.hop_size = int(hop_size_second * fs)
|
| 137 |
+
self.overlap_size = self.chunk_size - self.hop_size
|
| 138 |
+
self.fade_edge_frames = fade_edge_frames
|
| 139 |
+
self.batch_size = batch_size
|
| 140 |
+
|
| 141 |
+
def prepare(self, audio):
|
| 142 |
+
|
| 143 |
+
if self.fade_edge_frames:
|
| 144 |
+
audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect")
|
| 145 |
+
|
| 146 |
+
n_samples = audio.shape[-1]
|
| 147 |
+
n_chunks = int(np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1)
|
| 148 |
+
|
| 149 |
+
padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size
|
| 150 |
+
pad_size = padded_size - n_samples
|
| 151 |
+
|
| 152 |
+
padded_audio = F.pad(audio, (0, pad_size))
|
| 153 |
+
|
| 154 |
+
return padded_audio, n_chunks
|
| 155 |
+
|
| 156 |
+
def forward(
|
| 157 |
+
self,
|
| 158 |
+
audio: torch.Tensor,
|
| 159 |
+
model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
|
| 160 |
+
):
|
| 161 |
+
|
| 162 |
+
original_dtype = audio.dtype
|
| 163 |
+
original_device = audio.device
|
| 164 |
+
|
| 165 |
+
audio = audio.to("cpu")
|
| 166 |
+
|
| 167 |
+
original_batch_size, n_channel, n_samples = audio.shape
|
| 168 |
+
padded_audio, n_chunks = self.prepare(audio)
|
| 169 |
+
del audio
|
| 170 |
+
n_padded_samples = padded_audio.shape[-1]
|
| 171 |
+
|
| 172 |
+
if n_channel > 1:
|
| 173 |
+
padded_audio = padded_audio.view(
|
| 174 |
+
original_batch_size * n_channel, 1, n_padded_samples
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
unfolded_input = unfold(
|
| 178 |
+
padded_audio, original_batch_size, n_channel, self.chunk_size, self.hop_size
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
n_total_chunks, n_channel, chunk_size = unfolded_input.shape
|
| 182 |
+
|
| 183 |
+
n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int)
|
| 184 |
+
|
| 185 |
+
chunks_in = [
|
| 186 |
+
unfolded_input[b * self.batch_size : (b + 1) * self.batch_size, ...].clone()
|
| 187 |
+
for b in range(n_batch)
|
| 188 |
+
]
|
| 189 |
+
|
| 190 |
+
all_chunks_out = defaultdict(
|
| 191 |
+
lambda: torch.zeros_like(unfolded_input, device="cpu")
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
for b, cin in enumerate(chunks_in):
|
| 195 |
+
if torch.allclose(cin, torch.tensor(0.0)):
|
| 196 |
+
del cin
|
| 197 |
+
continue
|
| 198 |
+
|
| 199 |
+
chunks_out = model_fn(cin.to(original_device))
|
| 200 |
+
del cin
|
| 201 |
+
for s, c in chunks_out.items():
|
| 202 |
+
all_chunks_out[s][
|
| 203 |
+
b * self.batch_size : (b + 1) * self.batch_size, ...
|
| 204 |
+
] = c.cpu()
|
| 205 |
+
del chunks_out
|
| 206 |
+
|
| 207 |
+
del unfolded_input
|
| 208 |
+
del padded_audio
|
| 209 |
+
|
| 210 |
+
if self.fade_edge_frames:
|
| 211 |
+
fn = merge_chunks_all
|
| 212 |
+
else:
|
| 213 |
+
fn = merge_chunks_edge
|
| 214 |
+
outputs = {}
|
| 215 |
+
|
| 216 |
+
torch.cuda.empty_cache()
|
| 217 |
+
|
| 218 |
+
for s, c in all_chunks_out.items():
|
| 219 |
+
combined: torch.Tensor = fn(
|
| 220 |
+
c,
|
| 221 |
+
original_batch_size,
|
| 222 |
+
n_channel,
|
| 223 |
+
n_samples,
|
| 224 |
+
n_padded_samples,
|
| 225 |
+
n_chunks,
|
| 226 |
+
self.chunk_size,
|
| 227 |
+
self.hop_size,
|
| 228 |
+
self.edge_frame_pad_sizes,
|
| 229 |
+
self.standard_window,
|
| 230 |
+
self.__dict__.get("first_window", self.standard_window),
|
| 231 |
+
self.__dict__.get("last_window", self.standard_window),
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
outputs[s] = combined.to(dtype=original_dtype, device=original_device)
|
| 235 |
+
|
| 236 |
+
return {"audio": outputs}
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class LinearFader(BaseFader):
|
| 240 |
+
def __init__(
|
| 241 |
+
self,
|
| 242 |
+
chunk_size_second: float,
|
| 243 |
+
hop_size_second: float,
|
| 244 |
+
fs: int,
|
| 245 |
+
fade_edge_frames: bool = False,
|
| 246 |
+
batch_size: int = 1,
|
| 247 |
+
) -> None:
|
| 248 |
+
|
| 249 |
+
assert hop_size_second >= chunk_size_second / 2
|
| 250 |
+
|
| 251 |
+
super().__init__(
|
| 252 |
+
chunk_size_second=chunk_size_second,
|
| 253 |
+
hop_size_second=hop_size_second,
|
| 254 |
+
fs=fs,
|
| 255 |
+
fade_edge_frames=fade_edge_frames,
|
| 256 |
+
batch_size=batch_size,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1]
|
| 260 |
+
out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:]
|
| 261 |
+
center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size)
|
| 262 |
+
inout_ones = torch.ones(self.overlap_size)
|
| 263 |
+
|
| 264 |
+
self.register_buffer(
|
| 265 |
+
"standard_window", torch.concat([in_fade, center_ones, out_fade])
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
self.fade_edge_frames = fade_edge_frames
|
| 269 |
+
self.edge_frame_pad_size = (self.overlap_size, self.overlap_size)
|
| 270 |
+
|
| 271 |
+
if not self.fade_edge_frames:
|
| 272 |
+
self.first_window = nn.Parameter(
|
| 273 |
+
torch.concat([inout_ones, center_ones, out_fade]), requires_grad=False
|
| 274 |
+
)
|
| 275 |
+
self.last_window = nn.Parameter(
|
| 276 |
+
torch.concat([in_fade, center_ones, inout_ones]), requires_grad=False
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class OverlapAddFader(BaseFader):
|
| 281 |
+
def __init__(
|
| 282 |
+
self,
|
| 283 |
+
window_type: str,
|
| 284 |
+
chunk_size_second: float,
|
| 285 |
+
hop_size_second: float,
|
| 286 |
+
fs: int,
|
| 287 |
+
batch_size: int = 1,
|
| 288 |
+
) -> None:
|
| 289 |
+
assert (chunk_size_second / hop_size_second) % 2 == 0
|
| 290 |
+
assert int(chunk_size_second * fs) % 2 == 0
|
| 291 |
+
|
| 292 |
+
super().__init__(
|
| 293 |
+
chunk_size_second=chunk_size_second,
|
| 294 |
+
hop_size_second=hop_size_second,
|
| 295 |
+
fs=fs,
|
| 296 |
+
fade_edge_frames=True,
|
| 297 |
+
batch_size=batch_size,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
self.hop_multiplier = self.chunk_size / (2 * self.hop_size)
|
| 301 |
+
|
| 302 |
+
self.edge_frame_pad_sizes = (2 * self.overlap_size, 2 * self.overlap_size)
|
| 303 |
+
|
| 304 |
+
self.register_buffer(
|
| 305 |
+
"standard_window",
|
| 306 |
+
torch.windows.__dict__[window_type](
|
| 307 |
+
self.chunk_size,
|
| 308 |
+
sym=False,
|
| 309 |
+
)
|
| 310 |
+
/ self.hop_multiplier,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
if __name__ == "__main__":
|
| 315 |
+
import torchaudio as ta
|
| 316 |
+
|
| 317 |
+
fs = 44100
|
| 318 |
+
ola = OverlapAddFader("hann", 6.0, 1.0, fs, batch_size=16)
|
| 319 |
+
audio_, _ = ta.load(
|
| 320 |
+
"$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too " "Much/vocals.wav"
|
| 321 |
+
)
|
| 322 |
+
audio_ = audio_[None, ...]
|
| 323 |
+
out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"]
|
| 324 |
+
print(torch.allclose(out, audio_))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mvsepless/models/bandit/model_from_config.py
CHANGED
|
@@ -1,26 +1,26 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
import os.path
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
import yaml
|
| 6 |
-
from ml_collections import ConfigDict
|
| 7 |
-
|
| 8 |
-
torch.set_float32_matmul_precision("medium")
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def get_model(
|
| 12 |
-
config_path,
|
| 13 |
-
weights_path,
|
| 14 |
-
device,
|
| 15 |
-
):
|
| 16 |
-
from .core.model import MultiMaskMultiSourceBandSplitRNNSimple
|
| 17 |
-
|
| 18 |
-
f = open(config_path)
|
| 19 |
-
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
|
| 20 |
-
f.close()
|
| 21 |
-
|
| 22 |
-
model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model)
|
| 23 |
-
d = torch.load(code_path + "model_bandit_plus_dnr_sdr_11.47.chpt")
|
| 24 |
-
model.load_state_dict(d)
|
| 25 |
-
model.to(device)
|
| 26 |
-
return model, config
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os.path
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import yaml
|
| 6 |
+
from ml_collections import ConfigDict
|
| 7 |
+
|
| 8 |
+
torch.set_float32_matmul_precision("medium")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_model(
|
| 12 |
+
config_path,
|
| 13 |
+
weights_path,
|
| 14 |
+
device,
|
| 15 |
+
):
|
| 16 |
+
from .core.model import MultiMaskMultiSourceBandSplitRNNSimple
|
| 17 |
+
|
| 18 |
+
f = open(config_path)
|
| 19 |
+
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
|
| 20 |
+
f.close()
|
| 21 |
+
|
| 22 |
+
model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model)
|
| 23 |
+
d = torch.load(code_path + "model_bandit_plus_dnr_sdr_11.47.chpt")
|
| 24 |
+
model.load_state_dict(d)
|
| 25 |
+
model.to(device)
|
| 26 |
+
return model, config
|
mvsepless/models/bandit_v2/bandit.py
CHANGED
|
@@ -1,363 +1,360 @@
|
|
| 1 |
-
from typing import Dict, List, Optional
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torchaudio as ta
|
| 5 |
-
from torch import nn
|
| 6 |
-
import pytorch_lightning as pl
|
| 7 |
-
|
| 8 |
-
from .bandsplit import BandSplitModule
|
| 9 |
-
from .maskestim import OverlappingMaskEstimationModule
|
| 10 |
-
from .tfmodel import SeqBandModellingModule
|
| 11 |
-
from .utils import MusicalBandsplitSpecification
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class BaseEndToEndModule(pl.LightningModule):
|
| 15 |
-
def __init__(
|
| 16 |
-
self,
|
| 17 |
-
) -> None:
|
| 18 |
-
super().__init__()
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class BaseBandit(BaseEndToEndModule):
|
| 22 |
-
def __init__(
|
| 23 |
-
self,
|
| 24 |
-
in_channels: int,
|
| 25 |
-
fs: int,
|
| 26 |
-
band_type: str = "musical",
|
| 27 |
-
n_bands: int = 64,
|
| 28 |
-
require_no_overlap: bool = False,
|
| 29 |
-
require_no_gap: bool = True,
|
| 30 |
-
normalize_channel_independently: bool = False,
|
| 31 |
-
treat_channel_as_feature: bool = True,
|
| 32 |
-
n_sqm_modules: int = 12,
|
| 33 |
-
emb_dim: int = 128,
|
| 34 |
-
rnn_dim: int = 256,
|
| 35 |
-
bidirectional: bool = True,
|
| 36 |
-
rnn_type: str = "LSTM",
|
| 37 |
-
n_fft: int = 2048,
|
| 38 |
-
win_length: Optional[int] = 2048,
|
| 39 |
-
hop_length: int = 512,
|
| 40 |
-
window_fn: str = "hann_window",
|
| 41 |
-
wkwargs: Optional[Dict] = None,
|
| 42 |
-
power: Optional[int] = None,
|
| 43 |
-
center: bool = True,
|
| 44 |
-
normalized: bool = True,
|
| 45 |
-
pad_mode: str = "constant",
|
| 46 |
-
onesided: bool = True,
|
| 47 |
-
):
|
| 48 |
-
super().__init__()
|
| 49 |
-
|
| 50 |
-
self.in_channels = in_channels
|
| 51 |
-
|
| 52 |
-
self.instantitate_spectral(
|
| 53 |
-
n_fft=n_fft,
|
| 54 |
-
win_length=win_length,
|
| 55 |
-
hop_length=hop_length,
|
| 56 |
-
window_fn=window_fn,
|
| 57 |
-
wkwargs=wkwargs,
|
| 58 |
-
power=power,
|
| 59 |
-
normalized=normalized,
|
| 60 |
-
center=center,
|
| 61 |
-
pad_mode=pad_mode,
|
| 62 |
-
onesided=onesided,
|
| 63 |
-
)
|
| 64 |
-
|
| 65 |
-
self.instantiate_bandsplit(
|
| 66 |
-
in_channels=in_channels,
|
| 67 |
-
band_type=band_type,
|
| 68 |
-
n_bands=n_bands,
|
| 69 |
-
require_no_overlap=require_no_overlap,
|
| 70 |
-
require_no_gap=require_no_gap,
|
| 71 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 72 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 73 |
-
emb_dim=emb_dim,
|
| 74 |
-
n_fft=n_fft,
|
| 75 |
-
fs=fs,
|
| 76 |
-
)
|
| 77 |
-
|
| 78 |
-
self.instantiate_tf_modelling(
|
| 79 |
-
n_sqm_modules=n_sqm_modules,
|
| 80 |
-
emb_dim=emb_dim,
|
| 81 |
-
rnn_dim=rnn_dim,
|
| 82 |
-
bidirectional=bidirectional,
|
| 83 |
-
rnn_type=rnn_type,
|
| 84 |
-
)
|
| 85 |
-
|
| 86 |
-
def instantitate_spectral(
|
| 87 |
-
self,
|
| 88 |
-
n_fft: int = 2048,
|
| 89 |
-
win_length: Optional[int] = 2048,
|
| 90 |
-
hop_length: int = 512,
|
| 91 |
-
window_fn: str = "hann_window",
|
| 92 |
-
wkwargs: Optional[Dict] = None,
|
| 93 |
-
power: Optional[int] = None,
|
| 94 |
-
normalized: bool = True,
|
| 95 |
-
center: bool = True,
|
| 96 |
-
pad_mode: str = "constant",
|
| 97 |
-
onesided: bool = True,
|
| 98 |
-
):
|
| 99 |
-
assert power is None
|
| 100 |
-
|
| 101 |
-
window_fn = torch.__dict__[window_fn]
|
| 102 |
-
|
| 103 |
-
self.stft = ta.transforms.Spectrogram(
|
| 104 |
-
n_fft=n_fft,
|
| 105 |
-
win_length=win_length,
|
| 106 |
-
hop_length=hop_length,
|
| 107 |
-
pad_mode=pad_mode,
|
| 108 |
-
pad=0,
|
| 109 |
-
window_fn=window_fn,
|
| 110 |
-
wkwargs=wkwargs,
|
| 111 |
-
power=power,
|
| 112 |
-
normalized=normalized,
|
| 113 |
-
center=center,
|
| 114 |
-
onesided=onesided,
|
| 115 |
-
)
|
| 116 |
-
|
| 117 |
-
self.istft = ta.transforms.InverseSpectrogram(
|
| 118 |
-
n_fft=n_fft,
|
| 119 |
-
win_length=win_length,
|
| 120 |
-
hop_length=hop_length,
|
| 121 |
-
pad_mode=pad_mode,
|
| 122 |
-
pad=0,
|
| 123 |
-
window_fn=window_fn,
|
| 124 |
-
wkwargs=wkwargs,
|
| 125 |
-
normalized=normalized,
|
| 126 |
-
center=center,
|
| 127 |
-
onesided=onesided,
|
| 128 |
-
)
|
| 129 |
-
|
| 130 |
-
def instantiate_bandsplit(
|
| 131 |
-
self,
|
| 132 |
-
in_channels: int,
|
| 133 |
-
band_type: str = "musical",
|
| 134 |
-
n_bands: int = 64,
|
| 135 |
-
require_no_overlap: bool = False,
|
| 136 |
-
require_no_gap: bool = True,
|
| 137 |
-
normalize_channel_independently: bool = False,
|
| 138 |
-
treat_channel_as_feature: bool = True,
|
| 139 |
-
emb_dim: int = 128,
|
| 140 |
-
n_fft: int = 2048,
|
| 141 |
-
fs: int = 44100,
|
| 142 |
-
):
|
| 143 |
-
assert band_type == "musical"
|
| 144 |
-
|
| 145 |
-
self.band_specs = MusicalBandsplitSpecification(
|
| 146 |
-
nfft=n_fft, fs=fs, n_bands=n_bands
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
self.band_split = BandSplitModule(
|
| 150 |
-
in_channels=in_channels,
|
| 151 |
-
band_specs=self.band_specs.get_band_specs(),
|
| 152 |
-
require_no_overlap=require_no_overlap,
|
| 153 |
-
require_no_gap=require_no_gap,
|
| 154 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 155 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 156 |
-
emb_dim=emb_dim,
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
-
def instantiate_tf_modelling(
|
| 160 |
-
self,
|
| 161 |
-
n_sqm_modules: int = 12,
|
| 162 |
-
emb_dim: int = 128,
|
| 163 |
-
rnn_dim: int = 256,
|
| 164 |
-
bidirectional: bool = True,
|
| 165 |
-
rnn_type: str = "LSTM",
|
| 166 |
-
):
|
| 167 |
-
try:
|
| 168 |
-
self.tf_model = torch.compile(
|
| 169 |
-
SeqBandModellingModule(
|
| 170 |
-
n_modules=n_sqm_modules,
|
| 171 |
-
emb_dim=emb_dim,
|
| 172 |
-
rnn_dim=rnn_dim,
|
| 173 |
-
bidirectional=bidirectional,
|
| 174 |
-
rnn_type=rnn_type,
|
| 175 |
-
),
|
| 176 |
-
disable=True,
|
| 177 |
-
)
|
| 178 |
-
except Exception as e:
|
| 179 |
-
self.tf_model = SeqBandModellingModule(
|
| 180 |
-
n_modules=n_sqm_modules,
|
| 181 |
-
emb_dim=emb_dim,
|
| 182 |
-
rnn_dim=rnn_dim,
|
| 183 |
-
bidirectional=bidirectional,
|
| 184 |
-
rnn_type=rnn_type,
|
| 185 |
-
)
|
| 186 |
-
|
| 187 |
-
def mask(self, x, m):
|
| 188 |
-
return x * m
|
| 189 |
-
|
| 190 |
-
def forward(self, batch, mode="train"):
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
s =
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
}
|
| 362 |
-
|
| 363 |
-
return batch
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio as ta
|
| 5 |
+
from torch import nn
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
|
| 8 |
+
from .bandsplit import BandSplitModule
|
| 9 |
+
from .maskestim import OverlappingMaskEstimationModule
|
| 10 |
+
from .tfmodel import SeqBandModellingModule
|
| 11 |
+
from .utils import MusicalBandsplitSpecification
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BaseEndToEndModule(pl.LightningModule):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
) -> None:
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class BaseBandit(BaseEndToEndModule):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
in_channels: int,
|
| 25 |
+
fs: int,
|
| 26 |
+
band_type: str = "musical",
|
| 27 |
+
n_bands: int = 64,
|
| 28 |
+
require_no_overlap: bool = False,
|
| 29 |
+
require_no_gap: bool = True,
|
| 30 |
+
normalize_channel_independently: bool = False,
|
| 31 |
+
treat_channel_as_feature: bool = True,
|
| 32 |
+
n_sqm_modules: int = 12,
|
| 33 |
+
emb_dim: int = 128,
|
| 34 |
+
rnn_dim: int = 256,
|
| 35 |
+
bidirectional: bool = True,
|
| 36 |
+
rnn_type: str = "LSTM",
|
| 37 |
+
n_fft: int = 2048,
|
| 38 |
+
win_length: Optional[int] = 2048,
|
| 39 |
+
hop_length: int = 512,
|
| 40 |
+
window_fn: str = "hann_window",
|
| 41 |
+
wkwargs: Optional[Dict] = None,
|
| 42 |
+
power: Optional[int] = None,
|
| 43 |
+
center: bool = True,
|
| 44 |
+
normalized: bool = True,
|
| 45 |
+
pad_mode: str = "constant",
|
| 46 |
+
onesided: bool = True,
|
| 47 |
+
):
|
| 48 |
+
super().__init__()
|
| 49 |
+
|
| 50 |
+
self.in_channels = in_channels
|
| 51 |
+
|
| 52 |
+
self.instantitate_spectral(
|
| 53 |
+
n_fft=n_fft,
|
| 54 |
+
win_length=win_length,
|
| 55 |
+
hop_length=hop_length,
|
| 56 |
+
window_fn=window_fn,
|
| 57 |
+
wkwargs=wkwargs,
|
| 58 |
+
power=power,
|
| 59 |
+
normalized=normalized,
|
| 60 |
+
center=center,
|
| 61 |
+
pad_mode=pad_mode,
|
| 62 |
+
onesided=onesided,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
self.instantiate_bandsplit(
|
| 66 |
+
in_channels=in_channels,
|
| 67 |
+
band_type=band_type,
|
| 68 |
+
n_bands=n_bands,
|
| 69 |
+
require_no_overlap=require_no_overlap,
|
| 70 |
+
require_no_gap=require_no_gap,
|
| 71 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 72 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 73 |
+
emb_dim=emb_dim,
|
| 74 |
+
n_fft=n_fft,
|
| 75 |
+
fs=fs,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
self.instantiate_tf_modelling(
|
| 79 |
+
n_sqm_modules=n_sqm_modules,
|
| 80 |
+
emb_dim=emb_dim,
|
| 81 |
+
rnn_dim=rnn_dim,
|
| 82 |
+
bidirectional=bidirectional,
|
| 83 |
+
rnn_type=rnn_type,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def instantitate_spectral(
|
| 87 |
+
self,
|
| 88 |
+
n_fft: int = 2048,
|
| 89 |
+
win_length: Optional[int] = 2048,
|
| 90 |
+
hop_length: int = 512,
|
| 91 |
+
window_fn: str = "hann_window",
|
| 92 |
+
wkwargs: Optional[Dict] = None,
|
| 93 |
+
power: Optional[int] = None,
|
| 94 |
+
normalized: bool = True,
|
| 95 |
+
center: bool = True,
|
| 96 |
+
pad_mode: str = "constant",
|
| 97 |
+
onesided: bool = True,
|
| 98 |
+
):
|
| 99 |
+
assert power is None
|
| 100 |
+
|
| 101 |
+
window_fn = torch.__dict__[window_fn]
|
| 102 |
+
|
| 103 |
+
self.stft = ta.transforms.Spectrogram(
|
| 104 |
+
n_fft=n_fft,
|
| 105 |
+
win_length=win_length,
|
| 106 |
+
hop_length=hop_length,
|
| 107 |
+
pad_mode=pad_mode,
|
| 108 |
+
pad=0,
|
| 109 |
+
window_fn=window_fn,
|
| 110 |
+
wkwargs=wkwargs,
|
| 111 |
+
power=power,
|
| 112 |
+
normalized=normalized,
|
| 113 |
+
center=center,
|
| 114 |
+
onesided=onesided,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
self.istft = ta.transforms.InverseSpectrogram(
|
| 118 |
+
n_fft=n_fft,
|
| 119 |
+
win_length=win_length,
|
| 120 |
+
hop_length=hop_length,
|
| 121 |
+
pad_mode=pad_mode,
|
| 122 |
+
pad=0,
|
| 123 |
+
window_fn=window_fn,
|
| 124 |
+
wkwargs=wkwargs,
|
| 125 |
+
normalized=normalized,
|
| 126 |
+
center=center,
|
| 127 |
+
onesided=onesided,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def instantiate_bandsplit(
|
| 131 |
+
self,
|
| 132 |
+
in_channels: int,
|
| 133 |
+
band_type: str = "musical",
|
| 134 |
+
n_bands: int = 64,
|
| 135 |
+
require_no_overlap: bool = False,
|
| 136 |
+
require_no_gap: bool = True,
|
| 137 |
+
normalize_channel_independently: bool = False,
|
| 138 |
+
treat_channel_as_feature: bool = True,
|
| 139 |
+
emb_dim: int = 128,
|
| 140 |
+
n_fft: int = 2048,
|
| 141 |
+
fs: int = 44100,
|
| 142 |
+
):
|
| 143 |
+
assert band_type == "musical"
|
| 144 |
+
|
| 145 |
+
self.band_specs = MusicalBandsplitSpecification(
|
| 146 |
+
nfft=n_fft, fs=fs, n_bands=n_bands
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
self.band_split = BandSplitModule(
|
| 150 |
+
in_channels=in_channels,
|
| 151 |
+
band_specs=self.band_specs.get_band_specs(),
|
| 152 |
+
require_no_overlap=require_no_overlap,
|
| 153 |
+
require_no_gap=require_no_gap,
|
| 154 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 155 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 156 |
+
emb_dim=emb_dim,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def instantiate_tf_modelling(
|
| 160 |
+
self,
|
| 161 |
+
n_sqm_modules: int = 12,
|
| 162 |
+
emb_dim: int = 128,
|
| 163 |
+
rnn_dim: int = 256,
|
| 164 |
+
bidirectional: bool = True,
|
| 165 |
+
rnn_type: str = "LSTM",
|
| 166 |
+
):
|
| 167 |
+
try:
|
| 168 |
+
self.tf_model = torch.compile(
|
| 169 |
+
SeqBandModellingModule(
|
| 170 |
+
n_modules=n_sqm_modules,
|
| 171 |
+
emb_dim=emb_dim,
|
| 172 |
+
rnn_dim=rnn_dim,
|
| 173 |
+
bidirectional=bidirectional,
|
| 174 |
+
rnn_type=rnn_type,
|
| 175 |
+
),
|
| 176 |
+
disable=True,
|
| 177 |
+
)
|
| 178 |
+
except Exception as e:
|
| 179 |
+
self.tf_model = SeqBandModellingModule(
|
| 180 |
+
n_modules=n_sqm_modules,
|
| 181 |
+
emb_dim=emb_dim,
|
| 182 |
+
rnn_dim=rnn_dim,
|
| 183 |
+
bidirectional=bidirectional,
|
| 184 |
+
rnn_type=rnn_type,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def mask(self, x, m):
|
| 188 |
+
return x * m
|
| 189 |
+
|
| 190 |
+
def forward(self, batch, mode="train"):
|
| 191 |
+
init_shape = batch.shape
|
| 192 |
+
if not isinstance(batch, dict):
|
| 193 |
+
mono = batch.view(-1, 1, batch.shape[-1])
|
| 194 |
+
batch = {"mixture": {"audio": mono}}
|
| 195 |
+
|
| 196 |
+
with torch.no_grad():
|
| 197 |
+
mixture = batch["mixture"]["audio"]
|
| 198 |
+
|
| 199 |
+
x = self.stft(mixture)
|
| 200 |
+
batch["mixture"]["spectrogram"] = x
|
| 201 |
+
|
| 202 |
+
if "sources" in batch.keys():
|
| 203 |
+
for stem in batch["sources"].keys():
|
| 204 |
+
s = batch["sources"][stem]["audio"]
|
| 205 |
+
s = self.stft(s)
|
| 206 |
+
batch["sources"][stem]["spectrogram"] = s
|
| 207 |
+
|
| 208 |
+
batch = self.separate(batch)
|
| 209 |
+
|
| 210 |
+
if 1:
|
| 211 |
+
b = []
|
| 212 |
+
for s in self.stems:
|
| 213 |
+
r = batch["estimates"][s]["audio"].view(
|
| 214 |
+
-1, init_shape[1], init_shape[2]
|
| 215 |
+
)
|
| 216 |
+
b.append(r)
|
| 217 |
+
batch = torch.stack(b, dim=1)
|
| 218 |
+
return batch
|
| 219 |
+
|
| 220 |
+
def encode(self, batch):
|
| 221 |
+
x = batch["mixture"]["spectrogram"]
|
| 222 |
+
length = batch["mixture"]["audio"].shape[-1]
|
| 223 |
+
|
| 224 |
+
z = self.band_split(x)
|
| 225 |
+
q = self.tf_model(z)
|
| 226 |
+
|
| 227 |
+
return x, q, length
|
| 228 |
+
|
| 229 |
+
def separate(self, batch):
|
| 230 |
+
raise NotImplementedError
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class Bandit(BaseBandit):
|
| 234 |
+
def __init__(
|
| 235 |
+
self,
|
| 236 |
+
in_channels: int,
|
| 237 |
+
stems: List[str],
|
| 238 |
+
band_type: str = "musical",
|
| 239 |
+
n_bands: int = 64,
|
| 240 |
+
require_no_overlap: bool = False,
|
| 241 |
+
require_no_gap: bool = True,
|
| 242 |
+
normalize_channel_independently: bool = False,
|
| 243 |
+
treat_channel_as_feature: bool = True,
|
| 244 |
+
n_sqm_modules: int = 12,
|
| 245 |
+
emb_dim: int = 128,
|
| 246 |
+
rnn_dim: int = 256,
|
| 247 |
+
bidirectional: bool = True,
|
| 248 |
+
rnn_type: str = "LSTM",
|
| 249 |
+
mlp_dim: int = 512,
|
| 250 |
+
hidden_activation: str = "Tanh",
|
| 251 |
+
hidden_activation_kwargs: Dict | None = None,
|
| 252 |
+
complex_mask: bool = True,
|
| 253 |
+
use_freq_weights: bool = True,
|
| 254 |
+
n_fft: int = 2048,
|
| 255 |
+
win_length: int | None = 2048,
|
| 256 |
+
hop_length: int = 512,
|
| 257 |
+
window_fn: str = "hann_window",
|
| 258 |
+
wkwargs: Dict | None = None,
|
| 259 |
+
power: int | None = None,
|
| 260 |
+
center: bool = True,
|
| 261 |
+
normalized: bool = True,
|
| 262 |
+
pad_mode: str = "constant",
|
| 263 |
+
onesided: bool = True,
|
| 264 |
+
fs: int = 44100,
|
| 265 |
+
stft_precisions="32",
|
| 266 |
+
bandsplit_precisions="bf16",
|
| 267 |
+
tf_model_precisions="bf16",
|
| 268 |
+
mask_estim_precisions="bf16",
|
| 269 |
+
):
|
| 270 |
+
super().__init__(
|
| 271 |
+
in_channels=in_channels,
|
| 272 |
+
band_type=band_type,
|
| 273 |
+
n_bands=n_bands,
|
| 274 |
+
require_no_overlap=require_no_overlap,
|
| 275 |
+
require_no_gap=require_no_gap,
|
| 276 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 277 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 278 |
+
n_sqm_modules=n_sqm_modules,
|
| 279 |
+
emb_dim=emb_dim,
|
| 280 |
+
rnn_dim=rnn_dim,
|
| 281 |
+
bidirectional=bidirectional,
|
| 282 |
+
rnn_type=rnn_type,
|
| 283 |
+
n_fft=n_fft,
|
| 284 |
+
win_length=win_length,
|
| 285 |
+
hop_length=hop_length,
|
| 286 |
+
window_fn=window_fn,
|
| 287 |
+
wkwargs=wkwargs,
|
| 288 |
+
power=power,
|
| 289 |
+
center=center,
|
| 290 |
+
normalized=normalized,
|
| 291 |
+
pad_mode=pad_mode,
|
| 292 |
+
onesided=onesided,
|
| 293 |
+
fs=fs,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
self.stems = stems
|
| 297 |
+
|
| 298 |
+
self.instantiate_mask_estim(
|
| 299 |
+
in_channels=in_channels,
|
| 300 |
+
stems=stems,
|
| 301 |
+
emb_dim=emb_dim,
|
| 302 |
+
mlp_dim=mlp_dim,
|
| 303 |
+
hidden_activation=hidden_activation,
|
| 304 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 305 |
+
complex_mask=complex_mask,
|
| 306 |
+
n_freq=n_fft // 2 + 1,
|
| 307 |
+
use_freq_weights=use_freq_weights,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
def instantiate_mask_estim(
|
| 311 |
+
self,
|
| 312 |
+
in_channels: int,
|
| 313 |
+
stems: List[str],
|
| 314 |
+
emb_dim: int,
|
| 315 |
+
mlp_dim: int,
|
| 316 |
+
hidden_activation: str,
|
| 317 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 318 |
+
complex_mask: bool = True,
|
| 319 |
+
n_freq: Optional[int] = None,
|
| 320 |
+
use_freq_weights: bool = False,
|
| 321 |
+
):
|
| 322 |
+
if hidden_activation_kwargs is None:
|
| 323 |
+
hidden_activation_kwargs = {}
|
| 324 |
+
|
| 325 |
+
assert n_freq is not None
|
| 326 |
+
|
| 327 |
+
self.mask_estim = nn.ModuleDict(
|
| 328 |
+
{
|
| 329 |
+
stem: OverlappingMaskEstimationModule(
|
| 330 |
+
band_specs=self.band_specs.get_band_specs(),
|
| 331 |
+
freq_weights=self.band_specs.get_freq_weights(),
|
| 332 |
+
n_freq=n_freq,
|
| 333 |
+
emb_dim=emb_dim,
|
| 334 |
+
mlp_dim=mlp_dim,
|
| 335 |
+
in_channels=in_channels,
|
| 336 |
+
hidden_activation=hidden_activation,
|
| 337 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 338 |
+
complex_mask=complex_mask,
|
| 339 |
+
use_freq_weights=use_freq_weights,
|
| 340 |
+
)
|
| 341 |
+
for stem in stems
|
| 342 |
+
}
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
def separate(self, batch):
|
| 346 |
+
batch["estimates"] = {}
|
| 347 |
+
|
| 348 |
+
x, q, length = self.encode(batch)
|
| 349 |
+
|
| 350 |
+
for stem, mem in self.mask_estim.items():
|
| 351 |
+
m = mem(q)
|
| 352 |
+
|
| 353 |
+
s = self.mask(x, m.to(x.dtype))
|
| 354 |
+
s = torch.reshape(s, x.shape)
|
| 355 |
+
batch["estimates"][stem] = {
|
| 356 |
+
"audio": self.istft(s, length),
|
| 357 |
+
"spectrogram": s,
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
return batch
|
|
|
|
|
|
|
|
|
mvsepless/models/bandit_v2/bandsplit.py
CHANGED
|
@@ -1,130 +1,127 @@
|
|
| 1 |
-
from typing import List, Tuple
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn
|
| 5 |
-
from torch.utils.checkpoint import checkpoint_sequential
|
| 6 |
-
|
| 7 |
-
from .utils import (
|
| 8 |
-
band_widths_from_specs,
|
| 9 |
-
check_no_gap,
|
| 10 |
-
check_no_overlap,
|
| 11 |
-
check_nonzero_bandwidth,
|
| 12 |
-
)
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class NormFC(nn.Module):
|
| 16 |
-
def __init__(
|
| 17 |
-
self,
|
| 18 |
-
emb_dim: int,
|
| 19 |
-
bandwidth: int,
|
| 20 |
-
in_channels: int,
|
| 21 |
-
normalize_channel_independently: bool = False,
|
| 22 |
-
treat_channel_as_feature: bool = True,
|
| 23 |
-
) -> None:
|
| 24 |
-
super().__init__()
|
| 25 |
-
|
| 26 |
-
if not treat_channel_as_feature:
|
| 27 |
-
raise NotImplementedError
|
| 28 |
-
|
| 29 |
-
self.treat_channel_as_feature = treat_channel_as_feature
|
| 30 |
-
|
| 31 |
-
if normalize_channel_independently:
|
| 32 |
-
raise NotImplementedError
|
| 33 |
-
|
| 34 |
-
reim = 2
|
| 35 |
-
|
| 36 |
-
norm = nn.LayerNorm(in_channels * bandwidth * reim)
|
| 37 |
-
|
| 38 |
-
fc_in = bandwidth * reim
|
| 39 |
-
|
| 40 |
-
if treat_channel_as_feature:
|
| 41 |
-
fc_in *= in_channels
|
| 42 |
-
else:
|
| 43 |
-
assert emb_dim % in_channels == 0
|
| 44 |
-
emb_dim = emb_dim // in_channels
|
| 45 |
-
|
| 46 |
-
fc = nn.Linear(fc_in, emb_dim)
|
| 47 |
-
|
| 48 |
-
self.combined = nn.Sequential(norm, fc)
|
| 49 |
-
|
| 50 |
-
def forward(self, xb):
|
| 51 |
-
return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False)
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
class BandSplitModule(nn.Module):
|
| 55 |
-
def __init__(
|
| 56 |
-
self,
|
| 57 |
-
band_specs: List[Tuple[float, float]],
|
| 58 |
-
emb_dim: int,
|
| 59 |
-
in_channels: int,
|
| 60 |
-
require_no_overlap: bool = False,
|
| 61 |
-
require_no_gap: bool = True,
|
| 62 |
-
normalize_channel_independently: bool = False,
|
| 63 |
-
treat_channel_as_feature: bool = True,
|
| 64 |
-
) -> None:
|
| 65 |
-
super().__init__()
|
| 66 |
-
|
| 67 |
-
check_nonzero_bandwidth(band_specs)
|
| 68 |
-
|
| 69 |
-
if require_no_gap:
|
| 70 |
-
check_no_gap(band_specs)
|
| 71 |
-
|
| 72 |
-
if require_no_overlap:
|
| 73 |
-
check_no_overlap(band_specs)
|
| 74 |
-
|
| 75 |
-
self.band_specs = band_specs
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
self.
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
z[:, i, :, :] = nfm(xb)
|
| 129 |
-
|
| 130 |
-
return z
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.utils.checkpoint import checkpoint_sequential
|
| 6 |
+
|
| 7 |
+
from .utils import (
|
| 8 |
+
band_widths_from_specs,
|
| 9 |
+
check_no_gap,
|
| 10 |
+
check_no_overlap,
|
| 11 |
+
check_nonzero_bandwidth,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class NormFC(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
emb_dim: int,
|
| 19 |
+
bandwidth: int,
|
| 20 |
+
in_channels: int,
|
| 21 |
+
normalize_channel_independently: bool = False,
|
| 22 |
+
treat_channel_as_feature: bool = True,
|
| 23 |
+
) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
if not treat_channel_as_feature:
|
| 27 |
+
raise NotImplementedError
|
| 28 |
+
|
| 29 |
+
self.treat_channel_as_feature = treat_channel_as_feature
|
| 30 |
+
|
| 31 |
+
if normalize_channel_independently:
|
| 32 |
+
raise NotImplementedError
|
| 33 |
+
|
| 34 |
+
reim = 2
|
| 35 |
+
|
| 36 |
+
norm = nn.LayerNorm(in_channels * bandwidth * reim)
|
| 37 |
+
|
| 38 |
+
fc_in = bandwidth * reim
|
| 39 |
+
|
| 40 |
+
if treat_channel_as_feature:
|
| 41 |
+
fc_in *= in_channels
|
| 42 |
+
else:
|
| 43 |
+
assert emb_dim % in_channels == 0
|
| 44 |
+
emb_dim = emb_dim // in_channels
|
| 45 |
+
|
| 46 |
+
fc = nn.Linear(fc_in, emb_dim)
|
| 47 |
+
|
| 48 |
+
self.combined = nn.Sequential(norm, fc)
|
| 49 |
+
|
| 50 |
+
def forward(self, xb):
|
| 51 |
+
return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class BandSplitModule(nn.Module):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
band_specs: List[Tuple[float, float]],
|
| 58 |
+
emb_dim: int,
|
| 59 |
+
in_channels: int,
|
| 60 |
+
require_no_overlap: bool = False,
|
| 61 |
+
require_no_gap: bool = True,
|
| 62 |
+
normalize_channel_independently: bool = False,
|
| 63 |
+
treat_channel_as_feature: bool = True,
|
| 64 |
+
) -> None:
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
check_nonzero_bandwidth(band_specs)
|
| 68 |
+
|
| 69 |
+
if require_no_gap:
|
| 70 |
+
check_no_gap(band_specs)
|
| 71 |
+
|
| 72 |
+
if require_no_overlap:
|
| 73 |
+
check_no_overlap(band_specs)
|
| 74 |
+
|
| 75 |
+
self.band_specs = band_specs
|
| 76 |
+
self.band_widths = band_widths_from_specs(band_specs)
|
| 77 |
+
self.n_bands = len(band_specs)
|
| 78 |
+
self.emb_dim = emb_dim
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
self.norm_fc_modules = nn.ModuleList(
|
| 82 |
+
[ # type: ignore
|
| 83 |
+
torch.compile(
|
| 84 |
+
NormFC(
|
| 85 |
+
emb_dim=emb_dim,
|
| 86 |
+
bandwidth=bw,
|
| 87 |
+
in_channels=in_channels,
|
| 88 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 89 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 90 |
+
),
|
| 91 |
+
disable=True,
|
| 92 |
+
)
|
| 93 |
+
for bw in self.band_widths
|
| 94 |
+
]
|
| 95 |
+
)
|
| 96 |
+
except Exception as e:
|
| 97 |
+
self.norm_fc_modules = nn.ModuleList(
|
| 98 |
+
[ # type: ignore
|
| 99 |
+
NormFC(
|
| 100 |
+
emb_dim=emb_dim,
|
| 101 |
+
bandwidth=bw,
|
| 102 |
+
in_channels=in_channels,
|
| 103 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 104 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 105 |
+
)
|
| 106 |
+
for bw in self.band_widths
|
| 107 |
+
]
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def forward(self, x: torch.Tensor):
|
| 111 |
+
|
| 112 |
+
batch, in_chan, band_width, n_time = x.shape
|
| 113 |
+
|
| 114 |
+
z = torch.zeros(
|
| 115 |
+
size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
x = torch.permute(x, (0, 3, 1, 2)).contiguous()
|
| 119 |
+
|
| 120 |
+
for i, nfm in enumerate(self.norm_fc_modules):
|
| 121 |
+
fstart, fend = self.band_specs[i]
|
| 122 |
+
xb = x[:, :, :, fstart:fend]
|
| 123 |
+
xb = torch.view_as_real(xb)
|
| 124 |
+
xb = torch.reshape(xb, (batch, n_time, -1))
|
| 125 |
+
z[:, i, :, :] = nfm(xb)
|
| 126 |
+
|
| 127 |
+
return z
|
|
|
|
|
|
|
|
|
mvsepless/models/bandit_v2/film.py
CHANGED
|
@@ -1,23 +1,23 @@
|
|
| 1 |
-
from torch import nn
|
| 2 |
-
import torch
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
class FiLM(nn.Module):
|
| 6 |
-
def __init__(self):
|
| 7 |
-
super().__init__()
|
| 8 |
-
|
| 9 |
-
def forward(self, x, gamma, beta):
|
| 10 |
-
return gamma * x + beta
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class BTFBroadcastedFiLM(nn.Module):
|
| 14 |
-
def __init__(self):
|
| 15 |
-
super().__init__()
|
| 16 |
-
self.film = FiLM()
|
| 17 |
-
|
| 18 |
-
def forward(self, x, gamma, beta):
|
| 19 |
-
|
| 20 |
-
gamma = gamma[None, None, None, :]
|
| 21 |
-
beta = beta[None, None, None, :]
|
| 22 |
-
|
| 23 |
-
return self.film(x, gamma, beta)
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class FiLM(nn.Module):
|
| 6 |
+
def __init__(self):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
def forward(self, x, gamma, beta):
|
| 10 |
+
return gamma * x + beta
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BTFBroadcastedFiLM(nn.Module):
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.film = FiLM()
|
| 17 |
+
|
| 18 |
+
def forward(self, x, gamma, beta):
|
| 19 |
+
|
| 20 |
+
gamma = gamma[None, None, None, :]
|
| 21 |
+
beta = beta[None, None, None, :]
|
| 22 |
+
|
| 23 |
+
return self.film(x, gamma, beta)
|
mvsepless/models/bandit_v2/maskestim.py
CHANGED
|
@@ -1,281 +1,269 @@
|
|
| 1 |
-
from typing import Dict, List, Optional, Tuple, Type
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn
|
| 5 |
-
from torch.nn.modules import activation
|
| 6 |
-
from torch.utils.checkpoint import checkpoint_sequential
|
| 7 |
-
|
| 8 |
-
from .utils import (
|
| 9 |
-
band_widths_from_specs,
|
| 10 |
-
check_no_gap,
|
| 11 |
-
check_no_overlap,
|
| 12 |
-
check_nonzero_bandwidth,
|
| 13 |
-
)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class BaseNormMLP(nn.Module):
|
| 17 |
-
def __init__(
|
| 18 |
-
self,
|
| 19 |
-
emb_dim: int,
|
| 20 |
-
mlp_dim: int,
|
| 21 |
-
bandwidth: int,
|
| 22 |
-
in_channels: Optional[int],
|
| 23 |
-
hidden_activation: str = "Tanh",
|
| 24 |
-
hidden_activation_kwargs=None,
|
| 25 |
-
complex_mask: bool = True,
|
| 26 |
-
):
|
| 27 |
-
super().__init__()
|
| 28 |
-
if hidden_activation_kwargs is None:
|
| 29 |
-
hidden_activation_kwargs = {}
|
| 30 |
-
self.hidden_activation_kwargs = hidden_activation_kwargs
|
| 31 |
-
self.norm = nn.LayerNorm(emb_dim)
|
| 32 |
-
self.hidden = nn.Sequential(
|
| 33 |
-
nn.Linear(in_features=emb_dim, out_features=mlp_dim),
|
| 34 |
-
activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
self.bandwidth = bandwidth
|
| 38 |
-
self.in_channels = in_channels
|
| 39 |
-
|
| 40 |
-
self.complex_mask = complex_mask
|
| 41 |
-
self.reim = 2 if complex_mask else 1
|
| 42 |
-
self.glu_mult = 2
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
class NormMLP(BaseNormMLP):
|
| 46 |
-
def __init__(
|
| 47 |
-
self,
|
| 48 |
-
emb_dim: int,
|
| 49 |
-
mlp_dim: int,
|
| 50 |
-
bandwidth: int,
|
| 51 |
-
in_channels: Optional[int],
|
| 52 |
-
hidden_activation: str = "Tanh",
|
| 53 |
-
hidden_activation_kwargs=None,
|
| 54 |
-
complex_mask: bool = True,
|
| 55 |
-
) -> None:
|
| 56 |
-
super().__init__(
|
| 57 |
-
emb_dim=emb_dim,
|
| 58 |
-
mlp_dim=mlp_dim,
|
| 59 |
-
bandwidth=bandwidth,
|
| 60 |
-
in_channels=in_channels,
|
| 61 |
-
hidden_activation=hidden_activation,
|
| 62 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 63 |
-
complex_mask=complex_mask,
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
self.output = nn.Sequential(
|
| 67 |
-
nn.Linear(
|
| 68 |
-
in_features=mlp_dim,
|
| 69 |
-
out_features=bandwidth * in_channels * self.reim * 2,
|
| 70 |
-
),
|
| 71 |
-
nn.GLU(dim=-1),
|
| 72 |
-
)
|
| 73 |
-
|
| 74 |
-
try:
|
| 75 |
-
self.combined = torch.compile(
|
| 76 |
-
nn.Sequential(self.norm, self.hidden, self.output), disable=True
|
| 77 |
-
)
|
| 78 |
-
except Exception as e:
|
| 79 |
-
self.combined = nn.Sequential(self.norm, self.hidden, self.output)
|
| 80 |
-
|
| 81 |
-
def reshape_output(self, mb):
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
mb =
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
self.
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
def forward(self, q, cond=None):
|
| 272 |
-
# q = (batch, n_bands, n_time, emb_dim)
|
| 273 |
-
|
| 274 |
-
masks = self.compute_masks(
|
| 275 |
-
q
|
| 276 |
-
) # [n_bands * (batch, in_channels, bandwidth, n_time)]
|
| 277 |
-
|
| 278 |
-
# TODO: currently this requires band specs to have no gap and no overlap
|
| 279 |
-
masks = torch.concat(masks, dim=2) # (batch, in_channels, n_freq, n_time)
|
| 280 |
-
|
| 281 |
-
return masks
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple, Type
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.modules import activation
|
| 6 |
+
from torch.utils.checkpoint import checkpoint_sequential
|
| 7 |
+
|
| 8 |
+
from .utils import (
|
| 9 |
+
band_widths_from_specs,
|
| 10 |
+
check_no_gap,
|
| 11 |
+
check_no_overlap,
|
| 12 |
+
check_nonzero_bandwidth,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BaseNormMLP(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
emb_dim: int,
|
| 20 |
+
mlp_dim: int,
|
| 21 |
+
bandwidth: int,
|
| 22 |
+
in_channels: Optional[int],
|
| 23 |
+
hidden_activation: str = "Tanh",
|
| 24 |
+
hidden_activation_kwargs=None,
|
| 25 |
+
complex_mask: bool = True,
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
if hidden_activation_kwargs is None:
|
| 29 |
+
hidden_activation_kwargs = {}
|
| 30 |
+
self.hidden_activation_kwargs = hidden_activation_kwargs
|
| 31 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 32 |
+
self.hidden = nn.Sequential(
|
| 33 |
+
nn.Linear(in_features=emb_dim, out_features=mlp_dim),
|
| 34 |
+
activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
self.bandwidth = bandwidth
|
| 38 |
+
self.in_channels = in_channels
|
| 39 |
+
|
| 40 |
+
self.complex_mask = complex_mask
|
| 41 |
+
self.reim = 2 if complex_mask else 1
|
| 42 |
+
self.glu_mult = 2
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class NormMLP(BaseNormMLP):
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
emb_dim: int,
|
| 49 |
+
mlp_dim: int,
|
| 50 |
+
bandwidth: int,
|
| 51 |
+
in_channels: Optional[int],
|
| 52 |
+
hidden_activation: str = "Tanh",
|
| 53 |
+
hidden_activation_kwargs=None,
|
| 54 |
+
complex_mask: bool = True,
|
| 55 |
+
) -> None:
|
| 56 |
+
super().__init__(
|
| 57 |
+
emb_dim=emb_dim,
|
| 58 |
+
mlp_dim=mlp_dim,
|
| 59 |
+
bandwidth=bandwidth,
|
| 60 |
+
in_channels=in_channels,
|
| 61 |
+
hidden_activation=hidden_activation,
|
| 62 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 63 |
+
complex_mask=complex_mask,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self.output = nn.Sequential(
|
| 67 |
+
nn.Linear(
|
| 68 |
+
in_features=mlp_dim,
|
| 69 |
+
out_features=bandwidth * in_channels * self.reim * 2,
|
| 70 |
+
),
|
| 71 |
+
nn.GLU(dim=-1),
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
self.combined = torch.compile(
|
| 76 |
+
nn.Sequential(self.norm, self.hidden, self.output), disable=True
|
| 77 |
+
)
|
| 78 |
+
except Exception as e:
|
| 79 |
+
self.combined = nn.Sequential(self.norm, self.hidden, self.output)
|
| 80 |
+
|
| 81 |
+
def reshape_output(self, mb):
|
| 82 |
+
batch, n_time, _ = mb.shape
|
| 83 |
+
if self.complex_mask:
|
| 84 |
+
mb = mb.reshape(
|
| 85 |
+
batch, n_time, self.in_channels, self.bandwidth, self.reim
|
| 86 |
+
).contiguous()
|
| 87 |
+
mb = torch.view_as_complex(mb)
|
| 88 |
+
else:
|
| 89 |
+
mb = mb.reshape(batch, n_time, self.in_channels, self.bandwidth)
|
| 90 |
+
|
| 91 |
+
mb = torch.permute(mb, (0, 2, 3, 1))
|
| 92 |
+
|
| 93 |
+
return mb
|
| 94 |
+
|
| 95 |
+
def forward(self, qb):
|
| 96 |
+
|
| 97 |
+
mb = checkpoint_sequential(self.combined, 2, qb, use_reentrant=False)
|
| 98 |
+
mb = self.reshape_output(mb)
|
| 99 |
+
|
| 100 |
+
return mb
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class MaskEstimationModuleSuperBase(nn.Module):
|
| 104 |
+
pass
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
band_specs: List[Tuple[float, float]],
|
| 111 |
+
emb_dim: int,
|
| 112 |
+
mlp_dim: int,
|
| 113 |
+
in_channels: Optional[int],
|
| 114 |
+
hidden_activation: str = "Tanh",
|
| 115 |
+
hidden_activation_kwargs: Dict = None,
|
| 116 |
+
complex_mask: bool = True,
|
| 117 |
+
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 118 |
+
norm_mlp_kwargs: Dict = None,
|
| 119 |
+
) -> None:
|
| 120 |
+
super().__init__()
|
| 121 |
+
|
| 122 |
+
self.band_widths = band_widths_from_specs(band_specs)
|
| 123 |
+
self.n_bands = len(band_specs)
|
| 124 |
+
|
| 125 |
+
if hidden_activation_kwargs is None:
|
| 126 |
+
hidden_activation_kwargs = {}
|
| 127 |
+
|
| 128 |
+
if norm_mlp_kwargs is None:
|
| 129 |
+
norm_mlp_kwargs = {}
|
| 130 |
+
|
| 131 |
+
self.norm_mlp = nn.ModuleList(
|
| 132 |
+
[
|
| 133 |
+
norm_mlp_cls(
|
| 134 |
+
bandwidth=self.band_widths[b],
|
| 135 |
+
emb_dim=emb_dim,
|
| 136 |
+
mlp_dim=mlp_dim,
|
| 137 |
+
in_channels=in_channels,
|
| 138 |
+
hidden_activation=hidden_activation,
|
| 139 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 140 |
+
complex_mask=complex_mask,
|
| 141 |
+
**norm_mlp_kwargs,
|
| 142 |
+
)
|
| 143 |
+
for b in range(self.n_bands)
|
| 144 |
+
]
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
def compute_masks(self, q):
|
| 148 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 149 |
+
|
| 150 |
+
masks = []
|
| 151 |
+
|
| 152 |
+
for b, nmlp in enumerate(self.norm_mlp):
|
| 153 |
+
qb = q[:, b, :, :]
|
| 154 |
+
mb = nmlp(qb)
|
| 155 |
+
masks.append(mb)
|
| 156 |
+
|
| 157 |
+
return masks
|
| 158 |
+
|
| 159 |
+
def compute_mask(self, q, b):
|
| 160 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 161 |
+
qb = q[:, b, :, :]
|
| 162 |
+
mb = self.norm_mlp[b](qb)
|
| 163 |
+
return mb
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
in_channels: int,
|
| 170 |
+
band_specs: List[Tuple[float, float]],
|
| 171 |
+
freq_weights: List[torch.Tensor],
|
| 172 |
+
n_freq: int,
|
| 173 |
+
emb_dim: int,
|
| 174 |
+
mlp_dim: int,
|
| 175 |
+
cond_dim: int = 0,
|
| 176 |
+
hidden_activation: str = "Tanh",
|
| 177 |
+
hidden_activation_kwargs: Dict = None,
|
| 178 |
+
complex_mask: bool = True,
|
| 179 |
+
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 180 |
+
norm_mlp_kwargs: Dict = None,
|
| 181 |
+
use_freq_weights: bool = False,
|
| 182 |
+
) -> None:
|
| 183 |
+
check_nonzero_bandwidth(band_specs)
|
| 184 |
+
check_no_gap(band_specs)
|
| 185 |
+
|
| 186 |
+
if cond_dim > 0:
|
| 187 |
+
raise NotImplementedError
|
| 188 |
+
|
| 189 |
+
super().__init__(
|
| 190 |
+
band_specs=band_specs,
|
| 191 |
+
emb_dim=emb_dim + cond_dim,
|
| 192 |
+
mlp_dim=mlp_dim,
|
| 193 |
+
in_channels=in_channels,
|
| 194 |
+
hidden_activation=hidden_activation,
|
| 195 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 196 |
+
complex_mask=complex_mask,
|
| 197 |
+
norm_mlp_cls=norm_mlp_cls,
|
| 198 |
+
norm_mlp_kwargs=norm_mlp_kwargs,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
self.n_freq = n_freq
|
| 202 |
+
self.band_specs = band_specs
|
| 203 |
+
self.in_channels = in_channels
|
| 204 |
+
|
| 205 |
+
if freq_weights is not None and use_freq_weights:
|
| 206 |
+
for i, fw in enumerate(freq_weights):
|
| 207 |
+
self.register_buffer(f"freq_weights/{i}", fw)
|
| 208 |
+
|
| 209 |
+
self.use_freq_weights = use_freq_weights
|
| 210 |
+
else:
|
| 211 |
+
self.use_freq_weights = False
|
| 212 |
+
|
| 213 |
+
def forward(self, q):
|
| 214 |
+
|
| 215 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 216 |
+
|
| 217 |
+
masks = torch.zeros(
|
| 218 |
+
(batch, self.in_channels, self.n_freq, n_time),
|
| 219 |
+
device=q.device,
|
| 220 |
+
dtype=torch.complex64,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
for im in range(n_bands):
|
| 224 |
+
fstart, fend = self.band_specs[im]
|
| 225 |
+
|
| 226 |
+
mask = self.compute_mask(q, im)
|
| 227 |
+
|
| 228 |
+
if self.use_freq_weights:
|
| 229 |
+
fw = self.get_buffer(f"freq_weights/{im}")[:, None]
|
| 230 |
+
mask = mask * fw
|
| 231 |
+
masks[:, :, fstart:fend, :] += mask
|
| 232 |
+
|
| 233 |
+
return masks
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class MaskEstimationModule(OverlappingMaskEstimationModule):
|
| 237 |
+
def __init__(
|
| 238 |
+
self,
|
| 239 |
+
band_specs: List[Tuple[float, float]],
|
| 240 |
+
emb_dim: int,
|
| 241 |
+
mlp_dim: int,
|
| 242 |
+
in_channels: Optional[int],
|
| 243 |
+
hidden_activation: str = "Tanh",
|
| 244 |
+
hidden_activation_kwargs: Dict = None,
|
| 245 |
+
complex_mask: bool = True,
|
| 246 |
+
**kwargs,
|
| 247 |
+
) -> None:
|
| 248 |
+
check_nonzero_bandwidth(band_specs)
|
| 249 |
+
check_no_gap(band_specs)
|
| 250 |
+
check_no_overlap(band_specs)
|
| 251 |
+
super().__init__(
|
| 252 |
+
in_channels=in_channels,
|
| 253 |
+
band_specs=band_specs,
|
| 254 |
+
freq_weights=None,
|
| 255 |
+
n_freq=None,
|
| 256 |
+
emb_dim=emb_dim,
|
| 257 |
+
mlp_dim=mlp_dim,
|
| 258 |
+
hidden_activation=hidden_activation,
|
| 259 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 260 |
+
complex_mask=complex_mask,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
def forward(self, q, cond=None):
|
| 264 |
+
|
| 265 |
+
masks = self.compute_masks(q)
|
| 266 |
+
|
| 267 |
+
masks = torch.concat(masks, dim=2)
|
| 268 |
+
|
| 269 |
+
return masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mvsepless/models/bandit_v2/tfmodel.py
CHANGED
|
@@ -1,145 +1,141 @@
|
|
| 1 |
-
import warnings
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.backends.cuda
|
| 5 |
-
from torch import nn
|
| 6 |
-
from torch.nn.modules import rnn
|
| 7 |
-
from torch.utils.checkpoint import checkpoint_sequential
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class TimeFrequencyModellingModule(nn.Module):
|
| 11 |
-
def __init__(self) -> None:
|
| 12 |
-
super().__init__()
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class ResidualRNN(nn.Module):
|
| 16 |
-
def __init__(
|
| 17 |
-
self,
|
| 18 |
-
emb_dim: int,
|
| 19 |
-
rnn_dim: int,
|
| 20 |
-
bidirectional: bool = True,
|
| 21 |
-
rnn_type: str = "LSTM",
|
| 22 |
-
use_batch_trick: bool = True,
|
| 23 |
-
use_layer_norm: bool = True,
|
| 24 |
-
) -> None:
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
assert
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
self.
|
| 32 |
-
self.
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
self.use_batch_trick
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
z = torch.reshape(z, (batch
|
| 57 |
-
|
| 58 |
-
z =
|
| 59 |
-
|
| 60 |
-
z =
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
)
|
| 143 |
-
|
| 144 |
-
q = z
|
| 145 |
-
return q # (batch, n_bands, n_time, emb_dim)
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.backends.cuda
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn.modules import rnn
|
| 7 |
+
from torch.utils.checkpoint import checkpoint_sequential
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TimeFrequencyModellingModule(nn.Module):
|
| 11 |
+
def __init__(self) -> None:
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ResidualRNN(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
emb_dim: int,
|
| 19 |
+
rnn_dim: int,
|
| 20 |
+
bidirectional: bool = True,
|
| 21 |
+
rnn_type: str = "LSTM",
|
| 22 |
+
use_batch_trick: bool = True,
|
| 23 |
+
use_layer_norm: bool = True,
|
| 24 |
+
) -> None:
|
| 25 |
+
super().__init__()
|
| 26 |
+
|
| 27 |
+
assert use_layer_norm
|
| 28 |
+
assert use_batch_trick
|
| 29 |
+
|
| 30 |
+
self.use_layer_norm = use_layer_norm
|
| 31 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 32 |
+
self.rnn = rnn.__dict__[rnn_type](
|
| 33 |
+
input_size=emb_dim,
|
| 34 |
+
hidden_size=rnn_dim,
|
| 35 |
+
num_layers=1,
|
| 36 |
+
batch_first=True,
|
| 37 |
+
bidirectional=bidirectional,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
self.fc = nn.Linear(
|
| 41 |
+
in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
self.use_batch_trick = use_batch_trick
|
| 45 |
+
if not self.use_batch_trick:
|
| 46 |
+
warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
|
| 47 |
+
|
| 48 |
+
def forward(self, z):
|
| 49 |
+
|
| 50 |
+
z0 = torch.clone(z)
|
| 51 |
+
z = self.norm(z)
|
| 52 |
+
|
| 53 |
+
batch, n_uncrossed, n_across, emb_dim = z.shape
|
| 54 |
+
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
|
| 55 |
+
z = self.rnn(z)[0]
|
| 56 |
+
z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
|
| 57 |
+
|
| 58 |
+
z = self.fc(z)
|
| 59 |
+
|
| 60 |
+
z = z + z0
|
| 61 |
+
|
| 62 |
+
return z
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Transpose(nn.Module):
|
| 66 |
+
def __init__(self, dim0: int, dim1: int) -> None:
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.dim0 = dim0
|
| 69 |
+
self.dim1 = dim1
|
| 70 |
+
|
| 71 |
+
def forward(self, z):
|
| 72 |
+
return z.transpose(self.dim0, self.dim1)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class SeqBandModellingModule(TimeFrequencyModellingModule):
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
n_modules: int = 12,
|
| 79 |
+
emb_dim: int = 128,
|
| 80 |
+
rnn_dim: int = 256,
|
| 81 |
+
bidirectional: bool = True,
|
| 82 |
+
rnn_type: str = "LSTM",
|
| 83 |
+
parallel_mode=False,
|
| 84 |
+
) -> None:
|
| 85 |
+
super().__init__()
|
| 86 |
+
|
| 87 |
+
self.n_modules = n_modules
|
| 88 |
+
|
| 89 |
+
if parallel_mode:
|
| 90 |
+
self.seqband = nn.ModuleList([])
|
| 91 |
+
for _ in range(n_modules):
|
| 92 |
+
self.seqband.append(
|
| 93 |
+
nn.ModuleList(
|
| 94 |
+
[
|
| 95 |
+
ResidualRNN(
|
| 96 |
+
emb_dim=emb_dim,
|
| 97 |
+
rnn_dim=rnn_dim,
|
| 98 |
+
bidirectional=bidirectional,
|
| 99 |
+
rnn_type=rnn_type,
|
| 100 |
+
),
|
| 101 |
+
ResidualRNN(
|
| 102 |
+
emb_dim=emb_dim,
|
| 103 |
+
rnn_dim=rnn_dim,
|
| 104 |
+
bidirectional=bidirectional,
|
| 105 |
+
rnn_type=rnn_type,
|
| 106 |
+
),
|
| 107 |
+
]
|
| 108 |
+
)
|
| 109 |
+
)
|
| 110 |
+
else:
|
| 111 |
+
seqband = []
|
| 112 |
+
for _ in range(2 * n_modules):
|
| 113 |
+
seqband += [
|
| 114 |
+
ResidualRNN(
|
| 115 |
+
emb_dim=emb_dim,
|
| 116 |
+
rnn_dim=rnn_dim,
|
| 117 |
+
bidirectional=bidirectional,
|
| 118 |
+
rnn_type=rnn_type,
|
| 119 |
+
),
|
| 120 |
+
Transpose(1, 2),
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
self.seqband = nn.Sequential(*seqband)
|
| 124 |
+
|
| 125 |
+
self.parallel_mode = parallel_mode
|
| 126 |
+
|
| 127 |
+
def forward(self, z):
|
| 128 |
+
|
| 129 |
+
if self.parallel_mode:
|
| 130 |
+
for sbm_pair in self.seqband:
|
| 131 |
+
sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
|
| 132 |
+
zt = sbm_t(z)
|
| 133 |
+
zf = sbm_f(z.transpose(1, 2))
|
| 134 |
+
z = zt + zf.transpose(1, 2)
|
| 135 |
+
else:
|
| 136 |
+
z = checkpoint_sequential(
|
| 137 |
+
self.seqband, self.n_modules, z, use_reentrant=False
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
q = z
|
| 141 |
+
return q
|
|
|
|
|
|
|
|
|
|
|
|
mvsepless/models/bandit_v2/utils.py
CHANGED
|
@@ -1,523 +1,384 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from abc import abstractmethod
|
| 3 |
-
from typing import Callable
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import torch
|
| 7 |
-
from librosa import hz_to_midi, midi_to_hz
|
| 8 |
-
from torchaudio import functional as taF
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
self.
|
| 47 |
-
self.
|
| 48 |
-
self.
|
| 49 |
-
|
| 50 |
-
self.
|
| 51 |
-
self.
|
| 52 |
-
self.
|
| 53 |
-
|
| 54 |
-
self.
|
| 55 |
-
self.
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def
|
| 97 |
-
return
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
below20k
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
below16k
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
below16k
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
)
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
)
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
f_max
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
self.
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
self.
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
os.
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
fb =
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
# return torch.as_tensor(fb)
|
| 387 |
-
|
| 388 |
-
# class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 389 |
-
# def __init__(
|
| 390 |
-
# self,
|
| 391 |
-
# nfft: int,
|
| 392 |
-
# fs: int,
|
| 393 |
-
# n_bands: int,
|
| 394 |
-
# f_min: float = 0.0,
|
| 395 |
-
# f_max: float = None
|
| 396 |
-
# ) -> None:
|
| 397 |
-
# super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
# def triangular_bark_filterbank(
|
| 401 |
-
# n_bands, fs, f_min, f_max, n_freqs
|
| 402 |
-
# ):
|
| 403 |
-
|
| 404 |
-
# all_freqs = torch.linspace(0, fs // 2, n_freqs)
|
| 405 |
-
|
| 406 |
-
# # calculate mel freq bins
|
| 407 |
-
# m_min = hz2bark(f_min)
|
| 408 |
-
# m_max = hz2bark(f_max)
|
| 409 |
-
|
| 410 |
-
# m_pts = torch.linspace(m_min, m_max, n_bands + 2)
|
| 411 |
-
# f_pts = 600 * torch.sinh(m_pts / 6)
|
| 412 |
-
|
| 413 |
-
# # create filterbank
|
| 414 |
-
# fb = _create_triangular_filterbank(all_freqs, f_pts)
|
| 415 |
-
|
| 416 |
-
# fb = fb.T
|
| 417 |
-
|
| 418 |
-
# first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
|
| 419 |
-
# first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
|
| 420 |
-
|
| 421 |
-
# fb[first_active_band, :first_active_bin] = 1.0
|
| 422 |
-
|
| 423 |
-
# return fb
|
| 424 |
-
|
| 425 |
-
# class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 426 |
-
# def __init__(
|
| 427 |
-
# self,
|
| 428 |
-
# nfft: int,
|
| 429 |
-
# fs: int,
|
| 430 |
-
# n_bands: int,
|
| 431 |
-
# f_min: float = 0.0,
|
| 432 |
-
# f_max: float = None
|
| 433 |
-
# ) -> None:
|
| 434 |
-
# super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
# def minibark_filterbank(
|
| 438 |
-
# n_bands, fs, f_min, f_max, n_freqs
|
| 439 |
-
# ):
|
| 440 |
-
# fb = bark_filterbank(
|
| 441 |
-
# n_bands,
|
| 442 |
-
# fs,
|
| 443 |
-
# f_min,
|
| 444 |
-
# f_max,
|
| 445 |
-
# n_freqs
|
| 446 |
-
# )
|
| 447 |
-
|
| 448 |
-
# fb[fb < np.sqrt(0.5)] = 0.0
|
| 449 |
-
|
| 450 |
-
# return fb
|
| 451 |
-
|
| 452 |
-
# class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 453 |
-
# def __init__(
|
| 454 |
-
# self,
|
| 455 |
-
# nfft: int,
|
| 456 |
-
# fs: int,
|
| 457 |
-
# n_bands: int,
|
| 458 |
-
# f_min: float = 0.0,
|
| 459 |
-
# f_max: float = None
|
| 460 |
-
# ) -> None:
|
| 461 |
-
# super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
# def erb_filterbank(
|
| 465 |
-
# n_bands: int,
|
| 466 |
-
# fs: int,
|
| 467 |
-
# f_min: float,
|
| 468 |
-
# f_max: float,
|
| 469 |
-
# n_freqs: int,
|
| 470 |
-
# ) -> Tensor:
|
| 471 |
-
# # freq bins
|
| 472 |
-
# A = (1000 * np.log(10)) / (24.7 * 4.37)
|
| 473 |
-
# all_freqs = torch.linspace(0, fs // 2, n_freqs)
|
| 474 |
-
|
| 475 |
-
# # calculate mel freq bins
|
| 476 |
-
# m_min = hz2erb(f_min)
|
| 477 |
-
# m_max = hz2erb(f_max)
|
| 478 |
-
|
| 479 |
-
# m_pts = torch.linspace(m_min, m_max, n_bands + 2)
|
| 480 |
-
# f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437
|
| 481 |
-
|
| 482 |
-
# # create filterbank
|
| 483 |
-
# fb = _create_triangular_filterbank(all_freqs, f_pts)
|
| 484 |
-
|
| 485 |
-
# fb = fb.T
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
# first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
|
| 489 |
-
# first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
|
| 490 |
-
|
| 491 |
-
# fb[first_active_band, :first_active_bin] = 1.0
|
| 492 |
-
|
| 493 |
-
# return fb
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
# class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 497 |
-
# def __init__(
|
| 498 |
-
# self,
|
| 499 |
-
# nfft: int,
|
| 500 |
-
# fs: int,
|
| 501 |
-
# n_bands: int,
|
| 502 |
-
# f_min: float = 0.0,
|
| 503 |
-
# f_max: float = None
|
| 504 |
-
# ) -> None:
|
| 505 |
-
# super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
| 506 |
-
|
| 507 |
-
if __name__ == "__main__":
|
| 508 |
-
import pandas as pd
|
| 509 |
-
|
| 510 |
-
band_defs = []
|
| 511 |
-
|
| 512 |
-
for bands in [VocalBandsplitSpecification]:
|
| 513 |
-
band_name = bands.__name__.replace("BandsplitSpecification", "")
|
| 514 |
-
|
| 515 |
-
mbs = bands(nfft=2048, fs=44100).get_band_specs()
|
| 516 |
-
|
| 517 |
-
for i, (f_min, f_max) in enumerate(mbs):
|
| 518 |
-
band_defs.append(
|
| 519 |
-
{"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
|
| 520 |
-
)
|
| 521 |
-
|
| 522 |
-
df = pd.DataFrame(band_defs)
|
| 523 |
-
df.to_csv("vox7bands.csv", index=False)
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import abstractmethod
|
| 3 |
+
from typing import Callable
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from librosa import hz_to_midi, midi_to_hz
|
| 8 |
+
from torchaudio import functional as taF
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def band_widths_from_specs(band_specs):
|
| 12 |
+
return [e - i for i, e in band_specs]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def check_nonzero_bandwidth(band_specs):
|
| 16 |
+
for fstart, fend in band_specs:
|
| 17 |
+
if fend - fstart <= 0:
|
| 18 |
+
raise ValueError("Bands cannot be zero-width")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def check_no_overlap(band_specs):
|
| 22 |
+
fend_prev = -1
|
| 23 |
+
for fstart_curr, fend_curr in band_specs:
|
| 24 |
+
if fstart_curr <= fend_prev:
|
| 25 |
+
raise ValueError("Bands cannot overlap")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def check_no_gap(band_specs):
|
| 29 |
+
fstart, _ = band_specs[0]
|
| 30 |
+
assert fstart == 0
|
| 31 |
+
|
| 32 |
+
fend_prev = -1
|
| 33 |
+
for fstart_curr, fend_curr in band_specs:
|
| 34 |
+
if fstart_curr - fend_prev > 1:
|
| 35 |
+
raise ValueError("Bands cannot leave gap")
|
| 36 |
+
fend_prev = fend_curr
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class BandsplitSpecification:
|
| 40 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 41 |
+
self.fs = fs
|
| 42 |
+
self.nfft = nfft
|
| 43 |
+
self.nyquist = fs / 2
|
| 44 |
+
self.max_index = nfft // 2 + 1
|
| 45 |
+
|
| 46 |
+
self.split500 = self.hertz_to_index(500)
|
| 47 |
+
self.split1k = self.hertz_to_index(1000)
|
| 48 |
+
self.split2k = self.hertz_to_index(2000)
|
| 49 |
+
self.split4k = self.hertz_to_index(4000)
|
| 50 |
+
self.split8k = self.hertz_to_index(8000)
|
| 51 |
+
self.split16k = self.hertz_to_index(16000)
|
| 52 |
+
self.split20k = self.hertz_to_index(20000)
|
| 53 |
+
|
| 54 |
+
self.above20k = [(self.split20k, self.max_index)]
|
| 55 |
+
self.above16k = [(self.split16k, self.split20k)] + self.above20k
|
| 56 |
+
|
| 57 |
+
def index_to_hertz(self, index: int):
|
| 58 |
+
return index * self.fs / self.nfft
|
| 59 |
+
|
| 60 |
+
def hertz_to_index(self, hz: float, round: bool = True):
|
| 61 |
+
index = hz * self.nfft / self.fs
|
| 62 |
+
|
| 63 |
+
if round:
|
| 64 |
+
index = int(np.round(index))
|
| 65 |
+
|
| 66 |
+
return index
|
| 67 |
+
|
| 68 |
+
def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
|
| 69 |
+
band_specs = []
|
| 70 |
+
lower = start_index
|
| 71 |
+
|
| 72 |
+
while lower < end_index:
|
| 73 |
+
upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
|
| 74 |
+
upper = min(upper, end_index)
|
| 75 |
+
|
| 76 |
+
band_specs.append((lower, upper))
|
| 77 |
+
lower = upper
|
| 78 |
+
|
| 79 |
+
return band_specs
|
| 80 |
+
|
| 81 |
+
@abstractmethod
|
| 82 |
+
def get_band_specs(self):
|
| 83 |
+
raise NotImplementedError
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class VocalBandsplitSpecification(BandsplitSpecification):
|
| 87 |
+
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 88 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 89 |
+
|
| 90 |
+
self.version = version
|
| 91 |
+
|
| 92 |
+
def get_band_specs(self):
|
| 93 |
+
return getattr(self, f"version{self.version}")()
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def version1(self):
|
| 97 |
+
return self.get_band_specs_with_bandwidth(
|
| 98 |
+
start_index=0, end_index=self.max_index, bandwidth_hz=1000
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def version2(self):
|
| 102 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 103 |
+
start_index=0, end_index=self.split16k, bandwidth_hz=1000
|
| 104 |
+
)
|
| 105 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 106 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return below16k + below20k + self.above20k
|
| 110 |
+
|
| 111 |
+
def version3(self):
|
| 112 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 113 |
+
start_index=0, end_index=self.split8k, bandwidth_hz=1000
|
| 114 |
+
)
|
| 115 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 116 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return below8k + below16k + self.above16k
|
| 120 |
+
|
| 121 |
+
def version4(self):
|
| 122 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 123 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 124 |
+
)
|
| 125 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 126 |
+
start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
|
| 127 |
+
)
|
| 128 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 129 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
return below1k + below8k + below16k + self.above16k
|
| 133 |
+
|
| 134 |
+
def version5(self):
|
| 135 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 136 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 137 |
+
)
|
| 138 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 139 |
+
start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
|
| 140 |
+
)
|
| 141 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 142 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 143 |
+
)
|
| 144 |
+
return below1k + below16k + below20k + self.above20k
|
| 145 |
+
|
| 146 |
+
def version6(self):
|
| 147 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 148 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 149 |
+
)
|
| 150 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 151 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
|
| 152 |
+
)
|
| 153 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 154 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
|
| 155 |
+
)
|
| 156 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 157 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 158 |
+
)
|
| 159 |
+
return below1k + below4k + below8k + below16k + self.above16k
|
| 160 |
+
|
| 161 |
+
def version7(self):
|
| 162 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 163 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 164 |
+
)
|
| 165 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 166 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
|
| 167 |
+
)
|
| 168 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 169 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
|
| 170 |
+
)
|
| 171 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 172 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
|
| 173 |
+
)
|
| 174 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 175 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 176 |
+
)
|
| 177 |
+
return below1k + below4k + below8k + below16k + below20k + self.above20k
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class OtherBandsplitSpecification(VocalBandsplitSpecification):
|
| 181 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 182 |
+
super().__init__(nfft=nfft, fs=fs, version="7")
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class BassBandsplitSpecification(BandsplitSpecification):
|
| 186 |
+
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 187 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 188 |
+
|
| 189 |
+
def get_band_specs(self):
|
| 190 |
+
below500 = self.get_band_specs_with_bandwidth(
|
| 191 |
+
start_index=0, end_index=self.split500, bandwidth_hz=50
|
| 192 |
+
)
|
| 193 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 194 |
+
start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
|
| 195 |
+
)
|
| 196 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 197 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
|
| 198 |
+
)
|
| 199 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 200 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
|
| 201 |
+
)
|
| 202 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 203 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 204 |
+
)
|
| 205 |
+
above16k = [(self.split16k, self.max_index)]
|
| 206 |
+
|
| 207 |
+
return below500 + below1k + below4k + below8k + below16k + above16k
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class DrumBandsplitSpecification(BandsplitSpecification):
|
| 211 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 212 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 213 |
+
|
| 214 |
+
def get_band_specs(self):
|
| 215 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 216 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=50
|
| 217 |
+
)
|
| 218 |
+
below2k = self.get_band_specs_with_bandwidth(
|
| 219 |
+
start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
|
| 220 |
+
)
|
| 221 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 222 |
+
start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
|
| 223 |
+
)
|
| 224 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 225 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
|
| 226 |
+
)
|
| 227 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 228 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
|
| 229 |
+
)
|
| 230 |
+
above16k = [(self.split16k, self.max_index)]
|
| 231 |
+
|
| 232 |
+
return below1k + below2k + below4k + below8k + below16k + above16k
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class PerceptualBandsplitSpecification(BandsplitSpecification):
|
| 236 |
+
def __init__(
|
| 237 |
+
self,
|
| 238 |
+
nfft: int,
|
| 239 |
+
fs: int,
|
| 240 |
+
fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
|
| 241 |
+
n_bands: int,
|
| 242 |
+
f_min: float = 0.0,
|
| 243 |
+
f_max: float = None,
|
| 244 |
+
) -> None:
|
| 245 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 246 |
+
self.n_bands = n_bands
|
| 247 |
+
if f_max is None:
|
| 248 |
+
f_max = fs / 2
|
| 249 |
+
|
| 250 |
+
self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
|
| 251 |
+
|
| 252 |
+
weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True)
|
| 253 |
+
normalized_mel_fb = self.filterbank / weight_per_bin
|
| 254 |
+
|
| 255 |
+
freq_weights = []
|
| 256 |
+
band_specs = []
|
| 257 |
+
for i in range(self.n_bands):
|
| 258 |
+
active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
|
| 259 |
+
if isinstance(active_bins, int):
|
| 260 |
+
active_bins = (active_bins, active_bins)
|
| 261 |
+
if len(active_bins) == 0:
|
| 262 |
+
continue
|
| 263 |
+
start_index = active_bins[0]
|
| 264 |
+
end_index = active_bins[-1] + 1
|
| 265 |
+
band_specs.append((start_index, end_index))
|
| 266 |
+
freq_weights.append(normalized_mel_fb[i, start_index:end_index])
|
| 267 |
+
|
| 268 |
+
self.freq_weights = freq_weights
|
| 269 |
+
self.band_specs = band_specs
|
| 270 |
+
|
| 271 |
+
def get_band_specs(self):
|
| 272 |
+
return self.band_specs
|
| 273 |
+
|
| 274 |
+
def get_freq_weights(self):
|
| 275 |
+
return self.freq_weights
|
| 276 |
+
|
| 277 |
+
def save_to_file(self, dir_path: str) -> None:
|
| 278 |
+
os.makedirs(dir_path, exist_ok=True)
|
| 279 |
+
|
| 280 |
+
import pickle
|
| 281 |
+
|
| 282 |
+
with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
|
| 283 |
+
pickle.dump(
|
| 284 |
+
{
|
| 285 |
+
"band_specs": self.band_specs,
|
| 286 |
+
"freq_weights": self.freq_weights,
|
| 287 |
+
"filterbank": self.filterbank,
|
| 288 |
+
},
|
| 289 |
+
f,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 294 |
+
fb = taF.melscale_fbanks(
|
| 295 |
+
n_mels=n_bands,
|
| 296 |
+
sample_rate=fs,
|
| 297 |
+
f_min=f_min,
|
| 298 |
+
f_max=f_max,
|
| 299 |
+
n_freqs=n_freqs,
|
| 300 |
+
).T
|
| 301 |
+
|
| 302 |
+
fb[0, 0] = 1.0
|
| 303 |
+
|
| 304 |
+
return fb
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class MelBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 308 |
+
def __init__(
|
| 309 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 310 |
+
) -> None:
|
| 311 |
+
super().__init__(
|
| 312 |
+
fbank_fn=mel_filterbank,
|
| 313 |
+
nfft=nfft,
|
| 314 |
+
fs=fs,
|
| 315 |
+
n_bands=n_bands,
|
| 316 |
+
f_min=f_min,
|
| 317 |
+
f_max=f_max,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
|
| 322 |
+
nfft = 2 * (n_freqs - 1)
|
| 323 |
+
df = fs / nfft
|
| 324 |
+
f_max = f_max or fs / 2
|
| 325 |
+
f_min = f_min or 0
|
| 326 |
+
f_min = fs / nfft
|
| 327 |
+
|
| 328 |
+
n_octaves = np.log2(f_max / f_min)
|
| 329 |
+
n_octaves_per_band = n_octaves / n_bands
|
| 330 |
+
bandwidth_mult = np.power(2.0, n_octaves_per_band)
|
| 331 |
+
|
| 332 |
+
low_midi = max(0, hz_to_midi(f_min))
|
| 333 |
+
high_midi = hz_to_midi(f_max)
|
| 334 |
+
midi_points = np.linspace(low_midi, high_midi, n_bands)
|
| 335 |
+
hz_pts = midi_to_hz(midi_points)
|
| 336 |
+
|
| 337 |
+
low_pts = hz_pts / bandwidth_mult
|
| 338 |
+
high_pts = hz_pts * bandwidth_mult
|
| 339 |
+
|
| 340 |
+
low_bins = np.floor(low_pts / df).astype(int)
|
| 341 |
+
high_bins = np.ceil(high_pts / df).astype(int)
|
| 342 |
+
|
| 343 |
+
fb = np.zeros((n_bands, n_freqs))
|
| 344 |
+
|
| 345 |
+
for i in range(n_bands):
|
| 346 |
+
fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
|
| 347 |
+
|
| 348 |
+
fb[0, : low_bins[0]] = 1.0
|
| 349 |
+
fb[-1, high_bins[-1] + 1 :] = 1.0
|
| 350 |
+
|
| 351 |
+
return torch.as_tensor(fb)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 355 |
+
def __init__(
|
| 356 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 357 |
+
) -> None:
|
| 358 |
+
super().__init__(
|
| 359 |
+
fbank_fn=musical_filterbank,
|
| 360 |
+
nfft=nfft,
|
| 361 |
+
fs=fs,
|
| 362 |
+
n_bands=n_bands,
|
| 363 |
+
f_min=f_min,
|
| 364 |
+
f_max=f_max,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
if __name__ == "__main__":
|
| 369 |
+
import pandas as pd
|
| 370 |
+
|
| 371 |
+
band_defs = []
|
| 372 |
+
|
| 373 |
+
for bands in [VocalBandsplitSpecification]:
|
| 374 |
+
band_name = bands.__name__.replace("BandsplitSpecification", "")
|
| 375 |
+
|
| 376 |
+
mbs = bands(nfft=2048, fs=44100).get_band_specs()
|
| 377 |
+
|
| 378 |
+
for i, (f_min, f_max) in enumerate(mbs):
|
| 379 |
+
band_defs.append(
|
| 380 |
+
{"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
df = pd.DataFrame(band_defs)
|
| 384 |
+
df.to_csv("vox7bands.csv", index=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mvsepless/models/bs_roformer/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
-
from .bs_roformer import BSRoformer
|
| 2 |
-
from .bs_roformer_sw import BSRoformer_SW
|
| 3 |
-
from .bs_roformer_fno import BSRoformer_FNO
|
| 4 |
-
from .bs_roformer_hyperace import BSRoformerHyperACE
|
| 5 |
-
from .bs_roformer_hyperace2 import BSRoformerHyperACE_2
|
| 6 |
-
from .mel_band_roformer import MelBandRoformer
|
|
|
|
| 1 |
+
from .bs_roformer import BSRoformer
|
| 2 |
+
from .bs_roformer_sw import BSRoformer_SW
|
| 3 |
+
from .bs_roformer_fno import BSRoformer_FNO
|
| 4 |
+
from .bs_roformer_hyperace import BSRoformerHyperACE
|
| 5 |
+
from .bs_roformer_hyperace2 import BSRoformerHyperACE_2
|
| 6 |
+
from .mel_band_roformer import MelBandRoformer
|
mvsepless/models/bs_roformer/attend.py
CHANGED
|
@@ -1,126 +1,120 @@
|
|
| 1 |
-
from functools import wraps
|
| 2 |
-
from packaging import version
|
| 3 |
-
from collections import namedtuple
|
| 4 |
-
|
| 5 |
-
import os
|
| 6 |
-
import torch
|
| 7 |
-
from torch import nn, einsum
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
from einops import rearrange, reduce
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
FlashAttentionConfig
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def exists(val):
|
| 19 |
-
return val is not None
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
self.
|
| 48 |
-
self.
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
self.
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
self.cuda_config = FlashAttentionConfig(
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
)
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
|
| 107 |
-
|
| 108 |
-
scale = default(self.scale, q.shape[-1] ** -0.5)
|
| 109 |
-
|
| 110 |
-
if self.flash:
|
| 111 |
-
return self.flash_attn(q, k, v)
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
# aggregate values
|
| 123 |
-
|
| 124 |
-
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
|
| 125 |
-
|
| 126 |
-
return out
|
|
|
|
| 1 |
+
from functools import wraps
|
| 2 |
+
from packaging import version
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn, einsum
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from einops import rearrange, reduce
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
FlashAttentionConfig = namedtuple(
|
| 14 |
+
"FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def exists(val):
|
| 19 |
+
return val is not None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def default(v, d):
|
| 23 |
+
return v if exists(v) else d
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def once(fn):
|
| 27 |
+
called = False
|
| 28 |
+
|
| 29 |
+
@wraps(fn)
|
| 30 |
+
def inner(x):
|
| 31 |
+
nonlocal called
|
| 32 |
+
if called:
|
| 33 |
+
return
|
| 34 |
+
called = True
|
| 35 |
+
return fn(x)
|
| 36 |
+
|
| 37 |
+
return inner
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
print_once = once(print)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Attend(nn.Module):
|
| 44 |
+
def __init__(self, dropout=0.0, flash=False, scale=None):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.scale = scale
|
| 47 |
+
self.dropout = dropout
|
| 48 |
+
self.attn_dropout = nn.Dropout(dropout)
|
| 49 |
+
|
| 50 |
+
self.flash = flash
|
| 51 |
+
assert not (
|
| 52 |
+
flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
| 53 |
+
), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
| 54 |
+
|
| 55 |
+
self.cpu_config = FlashAttentionConfig(True, True, True)
|
| 56 |
+
self.cuda_config = None
|
| 57 |
+
|
| 58 |
+
if not torch.cuda.is_available() or not flash:
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
|
| 62 |
+
device_version = version.parse(
|
| 63 |
+
f"{device_properties.major}.{device_properties.minor}"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
if device_version >= version.parse("8.0"):
|
| 67 |
+
if os.name == "nt":
|
| 68 |
+
print_once(
|
| 69 |
+
"Windows OS detected, using math or mem efficient attention if input tensor is on cuda"
|
| 70 |
+
)
|
| 71 |
+
self.cuda_config = FlashAttentionConfig(False, True, True)
|
| 72 |
+
else:
|
| 73 |
+
print_once(
|
| 74 |
+
"GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda"
|
| 75 |
+
)
|
| 76 |
+
self.cuda_config = FlashAttentionConfig(True, False, False)
|
| 77 |
+
else:
|
| 78 |
+
print_once(
|
| 79 |
+
"GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda"
|
| 80 |
+
)
|
| 81 |
+
self.cuda_config = FlashAttentionConfig(False, True, True)
|
| 82 |
+
|
| 83 |
+
def flash_attn(self, q, k, v):
|
| 84 |
+
_, heads, q_len, _, k_len, is_cuda, device = (
|
| 85 |
+
*q.shape,
|
| 86 |
+
k.shape[-2],
|
| 87 |
+
q.is_cuda,
|
| 88 |
+
q.device,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
if exists(self.scale):
|
| 92 |
+
default_scale = q.shape[-1] ** -0.5
|
| 93 |
+
q = q * (self.scale / default_scale)
|
| 94 |
+
|
| 95 |
+
config = self.cuda_config if is_cuda else self.cpu_config
|
| 96 |
+
|
| 97 |
+
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
| 98 |
+
out = F.scaled_dot_product_attention(
|
| 99 |
+
q, k, v, dropout_p=self.dropout if self.training else 0.0
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
return out
|
| 103 |
+
|
| 104 |
+
def forward(self, q, k, v):
|
| 105 |
+
|
| 106 |
+
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
|
| 107 |
+
|
| 108 |
+
scale = default(self.scale, q.shape[-1] ** -0.5)
|
| 109 |
+
|
| 110 |
+
if self.flash:
|
| 111 |
+
return self.flash_attn(q, k, v)
|
| 112 |
+
|
| 113 |
+
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
|
| 114 |
+
|
| 115 |
+
attn = sim.softmax(dim=-1)
|
| 116 |
+
attn = self.attn_dropout(attn)
|
| 117 |
+
|
| 118 |
+
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
|
| 119 |
+
|
| 120 |
+
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|