noblebarkrr commited on
Commit
6cc8dc1
·
verified ·
1 Parent(s): 099fa7a

Убраны комментарии и отформатирован код

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. MVSepLess_Epsilon_Colab.ipynb +16 -6
  2. mvsepless/__init__.py +0 -0
  3. mvsepless/__main__.py +61 -14
  4. mvsepless/audio.py +789 -781
  5. mvsepless/downloader.py +90 -92
  6. mvsepless/ensemble.py +206 -224
  7. mvsepless/infer.py +116 -65
  8. mvsepless/infer_utils.py +41 -69
  9. mvsepless/model_manager.py +682 -609
  10. mvsepless/models.json +0 -0
  11. mvsepless/models/bandit/core/__init__.py +669 -691
  12. mvsepless/models/bandit/core/data/__init__.py +2 -2
  13. mvsepless/models/bandit/core/data/_types.py +17 -17
  14. mvsepless/models/bandit/core/data/augmentation.py +102 -102
  15. mvsepless/models/bandit/core/data/augmented.py +34 -34
  16. mvsepless/models/bandit/core/data/base.py +60 -60
  17. mvsepless/models/bandit/core/data/dnr/datamodule.py +64 -68
  18. mvsepless/models/bandit/core/data/dnr/dataset.py +360 -366
  19. mvsepless/models/bandit/core/data/dnr/preprocess.py +51 -51
  20. mvsepless/models/bandit/core/data/musdb/datamodule.py +75 -75
  21. mvsepless/models/bandit/core/data/musdb/dataset.py +241 -273
  22. mvsepless/models/bandit/core/data/musdb/preprocess.py +223 -226
  23. mvsepless/models/bandit/core/data/musdb/validation.yaml +14 -14
  24. mvsepless/models/bandit/core/loss/__init__.py +8 -8
  25. mvsepless/models/bandit/core/loss/_complex.py +27 -27
  26. mvsepless/models/bandit/core/loss/_multistem.py +43 -43
  27. mvsepless/models/bandit/core/loss/_timefreq.py +94 -95
  28. mvsepless/models/bandit/core/loss/snr.py +131 -139
  29. mvsepless/models/bandit/core/metrics/__init__.py +7 -9
  30. mvsepless/models/bandit/core/metrics/_squim.py +350 -443
  31. mvsepless/models/bandit/core/metrics/snr.py +124 -127
  32. mvsepless/models/bandit/core/model/__init__.py +3 -3
  33. mvsepless/models/bandit/core/model/_spectral.py +54 -54
  34. mvsepless/models/bandit/core/model/bsrnn/__init__.py +23 -23
  35. mvsepless/models/bandit/core/model/bsrnn/bandsplit.py +119 -135
  36. mvsepless/models/bandit/core/model/bsrnn/core.py +619 -651
  37. mvsepless/models/bandit/core/model/bsrnn/maskestim.py +327 -351
  38. mvsepless/models/bandit/core/model/bsrnn/tfmodel.py +287 -320
  39. mvsepless/models/bandit/core/model/bsrnn/utils.py +518 -525
  40. mvsepless/models/bandit/core/model/bsrnn/wrapper.py +828 -829
  41. mvsepless/models/bandit/core/utils/audio.py +324 -412
  42. mvsepless/models/bandit/model_from_config.py +26 -26
  43. mvsepless/models/bandit_v2/bandit.py +360 -363
  44. mvsepless/models/bandit_v2/bandsplit.py +127 -130
  45. mvsepless/models/bandit_v2/film.py +23 -23
  46. mvsepless/models/bandit_v2/maskestim.py +269 -281
  47. mvsepless/models/bandit_v2/tfmodel.py +141 -145
  48. mvsepless/models/bandit_v2/utils.py +384 -523
  49. mvsepless/models/bs_roformer/__init__.py +6 -6
  50. mvsepless/models/bs_roformer/attend.py +120 -126
MVSepLess_Epsilon_Colab.ipynb CHANGED
@@ -51,7 +51,7 @@
51
  "prodigyopt\n",
52
  "torch_log_wmse\n",
53
  "rotary_embedding_torch\n",
54
- "gradio\n",
55
  "omegaconf\n",
56
  "beartype\n",
57
  "spafe\n",
@@ -395,6 +395,18 @@
395
  "!python mvsepless/model_manager.py vbach list"
396
  ]
397
  },
 
 
 
 
 
 
 
 
 
 
 
 
398
  {
399
  "cell_type": "markdown",
400
  "metadata": {
@@ -461,9 +473,7 @@
461
  },
462
  {
463
  "cell_type": "markdown",
464
- "metadata": {
465
- "id": "4GSGHM4rYSVp"
466
- },
467
  "source": [
468
  "## Инференс"
469
  ]
@@ -485,7 +495,7 @@
485
  "# @markdown ---\n",
486
  "# @markdown ### Hubert\n",
487
  "# @markdown * Стэк\n",
488
- "stack = \"transformers\" # @param [\"fairseq\",\"transformers\"]\n",
489
  "# @markdown * Имя модели для fairseq\n",
490
  "fairseq_embedder = \"hubert_base\" # @param [\"hubert_base\",\"contentvec_base\",\"korean_hubert_base\",\"chinese_hubert_base\",\"portuguese_hubert_base\",\"japanese_hubert_base\"]\n",
491
  "# @markdown * Имя модели для transformers\n",
@@ -497,7 +507,7 @@
497
  "# @markdown * Стерео режим\n",
498
  "stereo_mode = \"mono\" # @param [\"mono\",\"left/right\",\"sim/dif\"]\n",
499
  "# @markdown * Метод определения тона\n",
500
- "method_pitch = \"rmvpe+\" # @param [\"rmvpe+\",\"mangio-crepe\",\"fcpe\"]\n",
501
  "# @markdown * Изменение высоты тона (полутона)\n",
502
  "pitch = 0 # @param {\"type\":\"slider\",\"min\":-48,\"max\":48,\"step\":1}\n",
503
  "# @markdown * Длина шага (для mangio-crepe)\n",
 
51
  "prodigyopt\n",
52
  "torch_log_wmse\n",
53
  "rotary_embedding_torch\n",
54
+ "gradio<=6.0\n",
55
  "omegaconf\n",
56
  "beartype\n",
57
  "spafe\n",
 
395
  "!python mvsepless/model_manager.py vbach list"
396
  ]
397
  },
398
+ {
399
+ "cell_type": "code",
400
+ "execution_count": null,
401
+ "metadata": {},
402
+ "outputs": [],
403
+ "source": [
404
+ "#@title Удаление голосовой модели\n",
405
+ "%cd $mvsepless_dir\n",
406
+ "voicemodel_name = \"\" # @param {\"type\":\"string\",\"placeholder\":\"Имя модели\"}\n",
407
+ "!python mvsepless/model_manager.py vbach remove --model_name \"$voicemodel_name\""
408
+ ]
409
+ },
410
  {
411
  "cell_type": "markdown",
412
  "metadata": {
 
473
  },
474
  {
475
  "cell_type": "markdown",
476
+ "metadata": {},
 
 
477
  "source": [
478
  "## Инференс"
479
  ]
 
495
  "# @markdown ---\n",
496
  "# @markdown ### Hubert\n",
497
  "# @markdown * Стэк\n",
498
+ "stack = \"fairseq\" # @param [\"fairseq\",\"transformers\"]\n",
499
  "# @markdown * Имя модели для fairseq\n",
500
  "fairseq_embedder = \"hubert_base\" # @param [\"hubert_base\",\"contentvec_base\",\"korean_hubert_base\",\"chinese_hubert_base\",\"portuguese_hubert_base\",\"japanese_hubert_base\"]\n",
501
  "# @markdown * Имя модели для transformers\n",
 
507
  "# @markdown * Стерео режим\n",
508
  "stereo_mode = \"mono\" # @param [\"mono\",\"left/right\",\"sim/dif\"]\n",
509
  "# @markdown * Метод определения тона\n",
510
+ "method_pitch = \"rmvpe+\" # @param [\"rmvpe+\",\"mangio-crepe\",\"mangio-crepe-tiny\",\"fcpe\",'harvest\",\"pm\",\"pyin\"]\n",
511
  "# @markdown * Изменение высоты тона (полутона)\n",
512
  "pitch = 0 # @param {\"type\":\"slider\",\"min\":-48,\"max\":48,\"step\":1}\n",
513
  "# @markdown * Длина шага (для mangio-crepe)\n",
mvsepless/__init__.py CHANGED
The diff for this file is too large to render. See raw diff
 
mvsepless/__main__.py CHANGED
@@ -11,18 +11,59 @@ if __name__ == "__main__":
11
  parser = argparse.ArgumentParser(description="MVSepless")
12
  subparsers = parser.add_subparsers(dest="command")
13
  app_parser = subparsers.add_parser("app", help="Приложение MVSepless")
14
- app_parser.add_argument("--port", type=int, default=None, help="Порт для запуска сервера Gradio.")
15
- app_parser.add_argument("--share", action="store_true", help="Создать публичную ссылку для приложения Gradio.")
 
 
 
 
 
 
16
  cli_parser = subparsers.add_parser("cli", help="CLI MVSepless Lite")
17
- cli_parser.add_argument("--input", type=str, required=True, help="Входной аудиофайл или каталог.")
18
- cli_parser.add_argument("--output_dir", type=str, default=None, help="Каталог для выходных файлов.")
19
- cli_parser.add_argument("--model_type", type=str, default="mel_band_roformer", help="Тип модели разделения.")
20
- cli_parser.add_argument("--model_name", type=str, default="Mel-Band-Roformer_Vocals_kimberley_jensen", help="Имя модели разделения.")
21
- cli_parser.add_argument("--ext_inst", action="store_true", help="Извлечь инструментал.")
22
- cli_parser.add_argument("--output_format", type=str, default="mp3", choices=Separator.audio.output_formats, help="Формат выходного файла.")
23
- cli_parser.add_argument("--output_bitrate", type=str, default="320k", help="Битрейт выходного файла.")
24
- cli_parser.add_argument("--template", type=str, default="NAME (STEM) MODEL", help="Шаблон именования выходных файлов.")
25
- cli_parser.add_argument("--selected_stems", type=str, nargs='*', default=None, help="Выбранные стемы для разделения.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  args = parser.parse_args()
27
 
28
  if args.command == "app":
@@ -39,7 +80,13 @@ if __name__ == "__main__":
39
  ],
40
  )
41
  mvsepless_lite_app = mvsepless_app(theme=theme)
42
- mvsepless_lite_app.launch(server_name="0.0.0.0", server_port=args.port, share=args.share, allowed_paths=["/"], debug=True)
 
 
 
 
 
 
43
  elif args.command == "cli":
44
  input_data = args.input
45
  if os.path.isdir(input_data):
@@ -62,6 +109,6 @@ if __name__ == "__main__":
62
  output_format=args.output_format,
63
  output_bitrate=args.output_bitrate,
64
  template=args.template,
65
- selected_stems=args.selected_stems
66
  )
67
- print("Разделение завершено.")
 
11
  parser = argparse.ArgumentParser(description="MVSepless")
12
  subparsers = parser.add_subparsers(dest="command")
13
  app_parser = subparsers.add_parser("app", help="Приложение MVSepless")
14
+ app_parser.add_argument(
15
+ "--port", type=int, default=None, help="Порт для запуска сервера Gradio."
16
+ )
17
+ app_parser.add_argument(
18
+ "--share",
19
+ action="store_true",
20
+ help="Создать публичную ссылку для приложения Gradio.",
21
+ )
22
  cli_parser = subparsers.add_parser("cli", help="CLI MVSepless Lite")
23
+ cli_parser.add_argument(
24
+ "--input", type=str, required=True, help="Входной аудиофайл или каталог."
25
+ )
26
+ cli_parser.add_argument(
27
+ "--output_dir", type=str, default=None, help="Каталог для выходных файлов."
28
+ )
29
+ cli_parser.add_argument(
30
+ "--model_type",
31
+ type=str,
32
+ default="mel_band_roformer",
33
+ help="Тип модели разделения.",
34
+ )
35
+ cli_parser.add_argument(
36
+ "--model_name",
37
+ type=str,
38
+ default="Mel-Band-Roformer_Vocals_kimberley_jensen",
39
+ help="Имя модели разделения.",
40
+ )
41
+ cli_parser.add_argument(
42
+ "--ext_inst", action="store_true", help="Извлечь инструментал."
43
+ )
44
+ cli_parser.add_argument(
45
+ "--output_format",
46
+ type=str,
47
+ default="mp3",
48
+ choices=Separator.audio.output_formats,
49
+ help="Формат выходного файла.",
50
+ )
51
+ cli_parser.add_argument(
52
+ "--output_bitrate", type=str, default="320k", help="Битрейт выходного файла."
53
+ )
54
+ cli_parser.add_argument(
55
+ "--template",
56
+ type=str,
57
+ default="NAME (STEM) MODEL",
58
+ help="Шаблон именования выходных файлов.",
59
+ )
60
+ cli_parser.add_argument(
61
+ "--selected_stems",
62
+ type=str,
63
+ nargs="*",
64
+ default=None,
65
+ help="Выбранные стемы для разделения.",
66
+ )
67
  args = parser.parse_args()
68
 
69
  if args.command == "app":
 
80
  ],
81
  )
82
  mvsepless_lite_app = mvsepless_app(theme=theme)
83
+ mvsepless_lite_app.launch(
84
+ server_name="0.0.0.0",
85
+ server_port=args.port,
86
+ share=args.share,
87
+ allowed_paths=["/"],
88
+ debug=True,
89
+ )
90
  elif args.command == "cli":
91
  input_data = args.input
92
  if os.path.isdir(input_data):
 
109
  output_format=args.output_format,
110
  output_bitrate=args.output_bitrate,
111
  template=args.template,
112
+ selected_stems=args.selected_stems,
113
  )
114
+ print("Разделение завершено.")
mvsepless/audio.py CHANGED
@@ -1,781 +1,789 @@
1
- import os
2
- from pathlib import Path
3
- import sys
4
- import json
5
- import subprocess
6
- import numpy as np
7
- from typing import Literal
8
- from collections.abc import Callable
9
- from pathlib import Path
10
- from numpy.typing import DTypeLike
11
- import tempfile
12
- import librosa
13
- if not __package__:
14
- from namer import Namer
15
- else:
16
- from .namer import Namer
17
- class NotInputFileSpecified(Exception): pass
18
- class NotOutputFileSpecified(Exception): pass
19
- class NotSupportedDataType(Exception): pass
20
- class ErrorDecode(Exception): pass
21
- class ErrorEncode(Exception): pass
22
- class NotSupportedFormat(Exception): pass
23
- class SampleRateError(Exception): pass
24
- class FileIsNotAudio(Exception): pass
25
-
26
- class Audio(Namer):
27
- def __init__(self):
28
- """
29
- Чтение и запись аудио файла через ffmpeg
30
-
31
- Поддерживаемые типы данных: - int16, int32, float32, float64
32
- """
33
- super().__init__()
34
- self.ffmpeg_path = os.environ.get("MVSEPLESS_FFMPEG", "ffmpeg")
35
- self.ffprobe_path = os.environ.get("MVSEPLESS_FFPROBE", "ffprobe")
36
- self.output_formats = ("mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff")
37
- self.input_formats = ("mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff", "mp4", "mkv", "webm", "avi", "mov", "ts")
38
- self.supported_dtypes = ("int16", "int32", "float32", "float64")
39
- self.dtypes_dict = {
40
- "int16": "s16le",
41
- "int32": "s32le",
42
- "float32": "f32le",
43
- "float64": "f64le",
44
- np.int16: "s16le",
45
- np.int32: "s32le",
46
- np.float32: "f32le",
47
- np.float64: "f64le",
48
- }
49
- self.bitrate_limit = {
50
- "mp3": {"min": 8, "max": 320},
51
- "aac": {"min": 8, "max": 512},
52
- "m4a": {"min": 8, "max": 512},
53
- "ac3": {"min": 32, "max": 640},
54
- "ogg": {"min": 64, "max": 500},
55
- "opus": {"min": 6, "max": 512},
56
- }
57
- self.sample_rates = {
58
- "mp3": {
59
- "supported": (44100, 48000, 32000, 22050, 24000, 16000, 11025, 12000, 8000)
60
- },
61
- "opus": {"supported": (48000, 24000, 16000, 12000, 8000)},
62
- "m4a": {
63
- "supported": (
64
- 96000,
65
- 88200,
66
- 64000,
67
- 48000,
68
- 44100,
69
- 32000,
70
- 24000,
71
- 22050,
72
- 16000,
73
- 12000,
74
- 11025,
75
- 8000,
76
- 7350,
77
- )
78
- },
79
- "aac": {
80
- "supported": (
81
- 96000,
82
- 88200,
83
- 64000,
84
- 48000,
85
- 44100,
86
- 32000,
87
- 24000,
88
- 22050,
89
- 16000,
90
- 12000,
91
- 11025,
92
- 8000,
93
- 7350,
94
- )
95
- },
96
- "ac3": {
97
- "supported": (
98
- 48000,
99
- 44100,
100
- 32000,
101
- )
102
- },
103
- "ogg": {"min": 6, "max": 192000},
104
- "wav": {"min": 0, "max": float("inf")},
105
- "aiff": {"min": 0, "max": float("inf")},
106
- "flac": {"min": 0, "max": 192000},
107
- }
108
- self.check_ffmpeg()
109
- self.check_ffprobe()
110
-
111
- def check_ffmpeg(self):
112
- """
113
- Проверяет, установлен ли ffmpeg?
114
- """
115
- try:
116
- ffmpeg_version_output = subprocess.check_output(
117
- [self.ffmpeg_path, "-version"], text=True
118
- )
119
- except FileNotFoundError:
120
- if "PYTEST_CURRENT_TEST" not in os.environ:
121
- raise FileNotFoundError("FFMPEG не установлен. Укажите путь к установленному FFMPEG через переменную окружения MVSEPLESS_FFMPEG")
122
-
123
- def check_ffprobe(self):
124
- """
125
- Проверяет, установлен ли ffprobe?
126
- """
127
- try:
128
- ffmpeg_version_output = subprocess.check_output(
129
- [self.ffprobe_path, "-version"], text=True
130
- )
131
- except FileNotFoundError:
132
- if "PYTEST_CURRENT_TEST" not in os.environ:
133
- raise FileNotFoundError("FFPROBE не установлен. Укажите путь к установленному FFPROBE через переменную окружения MVSEPLESS_FFPROBE")
134
-
135
-
136
- def fit_sr(
137
- self,
138
- f: str | Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] = "mp3",
139
- sr: int = 44100
140
- ) -> int:
141
- """
142
- Исправляет значение частоты дисректизации выходного файла
143
-
144
- Параметры:
145
- f: Формат вывода
146
- sr: Частота дискретизации (целое число)
147
- Возвращает:
148
- sr: Исправленная частота дискретизации
149
- """
150
- format_info = self.sample_rates.get(f.lower())
151
-
152
- if not format_info:
153
- return None # Формат не найден
154
-
155
- if "supported" in format_info:
156
- # Для форматов с конкретным списком
157
- supported_rates = format_info["supported"]
158
- if sr in supported_rates:
159
- return sr
160
-
161
- # Находим ближайшую поддерживаемую частоту
162
- return min(supported_rates, key=lambda x: abs(x - sr))
163
-
164
- elif "min" in format_info and "max" in format_info:
165
- # Для форматов с диапазоном - обрезаем до границ
166
- min_rate = format_info["min"]
167
- max_rate = format_info["max"]
168
-
169
- if sr < min_rate:
170
- return min_rate
171
- elif sr > max_rate:
172
- return max_rate
173
- else:
174
- return sr
175
-
176
- return None
177
-
178
- def fit_br(
179
- self,
180
- f: str | Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] = "mp3",
181
- br: int = 320
182
- ) -> int:
183
- """
184
- Исправляет значение битрейта выходного файла
185
-
186
- Параметры:
187
- f: Формат вывода
188
- br: Битрейт (целое число)
189
- Возвращает:
190
- br: Исправленный битрейт
191
- """
192
- if f not in self.bitrate_limit:
193
- raise NotSupportedFormat(f"Формат {f} не поддерживается")
194
-
195
- limits = self.bitrate_limit[f]
196
-
197
- if br < limits["min"]:
198
- return limits["min"]
199
- elif br > limits["max"]:
200
- return limits["max"]
201
- else:
202
- return br
203
-
204
- def get_info(
205
- self,
206
- i: str | os.PathLike | Callable | None = None,
207
- ) -> dict[int, dict[int, float]]:
208
- """
209
- Получает информацию о аудио потоках из файла напрямую через FFMPEG
210
-
211
- Параметры:
212
- i: Путь к выходному файлу
213
- Возвращает:
214
- audio_info: Словарь с информацией о аудиопотоках вида:
215
-
216
- {Номер потока:
217
- {
218
- "sample_rate": Частота дисректизации (является целым числом),
219
- "duration": Длительность аудиопотока (является числом с плавающей точкой)
220
- }
221
- }
222
- """
223
- audio_info = {}
224
- if i:
225
- if isinstance(i, Path):
226
- i = str(i)
227
- if os.path.exists(i):
228
- cmd = [self.ffprobe_path, "-i", i, "-v", "quiet", "-hide_banner",
229
- "-show_entries", "stream=index,sample_rate,duration", "-select_streams", "a", "-of", "json"]
230
-
231
- process = subprocess.Popen(
232
- cmd,
233
- stdin=subprocess.PIPE,
234
- stdout=subprocess.PIPE,
235
- stderr=subprocess.PIPE,
236
- )
237
-
238
- stdout, stderr = process.communicate()
239
-
240
- if process.returncode != 0:
241
- print(f"STDERR: {stderr.decode('utf-8')}")
242
- print(f"STDOUT: {stdout.decode('utf-8')}")
243
-
244
- json_output = json.loads(stdout)
245
- streams = json_output["streams"]
246
- if not streams:
247
- pass
248
-
249
- else:
250
- for a, stream in enumerate(streams):
251
- audio_info[a] = {
252
- "sample_rate": int(stream.get("sample_rate", 0)),
253
- "duration": float(stream.get("duration", 0))
254
- }
255
-
256
- return audio_info
257
-
258
- else:
259
- raise FileExistsError("Указанного файла не существует")
260
-
261
- else:
262
- raise NotInputFileSpecified("Не указан путь к файлу")
263
-
264
- def check(
265
- self,
266
- i: str | os.PathLike | Callable | None = None
267
- ) -> bool:
268
- """
269
- Проверяет, является ли файл аудио или видео файлом, поддерживаемым ffmpeg
270
-
271
- Параметры:
272
- i: Путь к выходному файлу
273
- Возвращает:
274
- is_audio_video: Булево значение, является ли файл аудио или видео файлом
275
- """
276
- if i:
277
- if isinstance(i, Path):
278
- i = str(i)
279
- if os.path.exists(i):
280
- info = self.get_info(i=i)
281
- if info:
282
- list_streams = list(info.keys())
283
- if len(list_streams) > 0:
284
- if info[0].get("sample_rate") > 0:
285
- return True
286
- else:
287
- return False
288
- else:
289
- return False
290
- else:
291
- return False
292
- else:
293
- raise FileExistsError("Указанного файла не существует")
294
- else:
295
- raise NotInputFileSpecified("Не указан путь к файлу")
296
-
297
- def read(
298
- self,
299
- i: str | os.PathLike | Callable | None = None,
300
- sr: int | None = None,
301
- mono: bool = False,
302
- dtype: DTypeLike = np.float32,
303
- s: int = 0
304
- ) -> tuple[np.ndarray, int, float]:
305
- """
306
- Читает аудио-файл, преобразовывая его в массив с аудио данными напрямую через FFMPEG
307
- Является заменой soundfile.read() и librosa.load()
308
-
309
- Параметры:
310
- i: Путь к выходному файлу
311
- sr: Целевая частота дискретизации (Если не указана, то используется частота дискретизации входного файла)
312
- mono: Конвертация в моно (по умолчанию отключена)
313
- dtype: Тип данных (поддерживаются типы: int16, int32, float32, float64; по умолчанию - float32)
314
- s: Номер аудиопотока (по умолчанию 0)
315
- Возвращает:
316
- audio_array: Массив с аудио данными
317
- sr: Частота дискретизации массива
318
- duration: Длительность аудио (количество сэмплов / частота дискретизации)
319
- """
320
- output_format = self.dtypes_dict.get(dtype, None)
321
- if not output_format:
322
- raise NotSupportedDataType(f"Этот тип данных не поддерживается {dtype}")
323
- if i:
324
- if isinstance(i, Path):
325
- i = str(i)
326
- if os.path.exists(i):
327
- audio_info = self.get_info(i=i)
328
- list_streams = list(audio_info.keys())
329
- if audio_info.get(s, False):
330
- stream = s
331
- else:
332
- if len(list_streams) > 0:
333
- stream = 0
334
- else:
335
- raise FileIsNotAudio("В входном файле нет аудио потоков")
336
-
337
- sample_rate_input = audio_info[stream]["sample_rate"]
338
- if sample_rate_input == 0:
339
- raise FileIsNotAudio( входном файле нет аудио потоков")
340
-
341
- cmd = [
342
- self.ffmpeg_path,
343
- "-i", i,
344
- "-map", f"0:a:{stream}", "-vn",
345
- "-f", output_format,
346
- "-ac", "1" if mono else "2",
347
- ]
348
-
349
- if sr:
350
- cmd.extend(["-ar", str(sr)])
351
- else:
352
- sr = sample_rate_input
353
-
354
- cmd.append("pipe:1")
355
-
356
- process = subprocess.Popen(
357
- cmd,
358
- stdout=subprocess.PIPE,
359
- stderr=subprocess.PIPE,
360
- bufsize=10**8
361
- )
362
-
363
- try:
364
-
365
- raw_audio, stderr = process.communicate(timeout=300)
366
-
367
- if process.returncode != 0:
368
- raise ErrorDecode(f"FFmpeg error: {stderr.decode()}")
369
-
370
- except subprocess.TimeoutExpired:
371
- process.kill()
372
- raise ErrorDecode("FFmpeg timeout при чтении файла")
373
-
374
- audio_array = np.frombuffer(raw_audio, dtype=dtype)
375
-
376
- channels = 1 if mono else 2
377
- audio_array = audio_array.reshape((-1, channels)).T
378
- if audio_array.ndim > 1 and channels == 1:
379
- audio_array = np.mean(audio_array, axis=tuple(range(audio_array.ndim - 1)))
380
-
381
- len_samples = float(audio_array.shape[-1])
382
-
383
- duration = len_samples / sr
384
-
385
- print(f"Частота дискретизации: {sr}")
386
-
387
- return audio_array.copy(), sr, duration
388
- else:
389
- raise FileExistsError("Указанного файла не существует")
390
-
391
- else:
392
- raise NotInputFileSpecified("Не указан путь к файлу")
393
-
394
- def write(
395
- self,
396
- o: str | os.PathLike | Callable | None = None,
397
- array: np.ndarray = np.array([], dtype=np.float32),
398
- sr: int = 44100,
399
- of: str | Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] | None = None,
400
- br: str | int | None = None
401
- ) -> str:
402
- """
403
- Записывает numpy-массив с аудио данными в файл напрямую через ffmpeg.
404
- Является заменой soundfile.write()
405
-
406
- Параметры:
407
- o: Путь к выходному файлу
408
- array: Массив с аудио данными (поддерживаются типы: int16, int32, float32, float64)
409
- sr: Частота дискретизации массива
410
- of: Формат вывода (по умолчанию mp3)
411
- br: Битрейт для кодеков, сжимающих аудио с потерями
412
- Возвращает:
413
- o: Путь к выходному файлу
414
- """
415
- if isinstance(array, np.ndarray):
416
-
417
- if len(array.shape) == 1:
418
- array = array.reshape(-1, 1)
419
- elif len(array.shape) == 2:
420
- if array.shape[0] == 2:
421
- array = array.T
422
- else:
423
- raise ValueError("numpy-массив должен быть либо одномерным, либо двухмерным")
424
-
425
- if array.dtype == np.int16:
426
- input_format = "s16le"
427
- elif array.dtype == np.int32:
428
- input_format = "s32le"
429
- elif array.dtype == np.float32:
430
- input_format = "f32le"
431
- elif array.dtype == np.float64:
432
- input_format = "f64le"
433
- else:
434
- raise NotSupportedDataType(f"Этот тип данных не поддерживается {array.dtype}")
435
-
436
- if array.shape[1] == 1:
437
- audio_bytes = array.tobytes()
438
-
439
- channels = 1
440
-
441
- elif array.shape[1] == 2:
442
- audio_bytes = array.tobytes()
443
-
444
- channels = 2
445
- else:
446
- raise ValueError("numpy-массив должен содержать 1 или 2 канала")
447
-
448
- else:
449
- raise ValueError("Вход должен быть numpy-массивом")
450
-
451
- if o:
452
- if isinstance(o, Path):
453
- o = str(o)
454
- output_dir = os.path.dirname(o)
455
- output_base = os.path.basename(o)
456
- output_name, output_ext = os.path.splitext(output_base)
457
- if output_dir != "":
458
- os.makedirs(output_dir, exist_ok=True)
459
- if output_ext == "":
460
- if of:
461
- o += f".{of}"
462
- else:
463
- o += f".mp3"
464
- elif output_ext == ".":
465
- if of:
466
- o += f"{of}"
467
- else:
468
- o += f"mp3"
469
- else:
470
- raise NotOutputFileSpecified("Не указан путь к выходному файлу")
471
-
472
- if of:
473
- if of in self.output_formats:
474
- output_name, output_ext = os.path.splitext(o)
475
- if output_ext == f".{of}":
476
- pass
477
- else:
478
- o = f"{os.path.join(output_dir, output_name)}.{of}"
479
- else:
480
- raise NotSupportedFormat(f"Неподдерживаемый формат: {of}")
481
- else:
482
- of = os.path.splitext(o)[1].strip(".")
483
- if of in self.output_formats:
484
- pass
485
- else:
486
- raise NotSupportedFormat(f"Неподдерживаемый формат: {of}")
487
-
488
- if sr:
489
- if isinstance(sr, int):
490
- sample_rate_fixed = self.fit_sr(f=of, sr=sr)
491
- elif isinstance(sr, float):
492
- sr = int(sr)
493
- sample_rate_fixed = self.fit_sr(f=of, sr=sr)
494
- else:
495
- raise SampleRateError(f"Частота дискретизации должна быть числом\n\nЗначение: {sr}\nТип: {type(sr)}")
496
- else:
497
- raise SampleRateError("Не указана частота дискретизации")
498
-
499
- bitrate_fixed = "320k"
500
-
501
- if of not in ["wav", "flac", "aiff"]:
502
- if br:
503
- if isinstance(br, int):
504
- bitrate_fixed = self.fit_br(f=of, br=br)
505
- elif isinstance(br, float):
506
- bitrate_fixed = self.fit_br(f=of, br=int(br))
507
- elif isinstance(br, str):
508
- bitrate_fixed = self.fit_br(f=of, br=int(br.strip("k").strip("K")))
509
- else:
510
- bitrate_fixed = self.fit_br(f=of, br=320)
511
- else:
512
- bitrate_fixed = self.fit_br(of, 320)
513
-
514
- format_settings = {
515
- "wav": [
516
- "-c:a",
517
- "pcm_f32le",
518
- "-sample_fmt",
519
- "flt",
520
- ],
521
- "aiff": [
522
- "-c:a",
523
- "pcm_f32be",
524
- "-sample_fmt",
525
- "flt",
526
- ],
527
- "flac": [
528
- "-c:a",
529
- "flac",
530
- "-compression_level",
531
- "12",
532
- "-sample_fmt",
533
- "s32",
534
- ],
535
- "mp3": [
536
- "-c:a",
537
- "libmp3lame",
538
- "-b:a",
539
- f"{bitrate_fixed}k",
540
- ],
541
- "ogg": [
542
- "-c:a",
543
- "libvorbis",
544
- "-b:a",
545
- f"{bitrate_fixed}k",
546
- ],
547
- "opus": [
548
- "-c:a",
549
- "libopus",
550
- "-b:a",
551
- f"{bitrate_fixed}k",
552
- ],
553
- "m4a": [
554
- "-c:a",
555
- "aac",
556
- "-b:a",
557
- f"{bitrate_fixed}k",
558
- ],
559
- "aac": [
560
- "-c:a",
561
- "aac",
562
- "-b:a",
563
- f"{bitrate_fixed}k",
564
- ],
565
- "ac3": [
566
- "-c:a",
567
- "ac3",
568
- "-b:a",
569
- f"{bitrate_fixed}k",
570
- ],
571
- }
572
-
573
- cmd = [
574
- self.ffmpeg_path,
575
- "-y",
576
- "-f",
577
- input_format,
578
- "-ar",
579
- str(sr),
580
- "-ac",
581
- str(channels),
582
- "-i",
583
- "pipe:0",
584
- "-ac",
585
- str(channels),
586
- ]
587
-
588
- cmd.extend(["-ar", str(sample_rate_fixed)])
589
- cmd.extend(format_settings[of])
590
- o_dir, o_base = os.path.split(o)
591
- o_base_n, o_base_ext = os.path.splitext(o_base)
592
- o_base_n = self.sanitize(o_base_n)
593
- o_base_n = self.short(o_base_n)
594
- o = os.path.join(o_dir, f"{o_base_n}{o_base_ext}")
595
- o = self.iter(o)
596
- cmd.append(o)
597
-
598
- process = subprocess.Popen(
599
- cmd,
600
- stdin=subprocess.PIPE,
601
- stdout=subprocess.PIPE,
602
- stderr=subprocess.PIPE,
603
- )
604
-
605
- try:
606
- stdout, stderr = process.communicate(input=audio_bytes, timeout=300)
607
- except subprocess.TimeoutExpired:
608
- process.kill()
609
- raise ErrorEncode("FFmpeg timeout: операция заняла слишком много времени")
610
-
611
- if process.returncode != 0:
612
- raise ErrorEncode(f"FFmpeg завершился с ошибкой (код: {process.returncode})")
613
-
614
- return os.path.abspath(o)
615
-
616
- class Inverter(Audio):
617
- def __init__(self):
618
- super().__init__()
619
- self.test = "test"
620
- self.w_types = [
621
- "boxcar", # Прямоугольное окно
622
- "triang", # Треугольное окно
623
- "blackman", # Окно Блэкмана
624
- "hamming", # Окно Хэмминга
625
- "hann", # Окно Ханна
626
- "bartlett", # Окно Бартлетта
627
- "flattop", # Окно с плоской вершиной
628
- "parzen", # Окно Парзена
629
- "bohman", # Окно Бохмана
630
- "blackmanharris", # Окно Блэкмана-Харриса
631
- "nuttall", # Окно Нуттала
632
- "barthann", # Окно Бартлетта-Ханна
633
- "cosine", # Косинусное окно
634
- "exponential", # Экспоненциальное окно
635
- "tukey", # Окно Туки
636
- "taylor", # Окно Тейлора
637
- "lanczos", # Окно Ланцоша
638
- ]
639
-
640
- def load_audio(self, filepath):
641
- """Загрузка аудиофайла с помощью librosa"""
642
- try:
643
- y, sr, _ = self.read(i=filepath, sr=None, mono=False)
644
- return y, sr
645
- except Exception as e:
646
- print(f"Ошибка загрузки аудио: {e}")
647
- return None, None
648
-
649
- def process_channel(self, y1_ch, y2_ch, sr, method, w_size=2048, overlap=2, w_type="hann"):
650
- """Обработка одного аудиоканала"""
651
- HOP_LENGTH = w_size // overlap
652
- if method == "waveform":
653
- return y1_ch - y2_ch
654
-
655
- elif method == "spectrogram":
656
- # Вычисляем спектрограммы
657
- S1 = librosa.stft(
658
- y1_ch, n_fft=w_size, hop_length=HOP_LENGTH, win_length=w_size
659
- )
660
- S2 = librosa.stft(
661
- y2_ch, n_fft=w_size, hop_length=HOP_LENGTH, win_length=w_size
662
- )
663
-
664
- # Амплитудные спектрограммы
665
- mag1 = np.abs(S1)
666
- mag2 = np.abs(S2)
667
-
668
- # Спектральное вычитание
669
- mag_result = np.maximum(mag1 - mag2, 0)
670
-
671
- # Сохраняем фазовую информацию исходного сигнала
672
- phase = np.angle(S1)
673
-
674
- # Комбинируем амплитуду результата с фазой
675
- S_result = mag_result * np.exp(1j * phase)
676
-
677
- # Обратное преобразование
678
- return librosa.istft(
679
- S_result,
680
- n_fft=w_size,
681
- hop_length=HOP_LENGTH,
682
- win_length=w_size,
683
- length=len(y1_ch),
684
- )
685
-
686
- def process_audio(self, audio1_path, audio2_path, out_format, method, output_path="./inverted.mp3", w_size=2048, overlap=2, w_type="hann"):
687
- # Загрузка аудиофайлов
688
- y1, sr1 = self.load_audio(audio1_path)
689
- y2, sr2 = self.load_audio(audio2_path)
690
-
691
- if sr1 is None or sr2 is None:
692
- raise Exception("Произошла ошибка при чтении файлов")
693
-
694
- # Определяем количество каналов
695
- channels1 = 1 if y1.ndim == 1 else y1.shape[0]
696
- channels2 = 1 if y2.ndim == 1 else y2.shape[0]
697
-
698
- # Преобразование в форму (samples, channels)
699
- if channels1 > 1:
700
- y1 = y1.T # (channels, samples) -> (samples, channels)
701
- else:
702
- y1 = y1.reshape(-1, 1)
703
-
704
- if channels2 > 1:
705
- y2 = y2.T # (channels, samples) -> (samples, channels)
706
- else:
707
- y2 = y2.reshape(-1, 1)
708
-
709
- if sr1 != sr2:
710
- if channels2 > 1:
711
- # Ресемплинг для каждого канала отдельно
712
- y2_resampled_list = []
713
- for c in range(channels2):
714
- channel_resampled = librosa.resample(
715
- y2[:, c], orig_sr=sr2, target_sr=sr1
716
- )
717
- y2_resampled_list.append(channel_resampled)
718
-
719
- # Находим минимальную длину среди всех каналов
720
- min_channel_length = min(len(ch) for ch in y2_resampled_list)
721
-
722
- # Обрезаем все каналы до одинаковой длины и собираем в массив
723
- y2_resampled = np.zeros((min_channel_length, channels2), dtype=np.float32)
724
- for c, channel in enumerate(y2_resampled_list):
725
- y2_resampled[:, c] = channel[:min_channel_length]
726
-
727
- y2 = y2_resampled
728
- else:
729
- y2 = librosa.resample(y2[:, 0], orig_sr=sr2, target_sr=sr1)
730
- y2 = y2.reshape(-1, 1)
731
- sr2 = sr1
732
-
733
- # Приводим к одинаковой длине
734
- min_len = min(len(y1), len(y2))
735
- y1 = y1[:min_len]
736
- y2 = y2[:min_len]
737
-
738
- # Обрабатываем каждый канал отдельно
739
- result_channels = []
740
-
741
- # Если основной сигнал моно, а удаляемый стерео - преобразуем удаляемый в моно
742
- if channels1 == 1 and channels2 > 1:
743
- y2 = y2.mean(axis=1, keepdims=True)
744
- channels2 = 1
745
-
746
- for c in range(channels1):
747
- # Выбираем канал для основного сигнала
748
- y1_ch = y1[:, c]
749
-
750
- # Выбираем канал для удаляемого сигнала
751
- if channels2 == 1:
752
- y2_ch = y2[:, 0]
753
- else:
754
- # Если каналов удаляемого сигнала больше, используем соответствующий канал
755
- y2_ch = y2[:, min(c, channels2 - 1)]
756
-
757
- # Обрабатываем канал
758
- result_ch = self.process_channel(y1_ch, y2_ch, sr1, method, w_size=w_size, overlap=overlap, w_type=w_type)
759
- result_channels.append(result_ch)
760
-
761
- # Собираем каналы в один массив
762
- if len(result_channels) > 1:
763
- result = np.column_stack(result_channels)
764
- else:
765
- result = np.array(result_channels[0])
766
-
767
- # Нормализация (предотвращение клиппинга)
768
- if result.ndim > 1:
769
- # Для многоканального аудио нормализуем каждый канал отдельно
770
- for c in range(result.shape[1]):
771
- channel = result[:, c]
772
- max_val = np.max(np.abs(channel))
773
- if max_val > 0:
774
- result[:, c] = channel * 0.9 / max_val
775
- else:
776
- max_val = np.max(np.abs(result))
777
- if max_val > 0:
778
- result = result * 0.9 / max_val
779
-
780
- inverted = self.write(o=output_path, array=result.T, sr=sr1, of=out_format, br="320k")
781
- return inverted
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import sys
4
+ import json
5
+ import subprocess
6
+ import numpy as np
7
+ from typing import Literal
8
+ from collections.abc import Callable
9
+ from pathlib import Path
10
+ from numpy.typing import DTypeLike
11
+ import tempfile
12
+ import librosa
13
+
14
+ if not __package__:
15
+ from namer import Namer
16
+ else:
17
+ from .namer import Namer
18
+
19
+
20
+ class NotInputFileSpecified(Exception):
21
+ pass
22
+
23
+
24
+ class NotOutputFileSpecified(Exception):
25
+ pass
26
+
27
+
28
+ class NotSupportedDataType(Exception):
29
+ pass
30
+
31
+
32
+ class ErrorDecode(Exception):
33
+ pass
34
+
35
+
36
+ class ErrorEncode(Exception):
37
+ pass
38
+
39
+
40
+ class NotSupportedFormat(Exception):
41
+ pass
42
+
43
+
44
+ class SampleRateError(Exception):
45
+ pass
46
+
47
+
48
+ class FileIsNotAudio(Exception):
49
+ pass
50
+
51
+
52
+ class Audio(Namer):
53
+ def __init__(self):
54
+ super().__init__()
55
+ self.ffmpeg_path = os.environ.get("MVSEPLESS_FFMPEG", "ffmpeg")
56
+ self.ffprobe_path = os.environ.get("MVSEPLESS_FFPROBE", "ffprobe")
57
+ self.output_formats = (
58
+ "mp3",
59
+ "wav",
60
+ "flac",
61
+ "ogg",
62
+ "opus",
63
+ "m4a",
64
+ "aac",
65
+ "ac3",
66
+ "aiff",
67
+ )
68
+ self.input_formats = (
69
+ "mp3",
70
+ "wav",
71
+ "flac",
72
+ "ogg",
73
+ "opus",
74
+ "m4a",
75
+ "aac",
76
+ "ac3",
77
+ "aiff",
78
+ "mp4",
79
+ "mkv",
80
+ "webm",
81
+ "avi",
82
+ "mov",
83
+ "ts",
84
+ )
85
+ self.supported_dtypes = ("int16", "int32", "float32", "float64")
86
+ self.dtypes_dict = {
87
+ "int16": "s16le",
88
+ "int32": "s32le",
89
+ "float32": "f32le",
90
+ "float64": "f64le",
91
+ np.int16: "s16le",
92
+ np.int32: "s32le",
93
+ np.float32: "f32le",
94
+ np.float64: "f64le",
95
+ }
96
+ self.bitrate_limit = {
97
+ "mp3": {"min": 8, "max": 320},
98
+ "aac": {"min": 8, "max": 512},
99
+ "m4a": {"min": 8, "max": 512},
100
+ "ac3": {"min": 32, "max": 640},
101
+ "ogg": {"min": 64, "max": 500},
102
+ "opus": {"min": 6, "max": 512},
103
+ }
104
+ self.sample_rates = {
105
+ "mp3": {
106
+ "supported": (
107
+ 44100,
108
+ 48000,
109
+ 32000,
110
+ 22050,
111
+ 24000,
112
+ 16000,
113
+ 11025,
114
+ 12000,
115
+ 8000,
116
+ )
117
+ },
118
+ "opus": {"supported": (48000, 24000, 16000, 12000, 8000)},
119
+ "m4a": {
120
+ "supported": (
121
+ 96000,
122
+ 88200,
123
+ 64000,
124
+ 48000,
125
+ 44100,
126
+ 32000,
127
+ 24000,
128
+ 22050,
129
+ 16000,
130
+ 12000,
131
+ 11025,
132
+ 8000,
133
+ 7350,
134
+ )
135
+ },
136
+ "aac": {
137
+ "supported": (
138
+ 96000,
139
+ 88200,
140
+ 64000,
141
+ 48000,
142
+ 44100,
143
+ 32000,
144
+ 24000,
145
+ 22050,
146
+ 16000,
147
+ 12000,
148
+ 11025,
149
+ 8000,
150
+ 7350,
151
+ )
152
+ },
153
+ "ac3": {
154
+ "supported": (
155
+ 48000,
156
+ 44100,
157
+ 32000,
158
+ )
159
+ },
160
+ "ogg": {"min": 6, "max": 192000},
161
+ "wav": {"min": 0, "max": float("inf")},
162
+ "aiff": {"min": 0, "max": float("inf")},
163
+ "flac": {"min": 0, "max": 192000},
164
+ }
165
+ self.check_ffmpeg()
166
+ self.check_ffprobe()
167
+
168
+ def check_ffmpeg(self):
169
+ try:
170
+ ffmpeg_version_output = subprocess.check_output(
171
+ [self.ffmpeg_path, "-version"], text=True
172
+ )
173
+ except FileNotFoundError:
174
+ if "PYTEST_CURRENT_TEST" not in os.environ:
175
+ raise FileNotFoundError(
176
+ "FFMPEG не установлен. Укажите путь к установленному FFMPEG через переменную окружения MVSEPLESS_FFMPEG"
177
+ )
178
+
179
+ def check_ffprobe(self):
180
+ try:
181
+ ffmpeg_version_output = subprocess.check_output(
182
+ [self.ffprobe_path, "-version"], text=True
183
+ )
184
+ except FileNotFoundError:
185
+ if "PYTEST_CURRENT_TEST" not in os.environ:
186
+ raise FileNotFoundError(
187
+ "FFPROBE не установлен. Укажите путь к установленному FFPROBE через переменную окружения MVSEPLESS_FFPROBE"
188
+ )
189
+
190
+ def fit_sr(
191
+ self,
192
+ f: (
193
+ str
194
+ | Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"]
195
+ ) = "mp3",
196
+ sr: int = 44100,
197
+ ) -> int:
198
+ format_info = self.sample_rates.get(f.lower())
199
+
200
+ if not format_info:
201
+ return None
202
+
203
+ if "supported" in format_info:
204
+ supported_rates = format_info["supported"]
205
+ if sr in supported_rates:
206
+ return sr
207
+
208
+ return min(supported_rates, key=lambda x: abs(x - sr))
209
+
210
+ elif "min" in format_info and "max" in format_info:
211
+ min_rate = format_info["min"]
212
+ max_rate = format_info["max"]
213
+
214
+ if sr < min_rate:
215
+ return min_rate
216
+ elif sr > max_rate:
217
+ return max_rate
218
+ else:
219
+ return sr
220
+
221
+ return None
222
+
223
+ def fit_br(
224
+ self,
225
+ f: (
226
+ str
227
+ | Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"]
228
+ ) = "mp3",
229
+ br: int = 320,
230
+ ) -> int:
231
+ if f not in self.bitrate_limit:
232
+ raise NotSupportedFormat(f"Формат {f} не поддерживается")
233
+
234
+ limits = self.bitrate_limit[f]
235
+
236
+ if br < limits["min"]:
237
+ return limits["min"]
238
+ elif br > limits["max"]:
239
+ return limits["max"]
240
+ else:
241
+ return br
242
+
243
+ def get_info(
244
+ self,
245
+ i: str | os.PathLike | Callable | None = None,
246
+ ) -> dict[int, dict[int, float]]:
247
+ audio_info = {}
248
+ if i:
249
+ if isinstance(i, Path):
250
+ i = str(i)
251
+ if os.path.exists(i):
252
+ cmd = [
253
+ self.ffprobe_path,
254
+ "-i",
255
+ i,
256
+ "-v",
257
+ "quiet",
258
+ "-hide_banner",
259
+ "-show_entries",
260
+ "stream=index,sample_rate,duration",
261
+ "-select_streams",
262
+ "a",
263
+ "-of",
264
+ "json",
265
+ ]
266
+
267
+ process = subprocess.Popen(
268
+ cmd,
269
+ stdin=subprocess.PIPE,
270
+ stdout=subprocess.PIPE,
271
+ stderr=subprocess.PIPE,
272
+ )
273
+
274
+ stdout, stderr = process.communicate()
275
+
276
+ if process.returncode != 0:
277
+ print(f"STDERR: {stderr.decode('utf-8')}")
278
+ print(f"STDOUT: {stdout.decode('utf-8')}")
279
+
280
+ json_output = json.loads(stdout)
281
+ streams = json_output["streams"]
282
+ if not streams:
283
+ pass
284
+
285
+ else:
286
+ for a, stream in enumerate(streams):
287
+ audio_info[a] = {
288
+ "sample_rate": int(stream.get("sample_rate", 0)),
289
+ "duration": float(stream.get("duration", 0)),
290
+ }
291
+
292
+ return audio_info
293
+
294
+ else:
295
+ raise FileExistsError("Указанного файла не существует")
296
+
297
+ else:
298
+ raise NotInputFileSpecified("Не указан путь к файлу")
299
+
300
+ def check(self, i: str | os.PathLike | Callable | None = None) -> bool:
301
+ if i:
302
+ if isinstance(i, Path):
303
+ i = str(i)
304
+ if os.path.exists(i):
305
+ info = self.get_info(i=i)
306
+ if info:
307
+ list_streams = list(info.keys())
308
+ if len(list_streams) > 0:
309
+ if info[0].get("sample_rate") > 0:
310
+ return True
311
+ else:
312
+ return False
313
+ else:
314
+ return False
315
+ else:
316
+ return False
317
+ else:
318
+ raise FileExistsError(казанного файла не существует")
319
+ else:
320
+ raise NotInputFileSpecified("Не указан путь к файлу")
321
+
322
+ def read(
323
+ self,
324
+ i: str | os.PathLike | Callable | None = None,
325
+ sr: int | None = None,
326
+ mono: bool = False,
327
+ dtype: DTypeLike = np.float32,
328
+ s: int = 0,
329
+ ) -> tuple[np.ndarray, int, float]:
330
+ output_format = self.dtypes_dict.get(dtype, None)
331
+ if not output_format:
332
+ raise NotSupportedDataType(f"Этот тип данных не поддерживается {dtype}")
333
+ if i:
334
+ if isinstance(i, Path):
335
+ i = str(i)
336
+ if os.path.exists(i):
337
+ audio_info = self.get_info(i=i)
338
+ list_streams = list(audio_info.keys())
339
+ if audio_info.get(s, False):
340
+ stream = s
341
+ else:
342
+ if len(list_streams) > 0:
343
+ stream = 0
344
+ else:
345
+ raise FileIsNotAudio("В входном файле нет аудио потоков")
346
+
347
+ sample_rate_input = audio_info[stream]["sample_rate"]
348
+ if sample_rate_input == 0:
349
+ raise FileIsNotAudio("В входном файле нет аудио потоков")
350
+
351
+ cmd = [
352
+ self.ffmpeg_path,
353
+ "-i",
354
+ i,
355
+ "-map",
356
+ f"0:a:{stream}",
357
+ "-vn",
358
+ "-f",
359
+ output_format,
360
+ "-ac",
361
+ "1" if mono else "2",
362
+ ]
363
+
364
+ if sr:
365
+ cmd.extend(["-ar", str(sr)])
366
+ else:
367
+ sr = sample_rate_input
368
+
369
+ cmd.append("pipe:1")
370
+
371
+ process = subprocess.Popen(
372
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=10**8
373
+ )
374
+
375
+ try:
376
+
377
+ raw_audio, stderr = process.communicate(timeout=300)
378
+
379
+ if process.returncode != 0:
380
+ raise ErrorDecode(f"FFmpeg error: {stderr.decode()}")
381
+
382
+ except subprocess.TimeoutExpired:
383
+ process.kill()
384
+ raise ErrorDecode("FFmpeg timeout при чтении файла")
385
+
386
+ audio_array = np.frombuffer(raw_audio, dtype=dtype)
387
+
388
+ channels = 1 if mono else 2
389
+ audio_array = audio_array.reshape((-1, channels)).T
390
+ if audio_array.ndim > 1 and channels == 1:
391
+ audio_array = np.mean(
392
+ audio_array, axis=tuple(range(audio_array.ndim - 1))
393
+ )
394
+
395
+ len_samples = float(audio_array.shape[-1])
396
+
397
+ duration = len_samples / sr
398
+
399
+ print(f"Частота дискретизации: {sr}")
400
+
401
+ return audio_array.copy(), sr, duration
402
+ else:
403
+ raise FileExistsError("Указанного файла не существует")
404
+
405
+ else:
406
+ raise NotInputFileSpecified("Не указан путь к файлу")
407
+
408
+ def write(
409
+ self,
410
+ o: str | os.PathLike | Callable | None = None,
411
+ array: np.ndarray = np.array([], dtype=np.float32),
412
+ sr: int = 44100,
413
+ of: (
414
+ str
415
+ | Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"]
416
+ | None
417
+ ) = None,
418
+ br: str | int | None = None,
419
+ ) -> str:
420
+ if isinstance(array, np.ndarray):
421
+
422
+ if len(array.shape) == 1:
423
+ array = array.reshape(-1, 1)
424
+ elif len(array.shape) == 2:
425
+ if array.shape[0] == 2:
426
+ array = array.T
427
+ else:
428
+ raise ValueError(
429
+ "numpy-массив должен быть либо одномерным, либо двухмерным"
430
+ )
431
+
432
+ if array.dtype == np.int16:
433
+ input_format = "s16le"
434
+ elif array.dtype == np.int32:
435
+ input_format = "s32le"
436
+ elif array.dtype == np.float32:
437
+ input_format = "f32le"
438
+ elif array.dtype == np.float64:
439
+ input_format = "f64le"
440
+ else:
441
+ raise NotSupportedDataType(
442
+ f"Этот тип данных не поддерживается {array.dtype}"
443
+ )
444
+
445
+ if array.shape[1] == 1:
446
+ audio_bytes = array.tobytes()
447
+
448
+ channels = 1
449
+
450
+ elif array.shape[1] == 2:
451
+ audio_bytes = array.tobytes()
452
+
453
+ channels = 2
454
+ else:
455
+ raise ValueError("numpy-массив должен содержать 1 или 2 канала")
456
+
457
+ else:
458
+ raise ValueError("Вход должен быть numpy-массивом")
459
+
460
+ if o:
461
+ if isinstance(o, Path):
462
+ o = str(o)
463
+ output_dir = os.path.dirname(o)
464
+ output_base = os.path.basename(o)
465
+ output_name, output_ext = os.path.splitext(output_base)
466
+ if output_dir != "":
467
+ os.makedirs(output_dir, exist_ok=True)
468
+ if output_ext == "":
469
+ if of:
470
+ o += f".{of}"
471
+ else:
472
+ o += f".mp3"
473
+ elif output_ext == ".":
474
+ if of:
475
+ o += f"{of}"
476
+ else:
477
+ o += f"mp3"
478
+ else:
479
+ raise NotOutputFileSpecified("Не указан путь к выходному файлу")
480
+
481
+ if of:
482
+ if of in self.output_formats:
483
+ output_name, output_ext = os.path.splitext(o)
484
+ if output_ext == f".{of}":
485
+ pass
486
+ else:
487
+ o = f"{os.path.join(output_dir, output_name)}.{of}"
488
+ else:
489
+ raise NotSupportedFormat(f"Неподдерживаемый формат: {of}")
490
+ else:
491
+ of = os.path.splitext(o)[1].strip(".")
492
+ if of in self.output_formats:
493
+ pass
494
+ else:
495
+ raise NotSupportedFormat(f"Неподдерживаемый формат: {of}")
496
+
497
+ if sr:
498
+ if isinstance(sr, int):
499
+ sample_rate_fixed = self.fit_sr(f=of, sr=sr)
500
+ elif isinstance(sr, float):
501
+ sr = int(sr)
502
+ sample_rate_fixed = self.fit_sr(f=of, sr=sr)
503
+ else:
504
+ raise SampleRateError(
505
+ f"Частота дискретизации должна быть числом\n\nЗначение: {sr}\nТип: {type(sr)}"
506
+ )
507
+ else:
508
+ raise SampleRateError("Не указана частота дискретизации")
509
+
510
+ bitrate_fixed = "320k"
511
+
512
+ if of not in ["wav", "flac", "aiff"]:
513
+ if br:
514
+ if isinstance(br, int):
515
+ bitrate_fixed = self.fit_br(f=of, br=br)
516
+ elif isinstance(br, float):
517
+ bitrate_fixed = self.fit_br(f=of, br=int(br))
518
+ elif isinstance(br, str):
519
+ bitrate_fixed = self.fit_br(f=of, br=int(br.strip("k").strip("K")))
520
+ else:
521
+ bitrate_fixed = self.fit_br(f=of, br=320)
522
+ else:
523
+ bitrate_fixed = self.fit_br(of, 320)
524
+
525
+ format_settings = {
526
+ "wav": [
527
+ "-c:a",
528
+ "pcm_f32le",
529
+ "-sample_fmt",
530
+ "flt",
531
+ ],
532
+ "aiff": [
533
+ "-c:a",
534
+ "pcm_f32be",
535
+ "-sample_fmt",
536
+ "flt",
537
+ ],
538
+ "flac": [
539
+ "-c:a",
540
+ "flac",
541
+ "-compression_level",
542
+ "12",
543
+ "-sample_fmt",
544
+ "s32",
545
+ ],
546
+ "mp3": [
547
+ "-c:a",
548
+ "libmp3lame",
549
+ "-b:a",
550
+ f"{bitrate_fixed}k",
551
+ ],
552
+ "ogg": [
553
+ "-c:a",
554
+ "libvorbis",
555
+ "-b:a",
556
+ f"{bitrate_fixed}k",
557
+ ],
558
+ "opus": [
559
+ "-c:a",
560
+ "libopus",
561
+ "-b:a",
562
+ f"{bitrate_fixed}k",
563
+ ],
564
+ "m4a": [
565
+ "-c:a",
566
+ "aac",
567
+ "-b:a",
568
+ f"{bitrate_fixed}k",
569
+ ],
570
+ "aac": [
571
+ "-c:a",
572
+ "aac",
573
+ "-b:a",
574
+ f"{bitrate_fixed}k",
575
+ ],
576
+ "ac3": [
577
+ "-c:a",
578
+ "ac3",
579
+ "-b:a",
580
+ f"{bitrate_fixed}k",
581
+ ],
582
+ }
583
+
584
+ cmd = [
585
+ self.ffmpeg_path,
586
+ "-y",
587
+ "-f",
588
+ input_format,
589
+ "-ar",
590
+ str(sr),
591
+ "-ac",
592
+ str(channels),
593
+ "-i",
594
+ "pipe:0",
595
+ "-ac",
596
+ str(channels),
597
+ ]
598
+
599
+ cmd.extend(["-ar", str(sample_rate_fixed)])
600
+ cmd.extend(format_settings[of])
601
+ o_dir, o_base = os.path.split(o)
602
+ o_base_n, o_base_ext = os.path.splitext(o_base)
603
+ o_base_n = self.sanitize(o_base_n)
604
+ o_base_n = self.short(o_base_n)
605
+ o = os.path.join(o_dir, f"{o_base_n}{o_base_ext}")
606
+ o = self.iter(o)
607
+ cmd.append(o)
608
+
609
+ process = subprocess.Popen(
610
+ cmd,
611
+ stdin=subprocess.PIPE,
612
+ stdout=subprocess.PIPE,
613
+ stderr=subprocess.PIPE,
614
+ )
615
+
616
+ try:
617
+ stdout, stderr = process.communicate(input=audio_bytes, timeout=300)
618
+ except subprocess.TimeoutExpired:
619
+ process.kill()
620
+ raise ErrorEncode("FFmpeg timeout: операция заняла слишком много времени")
621
+
622
+ if process.returncode != 0:
623
+ raise ErrorEncode(
624
+ f"FFmpeg завершился с ошибкой (код: {process.returncode})"
625
+ )
626
+
627
+ return os.path.abspath(o)
628
+
629
+
630
+ class Inverter(Audio):
631
+ def __init__(self):
632
+ super().__init__()
633
+ self.test = "test"
634
+ self.w_types = [
635
+ "boxcar",
636
+ "triang",
637
+ "blackman",
638
+ "hamming",
639
+ "hann",
640
+ "bartlett",
641
+ "flattop",
642
+ "parzen",
643
+ "bohman",
644
+ "blackmanharris",
645
+ "nuttall",
646
+ "barthann",
647
+ "cosine",
648
+ "exponential",
649
+ "tukey",
650
+ "taylor",
651
+ "lanczos",
652
+ ]
653
+
654
+ def load_audio(self, filepath):
655
+ try:
656
+ y, sr, _ = self.read(i=filepath, sr=None, mono=False)
657
+ return y, sr
658
+ except Exception as e:
659
+ print(f"Ошибка загрузки аудио: {e}")
660
+ return None, None
661
+
662
+ def process_channel(
663
+ self, y1_ch, y2_ch, sr, method, w_size=2048, overlap=2, w_type="hann"
664
+ ):
665
+ HOP_LENGTH = w_size // overlap
666
+ if method == "waveform":
667
+ return y1_ch - y2_ch
668
+
669
+ elif method == "spectrogram":
670
+ S1 = librosa.stft(
671
+ y1_ch, n_fft=w_size, hop_length=HOP_LENGTH, win_length=w_size
672
+ )
673
+ S2 = librosa.stft(
674
+ y2_ch, n_fft=w_size, hop_length=HOP_LENGTH, win_length=w_size
675
+ )
676
+
677
+ mag1 = np.abs(S1)
678
+ mag2 = np.abs(S2)
679
+
680
+ mag_result = np.maximum(mag1 - mag2, 0)
681
+
682
+ phase = np.angle(S1)
683
+
684
+ S_result = mag_result * np.exp(1j * phase)
685
+
686
+ return librosa.istft(
687
+ S_result,
688
+ n_fft=w_size,
689
+ hop_length=HOP_LENGTH,
690
+ win_length=w_size,
691
+ length=len(y1_ch),
692
+ )
693
+
694
+ def process_audio(
695
+ self,
696
+ audio1_path,
697
+ audio2_path,
698
+ out_format,
699
+ method,
700
+ output_path="./inverted.mp3",
701
+ w_size=2048,
702
+ overlap=2,
703
+ w_type="hann",
704
+ ):
705
+ y1, sr1 = self.load_audio(audio1_path)
706
+ y2, sr2 = self.load_audio(audio2_path)
707
+
708
+ if sr1 is None or sr2 is None:
709
+ raise Exception("Произошла ошибка при чтении файлов")
710
+
711
+ channels1 = 1 if y1.ndim == 1 else y1.shape[0]
712
+ channels2 = 1 if y2.ndim == 1 else y2.shape[0]
713
+
714
+ if channels1 > 1:
715
+ y1 = y1.T
716
+ else:
717
+ y1 = y1.reshape(-1, 1)
718
+
719
+ if channels2 > 1:
720
+ y2 = y2.T
721
+ else:
722
+ y2 = y2.reshape(-1, 1)
723
+
724
+ if sr1 != sr2:
725
+ if channels2 > 1:
726
+ y2_resampled_list = []
727
+ for c in range(channels2):
728
+ channel_resampled = librosa.resample(
729
+ y2[:, c], orig_sr=sr2, target_sr=sr1
730
+ )
731
+ y2_resampled_list.append(channel_resampled)
732
+
733
+ min_channel_length = min(len(ch) for ch in y2_resampled_list)
734
+
735
+ y2_resampled = np.zeros(
736
+ (min_channel_length, channels2), dtype=np.float32
737
+ )
738
+ for c, channel in enumerate(y2_resampled_list):
739
+ y2_resampled[:, c] = channel[:min_channel_length]
740
+
741
+ y2 = y2_resampled
742
+ else:
743
+ y2 = librosa.resample(y2[:, 0], orig_sr=sr2, target_sr=sr1)
744
+ y2 = y2.reshape(-1, 1)
745
+ sr2 = sr1
746
+
747
+ min_len = min(len(y1), len(y2))
748
+ y1 = y1[:min_len]
749
+ y2 = y2[:min_len]
750
+
751
+ result_channels = []
752
+
753
+ if channels1 == 1 and channels2 > 1:
754
+ y2 = y2.mean(axis=1, keepdims=True)
755
+ channels2 = 1
756
+
757
+ for c in range(channels1):
758
+ y1_ch = y1[:, c]
759
+
760
+ if channels2 == 1:
761
+ y2_ch = y2[:, 0]
762
+ else:
763
+ y2_ch = y2[:, min(c, channels2 - 1)]
764
+
765
+ result_ch = self.process_channel(
766
+ y1_ch, y2_ch, sr1, method, w_size=w_size, overlap=overlap, w_type=w_type
767
+ )
768
+ result_channels.append(result_ch)
769
+
770
+ if len(result_channels) > 1:
771
+ result = np.column_stack(result_channels)
772
+ else:
773
+ result = np.array(result_channels[0])
774
+
775
+ if result.ndim > 1:
776
+ for c in range(result.shape[1]):
777
+ channel = result[:, c]
778
+ max_val = np.max(np.abs(channel))
779
+ if max_val > 0:
780
+ result[:, c] = channel * 0.9 / max_val
781
+ else:
782
+ max_val = np.max(np.abs(result))
783
+ if max_val > 0:
784
+ result = result * 0.9 / max_val
785
+
786
+ inverted = self.write(
787
+ o=output_path, array=result.T, sr=sr1, of=out_format, br="320k"
788
+ )
789
+ return inverted
mvsepless/downloader.py CHANGED
@@ -1,92 +1,90 @@
1
- import os
2
- import yt_dlp
3
- import time
4
- from tqdm import tqdm
5
- import urllib.request
6
-
7
- DOWNLOAD_DIR = os.environ.get(
8
- "MVSEPLESS_DOWNLOAD_DIR", os.path.join(os.getcwd(), "downloaded")
9
- )
10
-
11
- def dw_file(url_model: str, local_path: str, retries: int = 180):
12
- dir_name = os.path.dirname(local_path)
13
- if dir_name != "":
14
- os.makedirs(dir_name, exist_ok=True)
15
-
16
- class TqdmUpTo(tqdm):
17
- def update_to(self, b=1, bsize=1, tsize=None):
18
- if tsize is not None:
19
- self.total = tsize
20
- self.update(b * bsize - self.n)
21
-
22
- for attempt in range(retries):
23
- try:
24
- with TqdmUpTo(
25
- unit="B",
26
- unit_scale=True,
27
- unit_divisor=1024,
28
- miniters=1,
29
- desc=os.path.basename(local_path),
30
- ) as t:
31
- urllib.request.urlretrieve(
32
- url_model, local_path, reporthook=t.update_to
33
- )
34
- # Если дошли сюда - загрузка успешна, прерываем цикл
35
- break
36
- except Exception as e:
37
- print(f"Попытка {attempt + 1}/{retries} не удалась. Ошибка: {e}")
38
- if attempt < retries - 1: # Если это не последняя попытка
39
- print("Повторная попытка...")
40
- time.sleep(2) # Небольшая задержка перед повторной попыткой
41
- else:
42
- print("Все попытки загрузки завершились неудачно")
43
- raise # Пробрасываем исключение дальше
44
-
45
- def dw_yt_dlp(
46
- url,
47
- output_dir=None,
48
- cookie=None,
49
- output_format="mp3",
50
- output_bitrate="320",
51
- title=None,
52
- ):
53
- # Подготовка шаблона имени файла
54
- outtmpl = "%(title)s.%(ext)s" if title is None else f"{title}.%(ext)s"
55
-
56
- ydl_opts = {
57
- "format": "bestaudio/best",
58
- "outtmpl": os.path.join(DOWNLOAD_DIR if not output_dir else output_dir, outtmpl),
59
- "postprocessors": [
60
- {
61
- "key": "FFmpegExtractAudio",
62
- "preferredcodec": output_format,
63
- "preferredquality": output_bitrate,
64
- }
65
- ],
66
- "noplaylist": True, # Скачивать только одно видео, не плейлист
67
- "quiet": True, # Отключить вывод в консоль
68
- "no_warnings": True, # Скрыть предупреждения
69
- }
70
-
71
- # Добавляем cookies если указаны
72
- if cookie and os.path.exists(cookie):
73
- ydl_opts["cookiefile"] = cookie
74
-
75
- with yt_dlp.YoutubeDL(ydl_opts) as ydl:
76
- try:
77
- info = ydl.extract_info(url, download=True)
78
- if "_type" in info and info["_type"] == "playlist":
79
- # Для плейлистов берем первое видео
80
- entry = info["entries"][0]
81
- filename = ydl.prepare_filename(entry)
82
- else:
83
- # Для одиночного видео
84
- filename = ydl.prepare_filename(info)
85
-
86
- # Заменяем оригинальное расширение на выбранный формат
87
- base, _ = os.path.splitext(filename)
88
- audio_file = base + f".{output_format}"
89
-
90
- return os.path.join(DOWNLOAD_DIR, audio_file)
91
- except Exception as e:
92
- return None
 
1
+ import os
2
+ import yt_dlp
3
+ import time
4
+ from tqdm import tqdm
5
+ import urllib.request
6
+
7
+ DOWNLOAD_DIR = os.environ.get(
8
+ "MVSEPLESS_DOWNLOAD_DIR", os.path.join(os.getcwd(), "downloaded")
9
+ )
10
+
11
+
12
+ def dw_file(url_model: str, local_path: str, retries: int = 180):
13
+ dir_name = os.path.dirname(local_path)
14
+ if dir_name != "":
15
+ os.makedirs(dir_name, exist_ok=True)
16
+
17
+ class TqdmUpTo(tqdm):
18
+ def update_to(self, b=1, bsize=1, tsize=None):
19
+ if tsize is not None:
20
+ self.total = tsize
21
+ self.update(b * bsize - self.n)
22
+
23
+ for attempt in range(retries):
24
+ try:
25
+ with TqdmUpTo(
26
+ unit="B",
27
+ unit_scale=True,
28
+ unit_divisor=1024,
29
+ miniters=1,
30
+ desc=os.path.basename(local_path),
31
+ ) as t:
32
+ urllib.request.urlretrieve(
33
+ url_model, local_path, reporthook=t.update_to
34
+ )
35
+ break
36
+ except Exception as e:
37
+ print(f"Попытка {attempt + 1}/{retries} не удалась. Ошибка: {e}")
38
+ if attempt < retries - 1:
39
+ print("Повторная попытка...")
40
+ time.sleep(2)
41
+ else:
42
+ print("Все попытки загрузки завершились неудачно")
43
+ raise
44
+
45
+
46
+ def dw_yt_dlp(
47
+ url,
48
+ output_dir=None,
49
+ cookie=None,
50
+ output_format="mp3",
51
+ output_bitrate="320",
52
+ title=None,
53
+ ):
54
+ outtmpl = "%(title)s.%(ext)s" if title is None else f"{title}.%(ext)s"
55
+
56
+ ydl_opts = {
57
+ "format": "bestaudio/best",
58
+ "outtmpl": os.path.join(
59
+ DOWNLOAD_DIR if not output_dir else output_dir, outtmpl
60
+ ),
61
+ "postprocessors": [
62
+ {
63
+ "key": "FFmpegExtractAudio",
64
+ "preferredcodec": output_format,
65
+ "preferredquality": output_bitrate,
66
+ }
67
+ ],
68
+ "noplaylist": True,
69
+ "quiet": True,
70
+ "no_warnings": True,
71
+ }
72
+
73
+ if cookie and os.path.exists(cookie):
74
+ ydl_opts["cookiefile"] = cookie
75
+
76
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
77
+ try:
78
+ info = ydl.extract_info(url, download=True)
79
+ if "_type" in info and info["_type"] == "playlist":
80
+ entry = info["entries"][0]
81
+ filename = ydl.prepare_filename(entry)
82
+ else:
83
+ filename = ydl.prepare_filename(info)
84
+
85
+ base, _ = os.path.splitext(filename)
86
+ audio_file = base + f".{output_format}"
87
+
88
+ return os.path.join(DOWNLOAD_DIR, audio_file)
89
+ except Exception as e:
90
+ return None
 
 
mvsepless/ensemble.py CHANGED
@@ -1,224 +1,206 @@
1
- # coding: utf-8
2
- __author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
3
-
4
- import os
5
- import sys
6
- import librosa
7
- import tempfile
8
- import numpy as np
9
- import argparse
10
- if not __package__:
11
- from audio import Audio
12
- from namer import Namer
13
- else:
14
- from .audio import Audio
15
- from .namer import Namer
16
-
17
- audio = Audio()
18
-
19
- def stft(wave, nfft, hl):
20
- wave_left = np.asfortranarray(wave[0])
21
- wave_right = np.asfortranarray(wave[1])
22
- spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
23
- spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
24
- spec = np.asfortranarray([spec_left, spec_right])
25
- return spec
26
-
27
-
28
- def istft(spec, hl, length):
29
- spec_left = np.asfortranarray(spec[0])
30
- spec_right = np.asfortranarray(spec[1])
31
- wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
32
- wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
33
- wave = np.asfortranarray([wave_left, wave_right])
34
- return wave
35
-
36
-
37
- def absmax(a, *, axis):
38
- dims = list(a.shape)
39
- dims.pop(axis)
40
- indices = np.ogrid[tuple(slice(0, d) for d in dims)]
41
- argmax = np.abs(a).argmax(axis=axis)
42
- # Convert indices to list before insertion
43
- indices = list(indices)
44
- indices.insert(axis % len(a.shape), argmax)
45
- return a[tuple(indices)]
46
-
47
-
48
- def absmin(a, *, axis):
49
- dims = list(a.shape)
50
- dims.pop(axis)
51
- indices = np.ogrid[tuple(slice(0, d) for d in dims)]
52
- argmax = np.abs(a).argmin(axis=axis)
53
- indices.insert((len(a.shape) + axis) % len(a.shape), argmax)
54
- return a[tuple(indices)]
55
-
56
-
57
- def lambda_max(arr, axis=None, key=None, keepdims=False):
58
- idxs = np.argmax(key(arr), axis)
59
- if axis is not None:
60
- idxs = np.expand_dims(idxs, axis)
61
- result = np.take_along_axis(arr, idxs, axis)
62
- if not keepdims:
63
- result = np.squeeze(result, axis=axis)
64
- return result
65
- else:
66
- return arr.flatten()[idxs]
67
-
68
-
69
- def lambda_min(arr, axis=None, key=None, keepdims=False):
70
- idxs = np.argmin(key(arr), axis)
71
- if axis is not None:
72
- idxs = np.expand_dims(idxs, axis)
73
- result = np.take_along_axis(arr, idxs, axis)
74
- if not keepdims:
75
- result = np.squeeze(result, axis=axis)
76
- return result
77
- else:
78
- return arr.flatten()[idxs]
79
-
80
-
81
- def average_waveforms(pred_track, weights, algorithm):
82
- """
83
- :param pred_track: shape = (num, channels, length)
84
- :param weights: shape = (num, )
85
- :param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
86
- :return: averaged waveform in shape (channels, length)
87
- """
88
-
89
- pred_track = np.array(pred_track)
90
- final_length = pred_track.shape[-1]
91
-
92
- mod_track = []
93
- for i in range(pred_track.shape[0]):
94
- if algorithm == "avg_wave":
95
- mod_track.append(pred_track[i] * weights[i])
96
- elif algorithm in ["median_wave", "min_wave", "max_wave"]:
97
- mod_track.append(pred_track[i])
98
- elif algorithm in ["avg_fft", "min_fft", "max_fft", "median_fft"]:
99
- spec = stft(pred_track[i], nfft=2048, hl=1024)
100
- if algorithm in ["avg_fft"]:
101
- mod_track.append(spec * weights[i])
102
- else:
103
- mod_track.append(spec)
104
- pred_track = np.array(mod_track)
105
-
106
- if algorithm in ["avg_wave"]:
107
- pred_track = pred_track.sum(axis=0)
108
- pred_track /= np.array(weights).sum().T
109
- elif algorithm in ["median_wave"]:
110
- pred_track = np.median(pred_track, axis=0)
111
- elif algorithm in ["min_wave"]:
112
- pred_track = np.array(pred_track)
113
- pred_track = lambda_min(pred_track, axis=0, key=np.abs)
114
- elif algorithm in ["max_wave"]:
115
- pred_track = np.array(pred_track)
116
- pred_track = lambda_max(pred_track, axis=0, key=np.abs)
117
- elif algorithm in ["avg_fft"]:
118
- pred_track = pred_track.sum(axis=0)
119
- pred_track /= np.array(weights).sum()
120
- pred_track = istft(pred_track, 1024, final_length)
121
- elif algorithm in ["min_fft"]:
122
- pred_track = np.array(pred_track)
123
- pred_track = lambda_min(pred_track, axis=0, key=np.abs)
124
- pred_track = istft(pred_track, 1024, final_length)
125
- elif algorithm in ["max_fft"]:
126
- pred_track = np.array(pred_track)
127
- pred_track = absmax(pred_track, axis=0)
128
- pred_track = istft(pred_track, 1024, final_length)
129
- elif algorithm in ["median_fft"]:
130
- pred_track = np.array(pred_track)
131
- pred_track = np.median(pred_track, axis=0)
132
- pred_track = istft(pred_track, 1024, final_length)
133
- return pred_track
134
-
135
-
136
- def ensemble_audio_files(
137
- files, output="res.wav", ensemble_type="avg_wave", weights=None, out_format="wav", add_wav=False
138
- ) -> str | tuple[str, str]:
139
- """
140
- Основная функция для объединения аудиофайлов
141
-
142
- :param files: список путей к аудиофайлам
143
- :param output: путь для сохранения результ��та
144
- :param ensemble_type: метод объединения (avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft)
145
- :param weights: список весов для каждого файла (None для равных весов)
146
- :return: None
147
- """
148
- print("Алгоритм склеивания: {}".format(ensemble_type))
149
- print("Количество входных файлов: {}".format(len(files)))
150
- if weights is not None:
151
- weights = np.array(weights)
152
- else:
153
- weights = np.ones(len(files))
154
- print("Весы: {}".format(weights))
155
- print("Имя выходного файла: {}".format(output))
156
-
157
- data = []
158
- sr = None
159
- max_length = 0
160
- max_channels = 0
161
-
162
- # Первый проход: определяем максимальную длину и количество каналов
163
- for f in files:
164
- if not os.path.isfile(f):
165
- print("Не удается найти файл: {}. Check paths.".format(f))
166
- exit()
167
- print("Читается файл: {}".format(f))
168
- wav, current_sr, _ = audio.read(i=f, sr=None, mono=False)
169
- if sr is None:
170
- sr = current_sr
171
- elif sr != current_sr:
172
- print("Частота дискретизации на всех файлах должна быть одинаковой")
173
- exit()
174
-
175
- # Определяем количество каналов
176
- if wav.ndim == 1:
177
- channels = 1
178
- length = len(wav)
179
- else:
180
- channels = wav.shape[0]
181
- length = wav.shape[1]
182
-
183
- max_length = max(max_length, length)
184
- max_channels = max(max_channels, channels)
185
- print("Форма сигнала: {} частота дискретизации: {}".format(wav.shape, sr))
186
-
187
- # Второй проход: обработка и выравнивание файлов
188
- for f in files:
189
- wav, current_sr, _ = audio.read(i=f, sr=None, mono=False)
190
-
191
- # Обработка каналов
192
- if wav.ndim == 1:
193
- # Моно -> стерео
194
- wav = np.vstack([wav, wav])
195
- elif wav.shape[0] == 1:
196
- # Один канал -> стерео
197
- wav = np.vstack([wav[0], wav[0]])
198
- elif wav.shape[0] > 2:
199
- # Более 2 каналов -> берем первые два
200
- wav = wav[:2, :]
201
-
202
- # Выравнивание длины
203
- if wav.shape[1] < max_length:
204
- pad_width = ((0, 0), (0, max_length - wav.shape[1]))
205
- wav = np.pad(wav, pad_width, mode="constant")
206
- elif wav.shape[1] > max_length:
207
- wav = wav[:, :max_length]
208
-
209
- data.append(wav)
210
-
211
- data = np.array(data)
212
- res = average_waveforms(data, weights, ensemble_type)
213
- print("Форма результата: {}".format(res.shape))
214
-
215
- output_wav = f"{output}_orig.wav"
216
- output = f"{output}.{out_format}"
217
-
218
- output = audio.write(o=output, array=res.T, sr=sr, of=out_format, br="320k")
219
- if add_wav:
220
- output_wav = audio.write(o=output_wav, array=res.T, sr=sr, of="wav")
221
- return output, output_wav
222
-
223
- else:
224
- return output
 
1
+ __author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
2
+
3
+ import os
4
+ import sys
5
+ import librosa
6
+ import tempfile
7
+ import numpy as np
8
+ import argparse
9
+
10
+ if not __package__:
11
+ from audio import Audio
12
+ from namer import Namer
13
+ else:
14
+ from .audio import Audio
15
+ from .namer import Namer
16
+
17
+ audio = Audio()
18
+
19
+
20
+ def stft(wave, nfft, hl):
21
+ wave_left = np.asfortranarray(wave[0])
22
+ wave_right = np.asfortranarray(wave[1])
23
+ spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
24
+ spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
25
+ spec = np.asfortranarray([spec_left, spec_right])
26
+ return spec
27
+
28
+
29
+ def istft(spec, hl, length):
30
+ spec_left = np.asfortranarray(spec[0])
31
+ spec_right = np.asfortranarray(spec[1])
32
+ wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
33
+ wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
34
+ wave = np.asfortranarray([wave_left, wave_right])
35
+ return wave
36
+
37
+
38
+ def absmax(a, *, axis):
39
+ dims = list(a.shape)
40
+ dims.pop(axis)
41
+ indices = np.ogrid[tuple(slice(0, d) for d in dims)]
42
+ argmax = np.abs(a).argmax(axis=axis)
43
+ indices = list(indices)
44
+ indices.insert(axis % len(a.shape), argmax)
45
+ return a[tuple(indices)]
46
+
47
+
48
+ def absmin(a, *, axis):
49
+ dims = list(a.shape)
50
+ dims.pop(axis)
51
+ indices = np.ogrid[tuple(slice(0, d) for d in dims)]
52
+ argmax = np.abs(a).argmin(axis=axis)
53
+ indices.insert((len(a.shape) + axis) % len(a.shape), argmax)
54
+ return a[tuple(indices)]
55
+
56
+
57
+ def lambda_max(arr, axis=None, key=None, keepdims=False):
58
+ idxs = np.argmax(key(arr), axis)
59
+ if axis is not None:
60
+ idxs = np.expand_dims(idxs, axis)
61
+ result = np.take_along_axis(arr, idxs, axis)
62
+ if not keepdims:
63
+ result = np.squeeze(result, axis=axis)
64
+ return result
65
+ else:
66
+ return arr.flatten()[idxs]
67
+
68
+
69
+ def lambda_min(arr, axis=None, key=None, keepdims=False):
70
+ idxs = np.argmin(key(arr), axis)
71
+ if axis is not None:
72
+ idxs = np.expand_dims(idxs, axis)
73
+ result = np.take_along_axis(arr, idxs, axis)
74
+ if not keepdims:
75
+ result = np.squeeze(result, axis=axis)
76
+ return result
77
+ else:
78
+ return arr.flatten()[idxs]
79
+
80
+
81
+ def average_waveforms(pred_track, weights, algorithm):
82
+
83
+ pred_track = np.array(pred_track)
84
+ final_length = pred_track.shape[-1]
85
+
86
+ mod_track = []
87
+ for i in range(pred_track.shape[0]):
88
+ if algorithm == "avg_wave":
89
+ mod_track.append(pred_track[i] * weights[i])
90
+ elif algorithm in ["median_wave", "min_wave", "max_wave"]:
91
+ mod_track.append(pred_track[i])
92
+ elif algorithm in ["avg_fft", "min_fft", "max_fft", "median_fft"]:
93
+ spec = stft(pred_track[i], nfft=2048, hl=1024)
94
+ if algorithm in ["avg_fft"]:
95
+ mod_track.append(spec * weights[i])
96
+ else:
97
+ mod_track.append(spec)
98
+ pred_track = np.array(mod_track)
99
+
100
+ if algorithm in ["avg_wave"]:
101
+ pred_track = pred_track.sum(axis=0)
102
+ pred_track /= np.array(weights).sum().T
103
+ elif algorithm in ["median_wave"]:
104
+ pred_track = np.median(pred_track, axis=0)
105
+ elif algorithm in ["min_wave"]:
106
+ pred_track = np.array(pred_track)
107
+ pred_track = lambda_min(pred_track, axis=0, key=np.abs)
108
+ elif algorithm in ["max_wave"]:
109
+ pred_track = np.array(pred_track)
110
+ pred_track = lambda_max(pred_track, axis=0, key=np.abs)
111
+ elif algorithm in ["avg_fft"]:
112
+ pred_track = pred_track.sum(axis=0)
113
+ pred_track /= np.array(weights).sum()
114
+ pred_track = istft(pred_track, 1024, final_length)
115
+ elif algorithm in ["min_fft"]:
116
+ pred_track = np.array(pred_track)
117
+ pred_track = lambda_min(pred_track, axis=0, key=np.abs)
118
+ pred_track = istft(pred_track, 1024, final_length)
119
+ elif algorithm in ["max_fft"]:
120
+ pred_track = np.array(pred_track)
121
+ pred_track = absmax(pred_track, axis=0)
122
+ pred_track = istft(pred_track, 1024, final_length)
123
+ elif algorithm in ["median_fft"]:
124
+ pred_track = np.array(pred_track)
125
+ pred_track = np.median(pred_track, axis=0)
126
+ pred_track = istft(pred_track, 1024, final_length)
127
+ return pred_track
128
+
129
+
130
+ def ensemble_audio_files(
131
+ files,
132
+ output="res.wav",
133
+ ensemble_type="avg_wave",
134
+ weights=None,
135
+ out_format="wav",
136
+ add_wav=False,
137
+ ) -> str | tuple[str, str]:
138
+ print("Алгоритм склеивания: {}".format(ensemble_type))
139
+ print("Количество входных файлов: {}".format(len(files)))
140
+ if weights is not None:
141
+ weights = np.array(weights)
142
+ else:
143
+ weights = np.ones(len(files))
144
+ print("Весы: {}".format(weights))
145
+ print("Имя выходного файла: {}".format(output))
146
+
147
+ data = []
148
+ sr = None
149
+ max_length = 0
150
+ max_channels = 0
151
+
152
+ for f in files:
153
+ if not os.path.isfile(f):
154
+ print("Не удается найти файл: {}. Check paths.".format(f))
155
+ exit()
156
+ print("Читается файл: {}".format(f))
157
+ wav, current_sr, _ = audio.read(i=f, sr=None, mono=False)
158
+ if sr is None:
159
+ sr = current_sr
160
+ elif sr != current_sr:
161
+ print("Частота дискретизации на всех файлах должна быть одинаковой")
162
+ exit()
163
+
164
+ if wav.ndim == 1:
165
+ channels = 1
166
+ length = len(wav)
167
+ else:
168
+ channels = wav.shape[0]
169
+ length = wav.shape[1]
170
+
171
+ max_length = max(max_length, length)
172
+ max_channels = max(max_channels, channels)
173
+ print("Форма сигнала: {} частота дискретизации: {}".format(wav.shape, sr))
174
+
175
+ for f in files:
176
+ wav, current_sr, _ = audio.read(i=f, sr=None, mono=False)
177
+
178
+ if wav.ndim == 1:
179
+ wav = np.vstack([wav, wav])
180
+ elif wav.shape[0] == 1:
181
+ wav = np.vstack([wav[0], wav[0]])
182
+ elif wav.shape[0] > 2:
183
+ wav = wav[:2, :]
184
+
185
+ if wav.shape[1] < max_length:
186
+ pad_width = ((0, 0), (0, max_length - wav.shape[1]))
187
+ wav = np.pad(wav, pad_width, mode="constant")
188
+ elif wav.shape[1] > max_length:
189
+ wav = wav[:, :max_length]
190
+
191
+ data.append(wav)
192
+
193
+ data = np.array(data)
194
+ res = average_waveforms(data, weights, ensemble_type)
195
+ print("Форма результата: {}".format(res.shape))
196
+
197
+ output_wav = f"{output}_orig.wav"
198
+ output = f"{output}.{out_format}"
199
+
200
+ output = audio.write(o=output, array=res.T, sr=sr, of=out_format, br="320k")
201
+ if add_wav:
202
+ output_wav = audio.write(o=output_wav, array=res.T, sr=sr, of="wav")
203
+ return output, output_wav
204
+
205
+ else:
206
+ return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvsepless/infer.py CHANGED
@@ -20,17 +20,13 @@ from namer import Namer
20
  namer = Namer()
21
  audio = Audio()
22
 
23
- from infer_utils import (
24
- prefer_target_instrument,
25
- demix,
26
- get_model_from_config
27
- )
28
 
29
 
30
  def normalize_peak(audio, peak):
31
  current_peak = np.max(np.abs(audio))
32
  if current_peak == 0:
33
- return audio # избегаем деления на ноль
34
  scale_factor = peak / current_peak
35
  return audio * scale_factor
36
 
@@ -57,10 +53,18 @@ def cleanup_model(model):
57
  torch.cuda.ipc_collect()
58
 
59
  gc.collect()
60
- sys.stdout.write(json.dumps({"cleanup": "Модель выгружена из памяти"}, ensure_ascii=False) + '\n')
 
 
 
61
  sys.stdout.flush()
62
  except Exception as e:
63
- sys.stdout.write(json.dumps({"error": f"Ошибка при выгрузке модели: {str(e)}"}, ensure_ascii=False) + '\n')
 
 
 
 
 
64
  sys.stdout.flush()
65
 
66
 
@@ -87,22 +91,32 @@ def once_inference(
87
  model_id: int = 0,
88
  ):
89
  results = []
90
- sys.stdout.write(json.dumps({"reading": path}, ensure_ascii=False) + '\n')
91
  sys.stdout.flush()
92
- sys.stdout.write(json.dumps({"selected_stems": selected_instruments}, ensure_ascii=False) + '\n')
 
 
93
  sys.stdout.flush()
94
- sys.stdout.write(json.dumps({"stems": list(instruments)}, ensure_ascii=False) + '\n')
 
 
95
  sys.stdout.flush()
96
 
97
  if config.training.target_instrument is not None:
98
- sys.stdout.write(json.dumps({"target_instrument": config.training.target_instrument}, ensure_ascii=False) + '\n')
 
 
 
 
 
 
99
  sys.stdout.flush()
100
 
101
  try:
102
  mix, sr, _ = audio.read(i=path, sr=sample_rate, mono=False)
103
  except Exception as e:
104
  error_msg = f"Не удалось прочитать аудио: {path}\nОшибка: {e}"
105
- sys.stdout.write(json.dumps({"error": error_msg}, ensure_ascii=False) + '\n')
106
  sys.stdout.flush()
107
  return results
108
 
@@ -128,13 +142,19 @@ def once_inference(
128
 
129
  full_result.append(waveforms)
130
  except Exception as e:
131
- sys.stdout.write(json.dumps({"error": f"Ошибка при демиксе: {e}"}, ensure_ascii=False) + '\n')
 
 
 
132
  sys.stdout.flush()
133
  del m
134
  gc.collect()
135
 
136
  if not full_result:
137
- sys.stdout.write(json.dumps({"error": "Пустой результат демикса."}, ensure_ascii=False) + '\n')
 
 
 
138
  sys.stdout.flush()
139
  return results
140
 
@@ -151,9 +171,7 @@ def once_inference(
151
  for el in waveforms:
152
  waveforms[el] /= len(full_result)
153
 
154
- if (
155
- extract_instrumental and config.training.target_instrument is not None
156
- ): # Если включен "Extract Instrumental / Извлечь инструментал" и найден целевой инструмент
157
  second_stem = [
158
  s
159
  for s in config.training.instruments
@@ -169,8 +187,7 @@ def once_inference(
169
  extract_instrumental
170
  and selected_instruments
171
  and config.training.target_instrument is None
172
- ): # Если включен "Extract Instrumental / Извлечь инструментал" и выбраны инструменты, то создаются стемы "inverted -" и "inverted +" (если не найден целевого инструмент)
173
-
174
 
175
  all_instruments = config.training.instruments
176
  if len(all_instruments) > 2:
@@ -178,21 +195,19 @@ def once_inference(
178
  waveforms["inverted -"] = mix_orig.copy()
179
  for instr in instruments:
180
  if instr in waveforms:
181
- waveforms["inverted -"] -= waveforms[
182
- instr
183
- ] # стем "inverted -": вычитание выбранного стема из оригинального сигнала (не всегда хорошо)
184
 
185
  if "inverted -" not in instruments:
186
  instruments.append("inverted -")
187
 
188
- unselected_stems = [s for s in all_instruments if s not in selected_instruments]
 
 
189
  if unselected_stems:
190
  waveforms["inverted +"] = np.zeros_like(mix_orig)
191
  for stem in unselected_stems:
192
  if stem in waveforms:
193
- waveforms["inverted +"] += waveforms[
194
- stem
195
- ] # стем "inverted +": сложение не выбранных инструментов в один стем
196
  if "inverted +" not in instruments:
197
  instruments.append("inverted +")
198
 
@@ -216,9 +231,7 @@ def once_inference(
216
  ):
217
 
218
  waveforms["instrumental -"] = mix_orig.copy()
219
- waveforms["instrumental -"] -= waveforms[
220
- "vocals"
221
- ] # стем "inverted -": вычитание выбранного стема из оригинального сигнала (не всегда хорошо)
222
 
223
  if "instrumental -" not in instruments:
224
  instruments.append("instrumental -")
@@ -229,9 +242,7 @@ def once_inference(
229
  waveforms["instrumental +"] = np.zeros_like(mix_orig)
230
  for stem in non_vocal_stems:
231
  if stem in waveforms:
232
- waveforms["instrumental +"] += waveforms[
233
- stem
234
- ] # стем "inverted +": сложение не выбранных инструментов в один стем
235
  if "instrumental +" not in instruments:
236
  instruments.append("instrumental +")
237
 
@@ -249,25 +260,40 @@ def once_inference(
249
  estimates = estimates * std + mean
250
 
251
  file_name = os.path.splitext(os.path.basename(path))[0]
252
- file_name_shorted = namer.short_input_name_template(template, STEM=instr, MODEL=model_name, ID=model_id, NAME=file_name)
 
 
253
  custom_name = namer.template(
254
- template, STEM=instr, MODEL=model_name, ID=model_id, NAME=file_name_shorted
 
 
 
 
255
  )
256
  output_path = os.path.join(store_dir, f"{custom_name}.{output_format}")
257
 
258
- sys.stdout.write(json.dumps({"writing": output_path}, ensure_ascii=False) + '\n')
 
 
259
  sys.stdout.flush()
260
 
261
  output_path = audio.write(
262
- o=output_path, array=estimates, sr=sr, of=output_format, br=output_bitrate
263
- ) # запись стема в аудио файл с помощью универсальной функции
 
 
 
 
264
 
265
- results.append(
266
- (instr, output_path)
267
- ) # запись информации о разделении: (название стема, путь к файлу)
268
  del estimates
269
  except Exception as e:
270
- sys.stdout.write(json.dumps({"error": f"Ошибка при обработке {instr}: {e}"}, ensure_ascii=False) + '\n')
 
 
 
 
 
271
  sys.stdout.flush()
272
  gc.collect()
273
 
@@ -307,7 +333,15 @@ def run_inference(
307
  instruments = prefer_target_instrument(config)
308
 
309
  if config.training.target_instrument is not None:
310
- sys.stdout.write(json.dumps({"info": "Целевой инструмент найден в конфигурации модели. Выбранные стемы будут проигнорированы."}, ensure_ascii=False) + '\n')
 
 
 
 
 
 
 
 
311
  sys.stdout.flush()
312
  else:
313
  if selected_instruments is not None and selected_instruments != []:
@@ -315,7 +349,10 @@ def run_inference(
315
  instr for instr in instruments if instr in selected_instruments
316
  ]
317
  if verbose:
318
- sys.stdout.write(json.dumps({"selected_stems": instruments}, ensure_ascii=False) + '\n')
 
 
 
319
  sys.stdout.flush()
320
 
321
  os.makedirs(store_dir, exist_ok=True)
@@ -345,9 +382,11 @@ def run_inference(
345
 
346
  time.sleep(1)
347
  time_taken = time.time() - start_time
348
- sys.stdout.write(json.dumps({"time": f"{time_taken:.2f} сек."}, ensure_ascii=False) + '\n')
 
 
349
  sys.stdout.flush()
350
- sys.stdout.write(json.dumps({"done": results}, ensure_ascii=False) + '\n')
351
  sys.stdout.flush()
352
  return results
353
 
@@ -357,7 +396,15 @@ def load_model(model_type, config_path, start_check_point, device_ids, force_cpu
357
  if force_cpu:
358
  device = "cpu"
359
  elif torch.cuda.is_available():
360
- sys.stdout.write(json.dumps({"info": "Разделение выполняется на ядрах CUDA. Для выполнения на процессоре установите force_cpu=True."}, ensure_ascii=False) + '\n')
 
 
 
 
 
 
 
 
361
  sys.stdout.flush()
362
  device = "cuda"
363
 
@@ -372,7 +419,7 @@ def load_model(model_type, config_path, start_check_point, device_ids, force_cpu
372
  elif torch.backends.mps.is_available():
373
  device = "mps"
374
 
375
- sys.stdout.write(json.dumps({"device": device}, ensure_ascii=False) + '\n')
376
  sys.stdout.flush()
377
 
378
  model_load_start_time = time.time()
@@ -382,29 +429,30 @@ def load_model(model_type, config_path, start_check_point, device_ids, force_cpu
382
 
383
  if model_type == "vr":
384
  model.load_checkpoint(start_check_point, device)
385
- model.settings(enable_post_process=False,
386
- post_process_threshold=config.inference.post_process_threshold,
387
- batch_size=config.inference.batch_size,
388
- window_size=config.inference.window_size,
389
- high_end_process=config.inference.high_end_process,
 
390
  primary_stem=config.training.instruments[0],
391
- secondary_stem=config.training.instruments[1]
392
  )
393
  return model, config, device
394
 
395
  elif model_type == "mdxnet":
396
  if start_check_point != "":
397
- sys.stdout.write(json.dumps({"checkpoint": start_check_point}) + '\n')
398
  sys.stdout.flush()
399
  model.init_onnx_session(start_check_point, device)
400
 
401
  return model, config, device
402
-
403
  else:
404
  if start_check_point != "":
405
- sys.stdout.write(json.dumps({"checkpoint": start_check_point}) + '\n')
406
  sys.stdout.flush()
407
-
408
  if model_type in ["htdemucs", "apollo"]:
409
  state_dict = torch.load(
410
  start_check_point, map_location=device, weights_only=False
@@ -428,7 +476,10 @@ def load_model(model_type, config_path, start_check_point, device_ids, force_cpu
428
  except RuntimeError:
429
  model.load_state_dict(state_dict, strict=False)
430
 
431
- sys.stdout.write(json.dumps({"stems": list(config.training.instruments)}, ensure_ascii=False) + '\n')
 
 
 
432
  sys.stdout.flush()
433
 
434
  if (
@@ -443,7 +494,10 @@ def load_model(model_type, config_path, start_check_point, device_ids, force_cpu
443
 
444
  load_time = time.time() - model_load_start_time
445
 
446
- sys.stdout.write(json.dumps({"model_load_time": f"{load_time:.2f} сек."}, ensure_ascii=False) + '\n')
 
 
 
447
  sys.stdout.flush()
448
 
449
  return model, config, device
@@ -503,13 +557,11 @@ def parse_args():
503
  description="Модифицированный Music-Source-Separation-Training для разделения аудио на источники"
504
  )
505
 
506
- # Обязательные аргументы
507
  parser.add_argument("--input", type=str, help="Путь к входному файлу или папке")
508
  parser.add_argument(
509
  "--store_dir", type=str, required=True, help="Путь для сохранения результатов"
510
  )
511
 
512
- # Основные параметры модели
513
  parser.add_argument(
514
  "--model_type",
515
  type=str,
@@ -523,7 +575,7 @@ def parse_args():
523
  "bandit",
524
  "bandit_v2",
525
  "mdxnet",
526
- "vr"
527
  ],
528
  help="Тип модели (по умолчанию: htdemucs)",
529
  )
@@ -537,7 +589,6 @@ def parse_args():
537
  "--start_check_point", type=str, required=True, help="Путь к чекпоинту модели"
538
  )
539
 
540
- # Параметры вывода
541
  parser.add_argument(
542
  "--output_format",
543
  type=str,
@@ -620,4 +671,4 @@ def main():
620
 
621
 
622
  if __name__ == "__main__":
623
- main()
 
20
  namer = Namer()
21
  audio = Audio()
22
 
23
+ from infer_utils import prefer_target_instrument, demix, get_model_from_config
 
 
 
 
24
 
25
 
26
  def normalize_peak(audio, peak):
27
  current_peak = np.max(np.abs(audio))
28
  if current_peak == 0:
29
+ return audio
30
  scale_factor = peak / current_peak
31
  return audio * scale_factor
32
 
 
53
  torch.cuda.ipc_collect()
54
 
55
  gc.collect()
56
+ sys.stdout.write(
57
+ json.dumps({"cleanup": "Модель выгружена из памяти"}, ensure_ascii=False)
58
+ + "\n"
59
+ )
60
  sys.stdout.flush()
61
  except Exception as e:
62
+ sys.stdout.write(
63
+ json.dumps(
64
+ {"error": f"Ошибка при выгрузке модели: {str(e)}"}, ensure_ascii=False
65
+ )
66
+ + "\n"
67
+ )
68
  sys.stdout.flush()
69
 
70
 
 
91
  model_id: int = 0,
92
  ):
93
  results = []
94
+ sys.stdout.write(json.dumps({"reading": path}, ensure_ascii=False) + "\n")
95
  sys.stdout.flush()
96
+ sys.stdout.write(
97
+ json.dumps({"selected_stems": selected_instruments}, ensure_ascii=False) + "\n"
98
+ )
99
  sys.stdout.flush()
100
+ sys.stdout.write(
101
+ json.dumps({"stems": list(instruments)}, ensure_ascii=False) + "\n"
102
+ )
103
  sys.stdout.flush()
104
 
105
  if config.training.target_instrument is not None:
106
+ sys.stdout.write(
107
+ json.dumps(
108
+ {"target_instrument": config.training.target_instrument},
109
+ ensure_ascii=False,
110
+ )
111
+ + "\n"
112
+ )
113
  sys.stdout.flush()
114
 
115
  try:
116
  mix, sr, _ = audio.read(i=path, sr=sample_rate, mono=False)
117
  except Exception as e:
118
  error_msg = f"Не удалось прочитать аудио: {path}\nОшибка: {e}"
119
+ sys.stdout.write(json.dumps({"error": error_msg}, ensure_ascii=False) + "\n")
120
  sys.stdout.flush()
121
  return results
122
 
 
142
 
143
  full_result.append(waveforms)
144
  except Exception as e:
145
+ sys.stdout.write(
146
+ json.dumps({"error": f"Ошибка при демиксе: {e}"}, ensure_ascii=False)
147
+ + "\n"
148
+ )
149
  sys.stdout.flush()
150
  del m
151
  gc.collect()
152
 
153
  if not full_result:
154
+ sys.stdout.write(
155
+ json.dumps({"error": "Пустой результат демикса."}, ensure_ascii=False)
156
+ + "\n"
157
+ )
158
  sys.stdout.flush()
159
  return results
160
 
 
171
  for el in waveforms:
172
  waveforms[el] /= len(full_result)
173
 
174
+ if extract_instrumental and config.training.target_instrument is not None:
 
 
175
  second_stem = [
176
  s
177
  for s in config.training.instruments
 
187
  extract_instrumental
188
  and selected_instruments
189
  and config.training.target_instrument is None
190
+ ):
 
191
 
192
  all_instruments = config.training.instruments
193
  if len(all_instruments) > 2:
 
195
  waveforms["inverted -"] = mix_orig.copy()
196
  for instr in instruments:
197
  if instr in waveforms:
198
+ waveforms["inverted -"] -= waveforms[instr]
 
 
199
 
200
  if "inverted -" not in instruments:
201
  instruments.append("inverted -")
202
 
203
+ unselected_stems = [
204
+ s for s in all_instruments if s not in selected_instruments
205
+ ]
206
  if unselected_stems:
207
  waveforms["inverted +"] = np.zeros_like(mix_orig)
208
  for stem in unselected_stems:
209
  if stem in waveforms:
210
+ waveforms["inverted +"] += waveforms[stem]
 
 
211
  if "inverted +" not in instruments:
212
  instruments.append("inverted +")
213
 
 
231
  ):
232
 
233
  waveforms["instrumental -"] = mix_orig.copy()
234
+ waveforms["instrumental -"] -= waveforms["vocals"]
 
 
235
 
236
  if "instrumental -" not in instruments:
237
  instruments.append("instrumental -")
 
242
  waveforms["instrumental +"] = np.zeros_like(mix_orig)
243
  for stem in non_vocal_stems:
244
  if stem in waveforms:
245
+ waveforms["instrumental +"] += waveforms[stem]
 
 
246
  if "instrumental +" not in instruments:
247
  instruments.append("instrumental +")
248
 
 
260
  estimates = estimates * std + mean
261
 
262
  file_name = os.path.splitext(os.path.basename(path))[0]
263
+ file_name_shorted = namer.short_input_name_template(
264
+ template, STEM=instr, MODEL=model_name, ID=model_id, NAME=file_name
265
+ )
266
  custom_name = namer.template(
267
+ template,
268
+ STEM=instr,
269
+ MODEL=model_name,
270
+ ID=model_id,
271
+ NAME=file_name_shorted,
272
  )
273
  output_path = os.path.join(store_dir, f"{custom_name}.{output_format}")
274
 
275
+ sys.stdout.write(
276
+ json.dumps({"writing": output_path}, ensure_ascii=False) + "\n"
277
+ )
278
  sys.stdout.flush()
279
 
280
  output_path = audio.write(
281
+ o=output_path,
282
+ array=estimates,
283
+ sr=sr,
284
+ of=output_format,
285
+ br=output_bitrate,
286
+ )
287
 
288
+ results.append((instr, output_path))
 
 
289
  del estimates
290
  except Exception as e:
291
+ sys.stdout.write(
292
+ json.dumps(
293
+ {"error": f"Ошибка при обработке {instr}: {e}"}, ensure_ascii=False
294
+ )
295
+ + "\n"
296
+ )
297
  sys.stdout.flush()
298
  gc.collect()
299
 
 
333
  instruments = prefer_target_instrument(config)
334
 
335
  if config.training.target_instrument is not None:
336
+ sys.stdout.write(
337
+ json.dumps(
338
+ {
339
+ "info": "Целевой инструмент найден в конфигурации модели. Выбранные стемы будут проигнорированы."
340
+ },
341
+ ensure_ascii=False,
342
+ )
343
+ + "\n"
344
+ )
345
  sys.stdout.flush()
346
  else:
347
  if selected_instruments is not None and selected_instruments != []:
 
349
  instr for instr in instruments if instr in selected_instruments
350
  ]
351
  if verbose:
352
+ sys.stdout.write(
353
+ json.dumps({"selected_stems": instruments}, ensure_ascii=False)
354
+ + "\n"
355
+ )
356
  sys.stdout.flush()
357
 
358
  os.makedirs(store_dir, exist_ok=True)
 
382
 
383
  time.sleep(1)
384
  time_taken = time.time() - start_time
385
+ sys.stdout.write(
386
+ json.dumps({"time": f"{time_taken:.2f} сек."}, ensure_ascii=False) + "\n"
387
+ )
388
  sys.stdout.flush()
389
+ sys.stdout.write(json.dumps({"done": results}, ensure_ascii=False) + "\n")
390
  sys.stdout.flush()
391
  return results
392
 
 
396
  if force_cpu:
397
  device = "cpu"
398
  elif torch.cuda.is_available():
399
+ sys.stdout.write(
400
+ json.dumps(
401
+ {
402
+ "info": "Разделение выполняется на ядрах CUDA. Для выполнения на процессоре установите force_cpu=True."
403
+ },
404
+ ensure_ascii=False,
405
+ )
406
+ + "\n"
407
+ )
408
  sys.stdout.flush()
409
  device = "cuda"
410
 
 
419
  elif torch.backends.mps.is_available():
420
  device = "mps"
421
 
422
+ sys.stdout.write(json.dumps({"device": device}, ensure_ascii=False) + "\n")
423
  sys.stdout.flush()
424
 
425
  model_load_start_time = time.time()
 
429
 
430
  if model_type == "vr":
431
  model.load_checkpoint(start_check_point, device)
432
+ model.settings(
433
+ enable_post_process=False,
434
+ post_process_threshold=config.inference.post_process_threshold,
435
+ batch_size=config.inference.batch_size,
436
+ window_size=config.inference.window_size,
437
+ high_end_process=config.inference.high_end_process,
438
  primary_stem=config.training.instruments[0],
439
+ secondary_stem=config.training.instruments[1],
440
  )
441
  return model, config, device
442
 
443
  elif model_type == "mdxnet":
444
  if start_check_point != "":
445
+ sys.stdout.write(json.dumps({"checkpoint": start_check_point}) + "\n")
446
  sys.stdout.flush()
447
  model.init_onnx_session(start_check_point, device)
448
 
449
  return model, config, device
450
+
451
  else:
452
  if start_check_point != "":
453
+ sys.stdout.write(json.dumps({"checkpoint": start_check_point}) + "\n")
454
  sys.stdout.flush()
455
+
456
  if model_type in ["htdemucs", "apollo"]:
457
  state_dict = torch.load(
458
  start_check_point, map_location=device, weights_only=False
 
476
  except RuntimeError:
477
  model.load_state_dict(state_dict, strict=False)
478
 
479
+ sys.stdout.write(
480
+ json.dumps({"stems": list(config.training.instruments)}, ensure_ascii=False)
481
+ + "\n"
482
+ )
483
  sys.stdout.flush()
484
 
485
  if (
 
494
 
495
  load_time = time.time() - model_load_start_time
496
 
497
+ sys.stdout.write(
498
+ json.dumps({"model_load_time": f"{load_time:.2f} сек."}, ensure_ascii=False)
499
+ + "\n"
500
+ )
501
  sys.stdout.flush()
502
 
503
  return model, config, device
 
557
  description="Модифицированный Music-Source-Separation-Training для разделения аудио на источники"
558
  )
559
 
 
560
  parser.add_argument("--input", type=str, help="Путь к входному файлу или папке")
561
  parser.add_argument(
562
  "--store_dir", type=str, required=True, help="Путь для сохранения результатов"
563
  )
564
 
 
565
  parser.add_argument(
566
  "--model_type",
567
  type=str,
 
575
  "bandit",
576
  "bandit_v2",
577
  "mdxnet",
578
+ "vr",
579
  ],
580
  help="Тип модели (по умолчанию: htdemucs)",
581
  )
 
589
  "--start_check_point", type=str, required=True, help="Путь к чекпоинту модели"
590
  )
591
 
 
592
  parser.add_argument(
593
  "--output_format",
594
  type=str,
 
671
 
672
 
673
  if __name__ == "__main__":
674
+ main()
mvsepless/infer_utils.py CHANGED
@@ -1,4 +1,3 @@
1
- # coding: utf-8
2
  __author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
3
 
4
  import sys
@@ -15,9 +14,6 @@ from typing import Dict, List, Tuple, Any, List, Optional
15
 
16
 
17
  def load_config(model_type: str, config_path: str) -> Any:
18
- """
19
- Load the configuration from the specified path based on the model type.
20
- """
21
  try:
22
  with open(config_path, "r") as f:
23
  if model_type == "htdemucs":
@@ -32,9 +28,6 @@ def load_config(model_type: str, config_path: str) -> Any:
32
 
33
 
34
  def get_model_from_config(model_type: str, config_path: str) -> Tuple:
35
- """
36
- Load the model specified by the model type and configuration file.
37
- """
38
  config = load_config(model_type, config_path)
39
 
40
  if model_type == "mdx23c":
@@ -47,11 +40,9 @@ def get_model_from_config(model_type: str, config_path: str) -> Tuple:
47
 
48
  model = MDXNet(**dict(config.model))
49
 
50
- # В функции get_model_from_config добавьте:
51
-
52
  elif model_type == "vr":
53
  from models.vr_arch import VRNet
54
- # Передаем instruments из config.training в модель
55
  model = VRNet(**dict(config.model))
56
 
57
  elif model_type == "htdemucs":
@@ -60,7 +51,7 @@ def get_model_from_config(model_type: str, config_path: str) -> Tuple:
60
  model = get_model(config)
61
 
62
  elif model_type == "mel_band_roformer":
63
- if hasattr(config, "windowed"): # Это не нарушает совместимость со обычными моделями на Mel-Band Roformer
64
  from models.windowed_roformer.model import MelBandRoformerWSA
65
 
66
  model = MelBandRoformerWSA(**dict(config.model))
@@ -114,10 +105,8 @@ def get_model_from_config(model_type: str, config_path: str) -> Tuple:
114
 
115
  return model, config
116
 
 
117
  def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor:
118
- """
119
- Generate a windowing array with a linear fade-in at the beginning and a fade-out at the end.
120
- """
121
  fadein = torch.linspace(0, 1, fade_size)
122
  fadeout = torch.linspace(1, 0, fade_size)
123
 
@@ -126,6 +115,7 @@ def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor:
126
  window[:fade_size] = fadein
127
  return window
128
 
 
129
  def demix_mdxnet(
130
  config: Any,
131
  model: Any,
@@ -133,18 +123,17 @@ def demix_mdxnet(
133
  device: torch.device,
134
  pbar: bool = False,
135
  ) -> Dict[str, np.ndarray]:
136
- """
137
- MDX-Net specific demixing function с поддержкой overlap
138
- """
139
  mix_tensor = torch.tensor(mix, dtype=torch.float32)
140
  inv_mix_tensor = torch.tensor(-mix, dtype=torch.float32)
141
-
142
  num_overlap = config.inference.num_overlap
143
  denoise = config.inference.denoise
144
  stem_name = model.primary_stem
145
  if denoise:
146
  processed_wav = model.process_wave(mix_tensor, device, num_overlap, pbar=pbar)
147
- inv_processed_wav = model.process_wave(inv_mix_tensor, device, num_overlap, pbar=pbar)
 
 
148
  result = processed_wav.cpu().numpy()
149
  inv_result = inv_processed_wav.cpu().numpy()
150
  result_separation = (result + -inv_result) * 0.5
@@ -152,9 +141,12 @@ def demix_mdxnet(
152
  processed_wav = model.process_wave(mix_tensor, device, num_overlap, pbar=pbar)
153
  result_separation = processed_wav.cpu().numpy()
154
 
155
- result_separation = np.nan_to_num(result_separation, nan=0.0, posinf=0.0, neginf=0.0)
 
 
 
 
156
 
157
- return {stem_name: result_separation} # Перемещаем на CPU для возврата
158
 
159
  def demix_vr(
160
  config: Any,
@@ -163,12 +155,10 @@ def demix_vr(
163
  device: torch.device,
164
  pbar: bool = False,
165
  ) -> Dict[str, np.ndarray]:
166
- """
167
- VR-specific demixing function that processes the entire audio at once
168
- since VR architecture doesn't support chunk-based processing
169
- """
170
- # Convert to tensor and add batch dimension
171
- return model.demix(mix, config.audio.sample_rate, device, config.inference.aggression)
172
 
173
  def demix_demucs(config, model, mix, device, pbar=False):
174
  mix = torch.tensor(mix, dtype=torch.float32)
@@ -176,8 +166,8 @@ def demix_demucs(config, model, mix, device, pbar=False):
176
  num_instruments = len(config.training.instruments)
177
  num_overlap = config.inference.num_overlap
178
  step = chunk_size // num_overlap
179
- fade_size = chunk_size // 10 # Добавляем fade_size для оконной функции
180
- windowing_array = _getWindowingArray(chunk_size, fade_size) # Создаём окно
181
 
182
  batch_size = config.inference.batch_size
183
  use_amp = getattr(config.training, "use_amp", True)
@@ -209,9 +199,9 @@ def demix_demucs(config, model, mix, device, pbar=False):
209
  x = model(arr)
210
 
211
  window = windowing_array.clone()
212
- if i - step == 0: # Первый чанк, без fade-in
213
  window[:fade_size] = 1
214
- elif i >= mix.shape[1]: # Последний чанк, без fade-out
215
  window[-fade_size:] = 1
216
 
217
  for j, (start, seg_len) in enumerate(batch_locations):
@@ -220,10 +210,14 @@ def demix_demucs(config, model, mix, device, pbar=False):
220
  )
221
  counter[..., start : start + seg_len] += window[..., :seg_len]
222
 
223
- # Output progress
224
  processed = min(i, mix.shape[1])
225
  total = mix.shape[1]
226
- sys.stdout.write(json.dumps({"processing": {"processed": processed, "total": total}}) + '\n')
 
 
 
 
 
227
  sys.stdout.flush()
228
 
229
  batch_data.clear()
@@ -239,6 +233,7 @@ def demix_demucs(config, model, mix, device, pbar=False):
239
  instruments = config.training.instruments
240
  return {k: v for k, v in zip(instruments, estimated_sources)}
241
 
 
242
  def demix_generic(
243
  config: ConfigDict,
244
  model: torch.nn.Module,
@@ -246,9 +241,6 @@ def demix_generic(
246
  device: torch.device,
247
  pbar: bool = False,
248
  ) -> Dict[str, np.ndarray]:
249
- """
250
- Generic demixing function for models that support chunk-based processing
251
- """
252
  mix = torch.tensor(mix, dtype=torch.float32)
253
  chunk_size = config.audio.chunk_size
254
  instruments = prefer_target_instrument(config)
@@ -260,8 +252,7 @@ def demix_generic(
260
  border = chunk_size - step
261
  length_init = mix.shape[-1]
262
  windowing_array = _getWindowingArray(chunk_size, fade_size)
263
-
264
- # Add padding to handle edge artifacts
265
  if length_init > 2 * border and border > 0:
266
  mix = nn.functional.pad(mix, (border, border), mode="reflect")
267
 
@@ -270,7 +261,6 @@ def demix_generic(
270
 
271
  with torch.cuda.amp.autocast(enabled=use_amp):
272
  with torch.inference_mode():
273
- # Initialize result and counter tensors
274
  req_shape = (num_instruments,) + mix.shape
275
  result = torch.zeros(req_shape, dtype=torch.float32)
276
  counter = torch.zeros(req_shape, dtype=torch.float32)
@@ -280,10 +270,9 @@ def demix_generic(
280
  batch_locations = []
281
 
282
  while i < mix.shape[1]:
283
- # Extract chunk and apply padding if necessary
284
  part = mix[:, i : i + chunk_size].to(device)
285
  chunk_len = part.shape[-1]
286
-
287
  pad_mode = "reflect" if chunk_len > chunk_size // 2 else "constant"
288
  part = nn.functional.pad(
289
  part, (0, chunk_size - chunk_len), mode=pad_mode, value=0
@@ -293,15 +282,14 @@ def demix_generic(
293
  batch_locations.append((i, chunk_len))
294
  i += step
295
 
296
- # Process batch if it's full or the end is reached
297
  if len(batch_data) >= batch_size or i >= mix.shape[1]:
298
  arr = torch.stack(batch_data, dim=0)
299
  x = model(arr)
300
 
301
  window = windowing_array.clone()
302
- if i - step == 0: # First audio chunk, no fadein
303
  window[:fade_size] = 1
304
- elif i >= mix.shape[1]: # Last audio chunk, no fadeout
305
  window[-fade_size:] = 1
306
 
307
  for j, (start, seg_len) in enumerate(batch_locations):
@@ -310,27 +298,30 @@ def demix_generic(
310
  )
311
  counter[..., start : start + seg_len] += window[..., :seg_len]
312
 
313
- # Output progress
314
  processed = min(i, mix.shape[1])
315
  total = mix.shape[1]
316
- sys.stdout.write(json.dumps({"processing": {"processed": processed, "total": total}}, ensure_ascii=False) + '\n')
 
 
 
 
 
 
317
  sys.stdout.flush()
318
 
319
  batch_data.clear()
320
  batch_locations.clear()
321
 
322
- # Compute final estimated sources
323
  estimated_sources = result / counter
324
  estimated_sources = estimated_sources.cpu().numpy()
325
  np.nan_to_num(estimated_sources, copy=False, nan=0.0)
326
 
327
- # Remove padding
328
  if length_init > 2 * border and border > 0:
329
  estimated_sources = estimated_sources[..., border:-border]
330
 
331
- # Return the result as a dictionary
332
  return {k: v for k, v in zip(instruments, estimated_sources)}
333
 
 
334
  def demix(
335
  config: ConfigDict,
336
  model: torch.nn.Module,
@@ -339,28 +330,17 @@ def demix(
339
  model_type: str,
340
  pbar: bool = False,
341
  ) -> Dict[str, np.ndarray]:
342
- """
343
- Unified function for audio source separation with support for multiple processing modes.
344
- """
345
- # Handle different model types
346
  if model_type == "vr":
347
  return demix_vr(config, model, mix, device, pbar)
348
  elif model_type == "mdxnet":
349
  return demix_mdxnet(config, model, mix, device, pbar)
350
  elif model_type == "htdemucs":
351
- # HTDemucs uses its own processing
352
  return demix_demucs(config, model, mix, device, pbar)
353
  else:
354
- # Generic processing for other models
355
  return demix_generic(config, model, mix, device, pbar)
356
 
357
 
358
  def prefer_target_instrument(config: ConfigDict) -> List[str]:
359
- """
360
- Return the list of target instruments based on the configuration.
361
- If a specific target instrument is specified in the configuration,
362
- it returns a list with that instrument. Otherwise, it returns the list of instruments.
363
- """
364
  if config.training.get("target_instrument"):
365
  return [config.training.target_instrument]
366
  else:
@@ -370,21 +350,13 @@ def prefer_target_instrument(config: ConfigDict) -> List[str]:
370
  def prefer_target_instrument_test(
371
  config: ConfigDict, selected_instruments: Optional[List[str]] = None
372
  ) -> List[str]:
373
- """
374
- Return the list of target instruments based on the configuration and selected instruments.
375
- If selected_instruments is specified, returns the intersection with available instruments.
376
- Otherwise, if a target instrument is specified, returns it, else returns all instruments.
377
- """
378
  available_instruments = config.training.instruments
379
 
380
  if selected_instruments is not None:
381
- # Return only selected instruments that are available
382
  return [
383
  instr for instr in selected_instruments if instr in available_instruments
384
  ]
385
  elif config.training.get("target_instrument"):
386
- # Default behavior if no selection - return target instrument
387
  return [config.training.target_instrument]
388
  else:
389
- # If no target and no selection, return all instruments
390
- return available_instruments
 
 
1
  __author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
2
 
3
  import sys
 
14
 
15
 
16
  def load_config(model_type: str, config_path: str) -> Any:
 
 
 
17
  try:
18
  with open(config_path, "r") as f:
19
  if model_type == "htdemucs":
 
28
 
29
 
30
  def get_model_from_config(model_type: str, config_path: str) -> Tuple:
 
 
 
31
  config = load_config(model_type, config_path)
32
 
33
  if model_type == "mdx23c":
 
40
 
41
  model = MDXNet(**dict(config.model))
42
 
 
 
43
  elif model_type == "vr":
44
  from models.vr_arch import VRNet
45
+
46
  model = VRNet(**dict(config.model))
47
 
48
  elif model_type == "htdemucs":
 
51
  model = get_model(config)
52
 
53
  elif model_type == "mel_band_roformer":
54
+ if hasattr(config, "windowed"):
55
  from models.windowed_roformer.model import MelBandRoformerWSA
56
 
57
  model = MelBandRoformerWSA(**dict(config.model))
 
105
 
106
  return model, config
107
 
108
+
109
  def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor:
 
 
 
110
  fadein = torch.linspace(0, 1, fade_size)
111
  fadeout = torch.linspace(1, 0, fade_size)
112
 
 
115
  window[:fade_size] = fadein
116
  return window
117
 
118
+
119
  def demix_mdxnet(
120
  config: Any,
121
  model: Any,
 
123
  device: torch.device,
124
  pbar: bool = False,
125
  ) -> Dict[str, np.ndarray]:
 
 
 
126
  mix_tensor = torch.tensor(mix, dtype=torch.float32)
127
  inv_mix_tensor = torch.tensor(-mix, dtype=torch.float32)
128
+
129
  num_overlap = config.inference.num_overlap
130
  denoise = config.inference.denoise
131
  stem_name = model.primary_stem
132
  if denoise:
133
  processed_wav = model.process_wave(mix_tensor, device, num_overlap, pbar=pbar)
134
+ inv_processed_wav = model.process_wave(
135
+ inv_mix_tensor, device, num_overlap, pbar=pbar
136
+ )
137
  result = processed_wav.cpu().numpy()
138
  inv_result = inv_processed_wav.cpu().numpy()
139
  result_separation = (result + -inv_result) * 0.5
 
141
  processed_wav = model.process_wave(mix_tensor, device, num_overlap, pbar=pbar)
142
  result_separation = processed_wav.cpu().numpy()
143
 
144
+ result_separation = np.nan_to_num(
145
+ result_separation, nan=0.0, posinf=0.0, neginf=0.0
146
+ )
147
+
148
+ return {stem_name: result_separation}
149
 
 
150
 
151
  def demix_vr(
152
  config: Any,
 
155
  device: torch.device,
156
  pbar: bool = False,
157
  ) -> Dict[str, np.ndarray]:
158
+ return model.demix(
159
+ mix, config.audio.sample_rate, device, config.inference.aggression
160
+ )
161
+
 
 
162
 
163
  def demix_demucs(config, model, mix, device, pbar=False):
164
  mix = torch.tensor(mix, dtype=torch.float32)
 
166
  num_instruments = len(config.training.instruments)
167
  num_overlap = config.inference.num_overlap
168
  step = chunk_size // num_overlap
169
+ fade_size = chunk_size // 10
170
+ windowing_array = _getWindowingArray(chunk_size, fade_size)
171
 
172
  batch_size = config.inference.batch_size
173
  use_amp = getattr(config.training, "use_amp", True)
 
199
  x = model(arr)
200
 
201
  window = windowing_array.clone()
202
+ if i - step == 0:
203
  window[:fade_size] = 1
204
+ elif i >= mix.shape[1]:
205
  window[-fade_size:] = 1
206
 
207
  for j, (start, seg_len) in enumerate(batch_locations):
 
210
  )
211
  counter[..., start : start + seg_len] += window[..., :seg_len]
212
 
 
213
  processed = min(i, mix.shape[1])
214
  total = mix.shape[1]
215
+ sys.stdout.write(
216
+ json.dumps(
217
+ {"processing": {"processed": processed, "total": total}}
218
+ )
219
+ + "\n"
220
+ )
221
  sys.stdout.flush()
222
 
223
  batch_data.clear()
 
233
  instruments = config.training.instruments
234
  return {k: v for k, v in zip(instruments, estimated_sources)}
235
 
236
+
237
  def demix_generic(
238
  config: ConfigDict,
239
  model: torch.nn.Module,
 
241
  device: torch.device,
242
  pbar: bool = False,
243
  ) -> Dict[str, np.ndarray]:
 
 
 
244
  mix = torch.tensor(mix, dtype=torch.float32)
245
  chunk_size = config.audio.chunk_size
246
  instruments = prefer_target_instrument(config)
 
252
  border = chunk_size - step
253
  length_init = mix.shape[-1]
254
  windowing_array = _getWindowingArray(chunk_size, fade_size)
255
+
 
256
  if length_init > 2 * border and border > 0:
257
  mix = nn.functional.pad(mix, (border, border), mode="reflect")
258
 
 
261
 
262
  with torch.cuda.amp.autocast(enabled=use_amp):
263
  with torch.inference_mode():
 
264
  req_shape = (num_instruments,) + mix.shape
265
  result = torch.zeros(req_shape, dtype=torch.float32)
266
  counter = torch.zeros(req_shape, dtype=torch.float32)
 
270
  batch_locations = []
271
 
272
  while i < mix.shape[1]:
 
273
  part = mix[:, i : i + chunk_size].to(device)
274
  chunk_len = part.shape[-1]
275
+
276
  pad_mode = "reflect" if chunk_len > chunk_size // 2 else "constant"
277
  part = nn.functional.pad(
278
  part, (0, chunk_size - chunk_len), mode=pad_mode, value=0
 
282
  batch_locations.append((i, chunk_len))
283
  i += step
284
 
 
285
  if len(batch_data) >= batch_size or i >= mix.shape[1]:
286
  arr = torch.stack(batch_data, dim=0)
287
  x = model(arr)
288
 
289
  window = windowing_array.clone()
290
+ if i - step == 0:
291
  window[:fade_size] = 1
292
+ elif i >= mix.shape[1]:
293
  window[-fade_size:] = 1
294
 
295
  for j, (start, seg_len) in enumerate(batch_locations):
 
298
  )
299
  counter[..., start : start + seg_len] += window[..., :seg_len]
300
 
 
301
  processed = min(i, mix.shape[1])
302
  total = mix.shape[1]
303
+ sys.stdout.write(
304
+ json.dumps(
305
+ {"processing": {"processed": processed, "total": total}},
306
+ ensure_ascii=False,
307
+ )
308
+ + "\n"
309
+ )
310
  sys.stdout.flush()
311
 
312
  batch_data.clear()
313
  batch_locations.clear()
314
 
 
315
  estimated_sources = result / counter
316
  estimated_sources = estimated_sources.cpu().numpy()
317
  np.nan_to_num(estimated_sources, copy=False, nan=0.0)
318
 
 
319
  if length_init > 2 * border and border > 0:
320
  estimated_sources = estimated_sources[..., border:-border]
321
 
 
322
  return {k: v for k, v in zip(instruments, estimated_sources)}
323
 
324
+
325
  def demix(
326
  config: ConfigDict,
327
  model: torch.nn.Module,
 
330
  model_type: str,
331
  pbar: bool = False,
332
  ) -> Dict[str, np.ndarray]:
 
 
 
 
333
  if model_type == "vr":
334
  return demix_vr(config, model, mix, device, pbar)
335
  elif model_type == "mdxnet":
336
  return demix_mdxnet(config, model, mix, device, pbar)
337
  elif model_type == "htdemucs":
 
338
  return demix_demucs(config, model, mix, device, pbar)
339
  else:
 
340
  return demix_generic(config, model, mix, device, pbar)
341
 
342
 
343
  def prefer_target_instrument(config: ConfigDict) -> List[str]:
 
 
 
 
 
344
  if config.training.get("target_instrument"):
345
  return [config.training.target_instrument]
346
  else:
 
350
  def prefer_target_instrument_test(
351
  config: ConfigDict, selected_instruments: Optional[List[str]] = None
352
  ) -> List[str]:
 
 
 
 
 
353
  available_instruments = config.training.instruments
354
 
355
  if selected_instruments is not None:
 
356
  return [
357
  instr for instr in selected_instruments if instr in available_instruments
358
  ]
359
  elif config.training.get("target_instrument"):
 
360
  return [config.training.target_instrument]
361
  else:
362
+ return available_instruments
 
mvsepless/model_manager.py CHANGED
@@ -1,609 +1,682 @@
1
- import os
2
- import sys
3
- import json
4
- import yaml
5
- from tabulate import tabulate
6
- import shutil
7
- from tqdm import tqdm
8
- import urllib.request
9
- import gdown
10
- import requests
11
- import zipfile
12
- import tempfile
13
- import secrets
14
- import string
15
- import argparse
16
- from typing import Dict, Any
17
- script_dir = os.path.dirname(os.path.abspath(__file__))
18
- if not __package__:
19
- from downloader import dw_file
20
- else:
21
- from .downloader import dw_file
22
-
23
- def generate_secure_random(length=10):
24
- """Генерирует криптографически безопасную случайную строку"""
25
- characters = string.ascii_letters + string.digits
26
- return ''.join(secrets.choice(characters) for _ in range(length))
27
-
28
- class MvseplessModelManager:
29
- def __init__(
30
- self,
31
- models_info_path=os.path.join(script_dir, "models.json"),
32
- cache_dir=os.path.join(script_dir, "mvsepless_models_cache"),
33
- ):
34
- self.models_cache_dir = cache_dir
35
- self.models_info_path = models_info_path
36
- with open(self.models_info_path, "r", encoding="utf-8") as f:
37
- models_info = json.load(f)
38
- self.models_info = models_info
39
-
40
- def get_mt(self):
41
- return list(self.models_info.keys())
42
-
43
- def get_mn(self, model_type):
44
- try:
45
- mt = self.models_info.get(model_type, None)
46
- if mt:
47
- return list(self.models_info[model_type].keys())
48
- return []
49
- except (KeyError, TypeError):
50
- return []
51
-
52
- def get_stems(self, model_type, model_name):
53
- try:
54
- mt = self.models_info.get(model_type, None)
55
- if mt:
56
- mn = self.models_info[model_type].get(model_name, None)
57
- if mn and "stems" in self.models_info[model_type][model_name]:
58
- return self.models_info[model_type][model_name]["stems"]
59
- return []
60
- except (KeyError, TypeError):
61
- return []
62
-
63
- def get_id(self, model_type, model_name):
64
- try:
65
- mt = self.models_info.get(model_type, None)
66
- if mt:
67
- mn = self.models_info[model_type].get(model_name, None)
68
- if mn and "id" in self.models_info[model_type][model_name]:
69
- return self.models_info[model_type][model_name]["id"]
70
- return 0
71
- except (KeyError, TypeError):
72
- return 0
73
-
74
- def get_tgt_inst(self, model_type, model_name):
75
- try:
76
- mt = self.models_info.get(model_type, None)
77
- if mt:
78
- mn = self.models_info[model_type].get(model_name, None)
79
- if mn and "target_instrument" in self.models_info[model_type][model_name]:
80
- return self.models_info[model_type][model_name]["target_instrument"]
81
- return None
82
- except (KeyError, TypeError):
83
- return None
84
-
85
- def display_models_info(self, filter: str = None):
86
- # Собираем данные для таблицы
87
- table_data = []
88
- headers = [
89
- "Тип модели",
90
- "ID",
91
- "Имя модели",
92
- "Стемы",
93
- "Целевой инструмент",
94
- ]
95
-
96
- for model_type, models in self.models_info.items():
97
- for model_name, model_info in models.items():
98
- try:
99
- stems_list = model_info.get("stems", [])
100
- id = model_info.get("id", "н/д")
101
- # Применяем фильтр (регистронезависимо)
102
- if filter:
103
- filter_lower = filter.lower()
104
- if not any(filter_lower == s.lower() for s in stems_list):
105
- continue
106
-
107
- # Подготавливаем данные для строки таблицы
108
- row = [
109
- model_type,
110
- id,
111
- model_name,
112
- ", ".join(stems_list) or "н/д",
113
- model_info.get("target_instrument", "н/д"),
114
- ]
115
- table_data.append(row)
116
- except (KeyError, TypeError, AttributeError) as e:
117
- print(f"Ошибка при обработке модели {model_type}/{model_name}: {e}")
118
- continue
119
-
120
- # Выводим результат
121
- if table_data:
122
- print(tabulate(table_data, headers=headers, tablefmt="grid"))
123
- else:
124
- print("Нет моделей, которые содержат указанный стем")
125
-
126
- def download_model(
127
- self, model_paths, model_name, model_type, ckpt_url, conf_url
128
- ):
129
- model_dir = os.path.join(model_paths, model_type)
130
- os.makedirs(model_dir, exist_ok=True)
131
-
132
- config_path = None
133
- checkpoint_path = None
134
-
135
- if model_type == "mel_band_roformer":
136
- config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
137
- checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
138
-
139
- elif model_type == "vr":
140
- config_path = os.path.join(model_dir, f"{model_name}.yaml")
141
- checkpoint_path = os.path.join(model_dir, f"{model_name}.pth")
142
-
143
- elif model_type == "mdxnet":
144
- config_path = os.path.join(model_dir, f"{model_name}.yaml")
145
- checkpoint_path = os.path.join(model_dir, f"{model_name}.onnx")
146
-
147
- elif model_type == "bs_roformer":
148
- config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
149
- checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
150
-
151
- elif model_type == "mdx23c":
152
- config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
153
- checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
154
-
155
- elif model_type == "scnet":
156
- config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
157
- checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
158
-
159
- elif model_type == "bandit":
160
- config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
161
- checkpoint_path = os.path.join(model_dir, f"{model_name}.chpt")
162
-
163
- elif model_type == "bandit_v2":
164
- config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
165
- checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
166
-
167
- elif model_type == "htdemucs":
168
- config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
169
- checkpoint_path = os.path.join(model_dir, f"{model_name}.th")
170
-
171
- else:
172
- raise ValueError(
173
- f"{self.I18N_helper.t('error_unsupported_model_type')}: {model_type}"
174
- )
175
-
176
- # Проверяем, что пути заданы (на всякий случай)
177
- if config_path is None or checkpoint_path is None:
178
- raise RuntimeError()
179
-
180
- # Если файлы уже есть — пропускаем загрузку
181
- if os.path.exists(checkpoint_path) and os.path.exists(config_path):
182
- if os.path.getsize(checkpoint_path) == 0 or os.path.getsize(checkpoint_path) == 0:
183
- for local_path, url_model in [
184
- (checkpoint_path, ckpt_url),
185
- (config_path, conf_url),
186
- ]:
187
- if not os.path.exists(local_path):
188
-
189
- dw_file(url_model, local_path)
190
- else:
191
- pass
192
- else:
193
- for local_path, url_model in [
194
- (checkpoint_path, ckpt_url),
195
- (config_path, conf_url),
196
- ]:
197
- if not os.path.exists(local_path):
198
-
199
- dw_file(url_model, local_path)
200
-
201
- return config_path, checkpoint_path
202
-
203
- def conf_editor(self, config_path, mdx_denoise, vr_aggr, model_type):
204
-
205
- class IndentDumper(yaml.Dumper):
206
- def increase_indent(self, flow=False, indentless=False):
207
- return super(IndentDumper, self).increase_indent(flow, False)
208
-
209
- def tuple_constructor(loader, node):
210
- # Load the sequence of values from the YAML node
211
- values = loader.construct_sequence(node)
212
- # Return a tuple constructed from the sequence
213
- return tuple(values)
214
-
215
- # Register the constructor with PyYAML
216
- yaml.SafeLoader.add_constructor(
217
- "tag:yaml.org,2002:python/tuple", tuple_constructor
218
- )
219
-
220
- def conf_edit(config_path, mdx_denoise, vr_aggr, model_type):
221
- with open(config_path, "r") as f:
222
- data = yaml.load(f, Loader=yaml.SafeLoader)
223
-
224
- # handle cases where 'use_amp' is missing from config:
225
- if "use_amp" not in data.keys():
226
- data["training"]["use_amp"] = True
227
-
228
- if model_type != "vr":
229
- if data["inference"]["num_overlap"] != 2:
230
- data["inference"]["num_overlap"] = 2
231
-
232
- if data["inference"]["batch_size"] != 1:
233
- data["inference"]["batch_size"] = 1
234
-
235
- if model_type == "mdxnet":
236
- data["inference"]["denoise"] = mdx_denoise
237
-
238
- elif model_type == "vr":
239
- data["inference"]["aggression"] = vr_aggr
240
-
241
- with open(config_path, "w") as f:
242
- yaml.dump(
243
- data,
244
- f,
245
- default_flow_style=False,
246
- sort_keys=False,
247
- Dumper=IndentDumper,
248
- allow_unicode=True,
249
- )
250
-
251
- conf_edit(config_path, mdx_denoise, vr_aggr, model_type)
252
-
253
- class VbachModelManager:
254
- def __init__(self):
255
- self.rmvpe_path = os.path.join(script_dir, "predictors", "rmvpe.pt")
256
- self.fcpe_path = os.path.join(script_dir, "predictors", "fcpe.pt")
257
- self.custom_fairseq_huberts_dir = os.path.join(script_dir, "custom_fairseq_embedders")
258
- self.custom_transformers_huberts_dir = os.path.join(script_dir, "custom_transformers_embedders")
259
- self.huberts_fairseq_dict = {
260
- "hubert_base": {
261
- "url": "https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/hubert_base.pt",
262
- "local_path": os.path.join(self.custom_fairseq_huberts_dir, "hubert_base.pt")
263
- },
264
- "contentvec_base": {
265
- "url": "https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/contentvec_base.pt",
266
- "local_path": os.path.join(self.custom_fairseq_huberts_dir, "contentvec_base.pt")
267
- },
268
- "korean_hubert_base": {
269
- "url": "https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/korean_hubert_base.pt",
270
- "local_path": os.path.join(self.custom_fairseq_huberts_dir, "korean_hubert_base.pt")
271
- },
272
- "chinese_hubert_base": {
273
- "url": "https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/chinese_hubert_base.pt",
274
- "local_path": os.path.join(self.custom_fairseq_huberts_dir, "chinese_hubert_base.pt")
275
- },
276
- "portuguese_hubert_base": {
277
- "url": "https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/portuguese_hubert_base.pt",
278
- "local_path": os.path.join(self.custom_fairseq_huberts_dir, "portuguese_hubert_base.pt")
279
- },
280
- "japanese_hubert_base": {
281
- "url": "https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/japanese_hubert_base.pt",
282
- "local_path": os.path.join(self.custom_fairseq_huberts_dir, "japanese_hubert_base.pt")
283
- }
284
- }
285
- self.huberts_transformers_dict = {
286
- "contentvec": {
287
- "base_dir": os.path.join(self.custom_transformers_huberts_dir, "contentvec"),
288
- "url_bin": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/contentvec/pytorch_model.bin",
289
- "url_json": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/contentvec/config.json",
290
- "local_bin": os.path.join(self.custom_transformers_huberts_dir, "contentvec", "pytorch_model.bin"),
291
- "local_json": os.path.join(self.custom_transformers_huberts_dir, "contentvec", "config.json")
292
- },
293
- "spin": {
294
- "base_dir": os.path.join(self.custom_transformers_huberts_dir, "spin"),
295
- "url_bin": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/spin/pytorch_model.bin",
296
- "url_json": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/spin/config.json",
297
- "local_bin": os.path.join(self.custom_transformers_huberts_dir, "spin", "pytorch_model.bin"),
298
- "local_json": os.path.join(self.custom_transformers_huberts_dir, "spin", "config.json")
299
- },
300
- "spin-v2": {
301
- "base_dir": os.path.join(self.custom_transformers_huberts_dir, "spinv2"),
302
- "url_bin": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/spin-v2/pytorch_model.bin",
303
- "url_json": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/spin-v2/config.json",
304
- "local_bin": os.path.join(self.custom_transformers_huberts_dir, "spinv2", "pytorch_model.bin"),
305
- "local_json": os.path.join(self.custom_transformers_huberts_dir, "spinv2", "config.json")
306
- },
307
- "chinese-hubert-base": {
308
- "base_dir": os.path.join(self.custom_transformers_huberts_dir, "chinese_hubert_base"),
309
- "url_bin": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/chinese_hubert_base/pytorch_model.bin",
310
- "url_json": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/chinese_hubert_base/config.json",
311
- "local_bin": os.path.join(self.custom_transformers_huberts_dir, "chinese_hubert_base", "pytorch_model.bin"),
312
- "local_json": os.path.join(self.custom_transformers_huberts_dir, "chinese_hubert_base", "config.json")
313
- },
314
- "japanese-hubert-base": {
315
- "base_dir": os.path.join(self.custom_transformers_huberts_dir, "japanese_hubert_base"),
316
- "url_bin": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/japanese_hubert_base/pytorch_model.bin",
317
- "url_json": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/japanese_hubert_base/config.json",
318
- "local_bin": os.path.join(self.custom_transformers_huberts_dir, "japanese_hubert_base", "pytorch_model.bin"),
319
- "local_json": os.path.join(self.custom_transformers_huberts_dir, "japanese_hubert_base", "config.json")
320
- },
321
- "korean-hubert-base": {
322
- "base_dir": os.path.join(self.custom_transformers_huberts_dir, "korean_hubert_base"),
323
- "url_bin": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/korean_hubert_base/pytorch_model.bin",
324
- "url_json": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/korean_hubert_base/config.json",
325
- "local_bin": os.path.join(self.custom_transformers_huberts_dir, "korean_hubert_base", "pytorch_model.bin"),
326
- "local_json": os.path.join(self.custom_transformers_huberts_dir, "korean_hubert_base", "config.json")
327
- }
328
- }
329
- self.requirements = [["https://huggingface.co/Politrees/RVC_resources/resolve/main/predictors/rmvpe.pt", self.rmvpe_path], ["https://huggingface.co/Politrees/RVC_resources/resolve/main/predictors/fcpe.pt", self.fcpe_path]]
330
- self.voicemodels_dir = os.path.join(script_dir, "vbach_models_cache")
331
- os.makedirs(self.voicemodels_dir, exist_ok=True)
332
- self.voicemodels_info = os.path.join(self.voicemodels_dir, "vbach_models.json")
333
- self.voicemodels: Dict[str, Dict[str, str]] = {}
334
- self.download_requirements()
335
- self.check_hubert("hubert_base")
336
- self.check_and_load()
337
- pass
338
-
339
- def check_hubert(self, embedder_name):
340
- if embedder_name in self.huberts_fairseq_dict:
341
- if not os.path.exists(self.huberts_fairseq_dict[embedder_name]["local_path"]):
342
- dw_file(self.huberts_fairseq_dict[embedder_name]["url"], self.huberts_fairseq_dict[embedder_name]["local_path"])
343
- return self.huberts_fairseq_dict[embedder_name]["local_path"]
344
- else:
345
- return None
346
-
347
- def check_hubert_transformers(self, embedder_name):
348
- if embedder_name in self.huberts_transformers_dict:
349
- os.makedirs(self.huberts_transformers_dict[embedder_name]["base_dir"], exist_ok=True)
350
- if not os.path.exists(self.huberts_transformers_dict[embedder_name]["local_bin"]) and not os.path.exists(self.huberts_transformers_dict[embedder_name]["local_json"]):
351
- dw_file(self.huberts_transformers_dict[embedder_name]["url_bin"], self.huberts_transformers_dict[embedder_name]["local_bin"])
352
- dw_file(self.huberts_transformers_dict[embedder_name]["url_json"], self.huberts_transformers_dict[embedder_name]["local_json"])
353
- return self.huberts_transformers_dict[embedder_name]["base_dir"]
354
- else:
355
- return None
356
-
357
- def write_voicemodels_info(self):
358
- with open(self.voicemodels_info, "w") as f:
359
- json.dump(self.voicemodels, f, indent=4)
360
-
361
- def load_voicemodels_info(self):
362
- with open(self.voicemodels_info, "r") as f:
363
- return json.load(f)
364
-
365
- def add_voice_model(
366
- self,
367
- name,
368
- pth_path,
369
- index_path,
370
- ):
371
- self.voicemodels[name] = {"pth": pth_path, "index": index_path}
372
- self.write_voicemodels_info()
373
-
374
- def del_voice_model(
375
- self, name
376
- ):
377
- if name in self.parse_voice_models():
378
- pth = self.voicemodels[name].get("pth", None)
379
- index = self.voicemodels[name].get("index", None)
380
- if index:
381
- os.remove(index)
382
- if pth:
383
- os.remove(pth)
384
- del self.voicemodels[name]
385
- self.write_voicemodels_info()
386
- return f"Модель {name} удалена"
387
- else:
388
- return f"Модель не была удалена, как так её не существует"
389
-
390
- def parse_voice_models(self):
391
- list_models = list(self.voicemodels.keys())
392
- return list_models
393
-
394
- def parse_pth_and_index(self, name):
395
- pth = self.voicemodels[name].get("pth", None)
396
- index = self.voicemodels[name].get("index", None)
397
- return pth, index
398
-
399
- def check_and_load(self):
400
- if os.path.exists(self.voicemodels_info):
401
- self.voicemodels = self.load_voicemodels_info()
402
- else:
403
- self.write_voicemodels_info()
404
-
405
- def clear_voicemodels_info(self):
406
- self.voicemodels: Dict[str, Dict[str, str]] = {}
407
- self.write_voicemodels_info()
408
-
409
- def download_requirements(self):
410
- for url, file in self.requirements:
411
- if not os.path.exists(file):
412
- dw_file(url, file)
413
-
414
- def download_voice_model_file(self, url, zip_name):
415
- try:
416
- if "drive.google.com" in url:
417
- self.download_from_google_drive(url, zip_name)
418
- elif "pixeldrain.com" in url:
419
- self.download_from_pixeldrain(url, zip_name)
420
- elif "disk.yandex.ru" in url or "yadi.sk" in url:
421
- self.download_from_yandex(url, zip_name)
422
- else:
423
- dw_file(url, zip_name)
424
- except Exception as e:
425
- print(e)
426
-
427
- def download_from_google_drive(self, url, zip_name):
428
- file_id = (
429
- url.split("file/d/")[1].split("/")[0]
430
- if "file/d/" in url
431
- else url.split("id=")[1].split("&")[0]
432
- )
433
- gdown.download(id=file_id, output=str(zip_name), quiet=False)
434
-
435
- def download_from_pixeldrain(self, url, zip_name):
436
- file_id = url.split("pixeldrain.com/u/")[1]
437
- response = requests.get(f"https://pixeldrain.com/api/file/{file_id}")
438
- with open(zip_name, "wb") as f:
439
- f.write(response.content)
440
-
441
- def download_from_yandex(self, url, zip_name):
442
- yandex_public_key = f"download?public_key={url}"
443
- yandex_api_url = f"https://cloud-api.yandex.net/v1/disk/public/resources/{yandex_public_key}"
444
- response = requests.get(yandex_api_url)
445
- if response.status_code == 200:
446
- download_link = response.json().get("href")
447
- urllib.request.urlretrieve(download_link, zip_name)
448
- else:
449
- print(response.status_code)
450
-
451
- def extract_zip(self, zip_name, model_name):
452
- model_dir = os.path.join(self.voicemodels_dir, f"{model_name}_{generate_secure_random(17)}")
453
- os.makedirs(model_dir, exist_ok=True)
454
- try:
455
- with zipfile.ZipFile(zip_name, "r") as zip_ref:
456
- zip_ref.extractall(model_dir)
457
- os.remove(zip_name)
458
-
459
- added_voice_models = []
460
-
461
- index_filepath, model_filepaths = None, []
462
- for root, _, files in os.walk(model_dir):
463
- for name in files:
464
- file_path = os.path.join(root, name)
465
- if name.endswith(".index") and os.stat(file_path).st_size > 1024 * 100:
466
- index_filepath = file_path
467
- if name.endswith(".pth") and os.stat(file_path).st_size > 1024 * 1024 * 20:
468
- model_filepaths.append(file_path)
469
-
470
- if len(model_filepaths) == 1:
471
- self.add_voice_model(model_name, model_filepaths[0], index_filepath)
472
- added_voice_models.append(model_name)
473
- else:
474
- for i, pth in enumerate(model_filepaths):
475
- self.add_voice_model(f"{model_name}_{i + 1}", pth, index_filepath)
476
- added_voice_models.append(f"{model_name}_{i + 1}")
477
- list_models_str = '\n'.join(added_voice_models)
478
- return f"Добавленные модели:\n{list_models_str}"
479
- except Exception as e:
480
- return f"Произошла ошибка при загрузке модели: {e}"
481
-
482
- def install_model_zip(self, zip, model_name, mode="url"):
483
- if model_name in self.parse_voice_models():
484
- print("Эта модель уже есть в списке установленных моделей. Она будут перезаписана")
485
- if mode == "url":
486
- with tempfile.TemporaryDirectory(prefix="vbach_temp_model", ignore_cleanup_errors=True) as tmp:
487
- zip_path = os.path.join(tmp, "model.zip")
488
- self.download_voice_model_file(zip, zip_path)
489
- status = self.extract_zip(zip_path, model_name)
490
- if mode == "local":
491
- status = self.extract_zip(zip, model_name)
492
- return status
493
-
494
- def install_model_files(self, index, pth, model_name, mode="url"):
495
- if model_name in self.parse_voice_models():
496
- print("Эта модель уже есть в списке установленных моделей. Она будут перезаписана")
497
- model_dir = os.path.join(self.voicemodels_dir, f"{model_name}_{generate_secure_random(17)}")
498
- os.makedirs(model_dir, exist_ok=True)
499
- local_index_path = None
500
- local_pth_path = None
501
- try:
502
- if mode == "url":
503
- if index:
504
- local_index_path = os.path.join(model_dir, "model.index")
505
- self.download_voice_model_file(index, local_index_path)
506
- if pth:
507
- local_pth_path = os.path.join(model_dir, "model.pth")
508
- self.download_voice_model_file(pth, local_pth_path)
509
-
510
- if mode == "local":
511
- if index:
512
- if os.path.exists(index):
513
- local_index_path = os.path.join(model_dir, os.path.basename(index))
514
- shutil.copy(index, local_index_path)
515
- if pth:
516
- if os.path.exists(pth):
517
- local_pth_path = os.path.join(model_dir, os.path.basename(pth))
518
- shutil.copy(pth, local_pth_path)
519
-
520
- self.add_voice_model(model_name, local_pth_path, local_index_path)
521
- return f"Модель {model_name} добавлена"
522
- except Exception as e:
523
- return f"Произошла ошибка при загрузке модели: {e}"
524
-
525
-
526
- if __name__ == "__main__":
527
- parser = argparse.ArgumentParser(description="Менеджер моделей")
528
- subparsers = parser.add_subparsers(title="subcommands", dest="command", required=True)
529
-
530
- # Mvsepless subcommand
531
- mvsepless_parser = subparsers.add_parser("mvsepless", help="Скачивание моделей в MVSepLess")
532
- mvsepless_parser.add_argument("--model_type", required=True, help="Тип модели")
533
- mvsepless_parser.add_argument("--model_name", required=True, help="Имя модели")
534
-
535
- # Vbach subcommand
536
- vbach_parser = subparsers.add_parser("vbach", help="Установка голосовых моделей в Vbach")
537
- vbach_subparsers = vbach_parser.add_subparsers(title="vbach_commands", dest="vbach_command", required=True)
538
-
539
- # Vbach install_local
540
- install_local_parser = vbach_subparsers.add_parser("install_local", help="Установка голосовой модели по локальным файлам")
541
- install_local_parser.add_argument("--model_name", required=True, help="Имя голосовой модели")
542
- install_local_parser.add_argument("--pth", required=True, help="Путь к *.pth файлу")
543
- install_local_parser.add_argument("--index", required=False, help="Путь к *.index файлу")
544
-
545
- # Vbach install_url_zip
546
- install_url_zip_parser = vbach_subparsers.add_parser("install_url_zip", help="Установка голосовой модели по URL (архив с файлами)")
547
- install_url_zip_parser.add_argument("--model_name", required=True, help="Имя голосовой модели")
548
- install_url_zip_parser.add_argument("--url", required=True, help="URL *.zip файла")
549
-
550
- # Vbach install_url_files
551
- install_url_files_parser = vbach_subparsers.add_parser("install_url_files", help="Установка голосовой модели по URL (отдельные файлы)")
552
- install_url_files_parser.add_argument("--model_name", required=True, help="Имя голосовой модели")
553
- install_url_files_parser.add_argument("--pth_url", required=True, help="URL *.pth файла")
554
- install_url_files_parser.add_argument("--index_url", required=False, help="URL *.index файла")
555
-
556
- # Vbach list
557
- list_parser = vbach_subparsers.add_parser("list", help="List installed voice models")
558
-
559
- args = parser.parse_args()
560
-
561
- if args.command == "mvsepless":
562
-
563
- _model_manager = MvseplessModelManager()
564
- info = _model_manager.models_info[args.model_type].get(args.model_name, None)
565
- if not info:
566
- raise ValueError(f"Модель {args.model_name} не найдена для типа {args.model_type}")
567
- conf, ckpt = _model_manager.download_model(
568
- _model_manager.models_cache_dir,
569
- args.model_name,
570
- args.model_type,
571
- info["checkpoint_url"],
572
- info["config_url"],
573
- )
574
-
575
- elif args.command == "vbach":
576
- model_manager = VbachModelManager()
577
-
578
- if args.vbach_command == "install_local":
579
- status = model_manager.install_model_files(
580
- args.index, args.pth, args.model_name, mode="local"
581
- )
582
- print(status)
583
-
584
- elif args.vbach_command == "install_url_zip":
585
- status = model_manager.install_model_zip(
586
- args.url, args.model_name, mode="url"
587
- )
588
- print(status)
589
-
590
- elif args.vbach_command == "install_url_files":
591
- status = model_manager.install_model_files(
592
- args.index_url, args.pth_url, args.model_name, mode="url"
593
- )
594
- print(status)
595
-
596
- elif args.vbach_command == "list":
597
- models = model_manager.parse_voice_models()
598
- if models:
599
- print("Установленные модели:")
600
- for model in models:
601
- print(f" - {model}")
602
- else:
603
- print("Нет установленных моделей")
604
-
605
-
606
-
607
-
608
-
609
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import yaml
5
+ from tabulate import tabulate
6
+ import shutil
7
+ from tqdm import tqdm
8
+ import urllib.request
9
+ import gdown
10
+ import requests
11
+ import zipfile
12
+ import tempfile
13
+ import secrets
14
+ import string
15
+ import argparse
16
+ from typing import Dict, Any
17
+
18
+ script_dir = os.path.dirname(os.path.abspath(__file__))
19
+ if not __package__:
20
+ from downloader import dw_file
21
+ else:
22
+ from .downloader import dw_file
23
+
24
+
25
+ def generate_secure_random(length=10):
26
+ characters = string.ascii_letters + string.digits
27
+ return "".join(secrets.choice(characters) for _ in range(length))
28
+
29
+
30
+ class MvseplessModelManager:
31
+ def __init__(
32
+ self,
33
+ models_info_path=os.path.join(script_dir, "models.json"),
34
+ cache_dir=os.path.join(script_dir, "mvsepless_models_cache"),
35
+ ):
36
+ self.models_cache_dir = cache_dir
37
+ self.models_info_path = models_info_path
38
+ with open(self.models_info_path, "r", encoding="utf-8") as f:
39
+ models_info = json.load(f)
40
+ self.models_info = models_info
41
+
42
+ def get_mt(self):
43
+ return [mt for mt in self.models_info]
44
+
45
+ def get_mn(self, model_type):
46
+ return [mn for mn in self.models_info.get(model_type, [])]
47
+
48
+ def get_stems(self, model_type, model_name):
49
+ return [
50
+ stem
51
+ for stem in self.models_info.get(model_type)
52
+ .get(model_name)
53
+ .get("stems", [])
54
+ ]
55
+
56
+ def get_id(self, model_type, model_name):
57
+ return self.models_info.get(model_type).get(model_name).get("id", 0)
58
+
59
+ def get_tgt_inst(self, model_type, model_name):
60
+ return (
61
+ self.models_info.get(model_type)
62
+ .get(model_name)
63
+ .get("target_instrument", None)
64
+ )
65
+
66
+ def get_category(self, model_type, model_name):
67
+ return self.models_info.get(model_type).get(model_name).get("category", "")
68
+
69
+ def display_models_info(self, filter: str = None):
70
+ table_data = []
71
+ headers = [
72
+ "Тип модели",
73
+ "ID",
74
+ "Имя модели",
75
+ "Стемы",
76
+ "Целевой инструмент",
77
+ ]
78
+
79
+ for model_type, models in self.models_info.items():
80
+ for model_name, model_info in models.items():
81
+ try:
82
+ stems_list = model_info.get("stems", [])
83
+ id = model_info.get("id", "н/д")
84
+ if filter:
85
+ filter_lower = filter.lower()
86
+ if not any(filter_lower == s.lower() for s in stems_list):
87
+ continue
88
+
89
+ row = [
90
+ model_type,
91
+ id,
92
+ model_name,
93
+ ", ".join(stems_list) or "н",
94
+ model_info.get("target_instrument", "н/д"),
95
+ ]
96
+ table_data.append(row)
97
+ except (KeyError, TypeError, AttributeError) as e:
98
+ print(f"Ошибка при обработке модели {model_type}/{model_name}: {e}")
99
+ continue
100
+
101
+ if table_data:
102
+ print(tabulate(table_data, headers=headers, tablefmt="grid"))
103
+ else:
104
+ print("Нет моделей, которые содержат указанный стем")
105
+
106
+ def download_model(self, model_paths, model_name, model_type, ckpt_url, conf_url):
107
+ model_dir = os.path.join(model_paths, model_type)
108
+ os.makedirs(model_dir, exist_ok=True)
109
+
110
+ config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
111
+ checkpoint_path = os.path.join(
112
+ model_dir,
113
+ f"{model_name}.onnx" if model_type == "mdxnet" else f"{model_name}.ckpt",
114
+ )
115
+
116
+ if config_path is None or checkpoint_path is None:
117
+ raise RuntimeError()
118
+
119
+ if os.path.exists(checkpoint_path) and os.path.exists(config_path):
120
+ if (
121
+ os.path.getsize(checkpoint_path) == 0
122
+ or os.path.getsize(checkpoint_path) == 0
123
+ ):
124
+ for local_path, url_model in [
125
+ (checkpoint_path, ckpt_url),
126
+ (config_path, conf_url),
127
+ ]:
128
+ if not os.path.exists(local_path):
129
+
130
+ dw_file(url_model, local_path)
131
+ else:
132
+ pass
133
+ else:
134
+ for local_path, url_model in [
135
+ (checkpoint_path, ckpt_url),
136
+ (config_path, conf_url),
137
+ ]:
138
+ if not os.path.exists(local_path):
139
+
140
+ dw_file(url_model, local_path)
141
+
142
+ return config_path, checkpoint_path
143
+
144
+ def conf_editor(self, config_path, mdx_denoise, vr_aggr, model_type):
145
+
146
+ class IndentDumper(yaml.Dumper):
147
+ def increase_indent(self, flow=False, indentless=False):
148
+ return super(IndentDumper, self).increase_indent(flow, False)
149
+
150
+ def tuple_constructor(loader, node):
151
+ values = loader.construct_sequence(node)
152
+ return tuple(values)
153
+
154
+ yaml.SafeLoader.add_constructor(
155
+ "tag:yaml.org,2002:python/tuple", tuple_constructor
156
+ )
157
+
158
+ def conf_edit(config_path, mdx_denoise, vr_aggr, model_type):
159
+ with open(config_path, "r") as f:
160
+ data = yaml.load(f, Loader=yaml.SafeLoader)
161
+
162
+ if "use_amp" not in data.keys():
163
+ data["training"]["use_amp"] = True
164
+
165
+ if model_type != "vr":
166
+ if data["inference"]["num_overlap"] != 2:
167
+ data["inference"]["num_overlap"] = 2
168
+
169
+ if data["inference"]["batch_size"] != 1:
170
+ data["inference"]["batch_size"] = 1
171
+
172
+ if model_type == "mdxnet":
173
+ data["inference"]["denoise"] = mdx_denoise
174
+
175
+ elif model_type == "vr":
176
+ data["inference"]["aggression"] = vr_aggr
177
+
178
+ with open(config_path, "w") as f:
179
+ yaml.dump(
180
+ data,
181
+ f,
182
+ default_flow_style=False,
183
+ sort_keys=False,
184
+ Dumper=IndentDumper,
185
+ allow_unicode=True,
186
+ )
187
+
188
+ conf_edit(config_path, mdx_denoise, vr_aggr, model_type)
189
+
190
+
191
+ class VbachModelManager:
192
+ def __init__(self):
193
+ self.rmvpe_path = os.path.join(script_dir, "predictors", "rmvpe.pt")
194
+ self.fcpe_path = os.path.join(script_dir, "predictors", "fcpe.pt")
195
+ self.custom_fairseq_huberts_dir = os.path.join(
196
+ script_dir, "custom_fairseq_embedders"
197
+ )
198
+ self.custom_transformers_huberts_dir = os.path.join(
199
+ script_dir, "custom_transformers_embedders"
200
+ )
201
+ self.huberts_fairseq_dict = {
202
+ "hubert_base": {
203
+ "url": "https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/hubert_base.pt",
204
+ "local_path": os.path.join(
205
+ self.custom_fairseq_huberts_dir, "hubert_base.pt"
206
+ ),
207
+ },
208
+ "contentvec_base": {
209
+ "url": "https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/contentvec_base.pt",
210
+ "local_path": os.path.join(
211
+ self.custom_fairseq_huberts_dir, "contentvec_base.pt"
212
+ ),
213
+ },
214
+ "korean_hubert_base": {
215
+ "url": "https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/korean_hubert_base.pt",
216
+ "local_path": os.path.join(
217
+ self.custom_fairseq_huberts_dir, "korean_hubert_base.pt"
218
+ ),
219
+ },
220
+ "chinese_hubert_base": {
221
+ "url": "https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/chinese_hubert_base.pt",
222
+ "local_path": os.path.join(
223
+ self.custom_fairseq_huberts_dir, "chinese_hubert_base.pt"
224
+ ),
225
+ },
226
+ "portuguese_hubert_base": {
227
+ "url": "https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/portuguese_hubert_base.pt",
228
+ "local_path": os.path.join(
229
+ self.custom_fairseq_huberts_dir, "portuguese_hubert_base.pt"
230
+ ),
231
+ },
232
+ "japanese_hubert_base": {
233
+ "url": "https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/japanese_hubert_base.pt",
234
+ "local_path": os.path.join(
235
+ self.custom_fairseq_huberts_dir, "japanese_hubert_base.pt"
236
+ ),
237
+ },
238
+ }
239
+ self.huberts_transformers_dict = {
240
+ "contentvec": {
241
+ "base_dir": os.path.join(
242
+ self.custom_transformers_huberts_dir, "contentvec"
243
+ ),
244
+ "url_bin": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/contentvec/pytorch_model.bin",
245
+ "url_json": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/contentvec/config.json",
246
+ "local_bin": os.path.join(
247
+ self.custom_transformers_huberts_dir,
248
+ "contentvec",
249
+ "pytorch_model.bin",
250
+ ),
251
+ "local_json": os.path.join(
252
+ self.custom_transformers_huberts_dir, "contentvec", "config.json"
253
+ ),
254
+ },
255
+ "spin": {
256
+ "base_dir": os.path.join(self.custom_transformers_huberts_dir, "spin"),
257
+ "url_bin": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/spin/pytorch_model.bin",
258
+ "url_json": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/spin/config.json",
259
+ "local_bin": os.path.join(
260
+ self.custom_transformers_huberts_dir, "spin", "pytorch_model.bin"
261
+ ),
262
+ "local_json": os.path.join(
263
+ self.custom_transformers_huberts_dir, "spin", "config.json"
264
+ ),
265
+ },
266
+ "spin-v2": {
267
+ "base_dir": os.path.join(
268
+ self.custom_transformers_huberts_dir, "spinv2"
269
+ ),
270
+ "url_bin": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/spin-v2/pytorch_model.bin",
271
+ "url_json": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/spin-v2/config.json",
272
+ "local_bin": os.path.join(
273
+ self.custom_transformers_huberts_dir, "spinv2", "pytorch_model.bin"
274
+ ),
275
+ "local_json": os.path.join(
276
+ self.custom_transformers_huberts_dir, "spinv2", "config.json"
277
+ ),
278
+ },
279
+ "chinese-hubert-base": {
280
+ "base_dir": os.path.join(
281
+ self.custom_transformers_huberts_dir, "chinese_hubert_base"
282
+ ),
283
+ "url_bin": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/chinese_hubert_base/pytorch_model.bin",
284
+ "url_json": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/chinese_hubert_base/config.json",
285
+ "local_bin": os.path.join(
286
+ self.custom_transformers_huberts_dir,
287
+ "chinese_hubert_base",
288
+ "pytorch_model.bin",
289
+ ),
290
+ "local_json": os.path.join(
291
+ self.custom_transformers_huberts_dir,
292
+ "chinese_hubert_base",
293
+ "config.json",
294
+ ),
295
+ },
296
+ "japanese-hubert-base": {
297
+ "base_dir": os.path.join(
298
+ self.custom_transformers_huberts_dir, "japanese_hubert_base"
299
+ ),
300
+ "url_bin": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/japanese_hubert_base/pytorch_model.bin",
301
+ "url_json": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/japanese_hubert_base/config.json",
302
+ "local_bin": os.path.join(
303
+ self.custom_transformers_huberts_dir,
304
+ "japanese_hubert_base",
305
+ "pytorch_model.bin",
306
+ ),
307
+ "local_json": os.path.join(
308
+ self.custom_transformers_huberts_dir,
309
+ "japanese_hubert_base",
310
+ "config.json",
311
+ ),
312
+ },
313
+ "korean-hubert-base": {
314
+ "base_dir": os.path.join(
315
+ self.custom_transformers_huberts_dir, "korean_hubert_base"
316
+ ),
317
+ "url_bin": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/korean_hubert_base/pytorch_model.bin",
318
+ "url_json": "https://huggingface.co/IAHispano/Applio/resolve/main/Resources/embedders/korean_hubert_base/config.json",
319
+ "local_bin": os.path.join(
320
+ self.custom_transformers_huberts_dir,
321
+ "korean_hubert_base",
322
+ "pytorch_model.bin",
323
+ ),
324
+ "local_json": os.path.join(
325
+ self.custom_transformers_huberts_dir,
326
+ "korean_hubert_base",
327
+ "config.json",
328
+ ),
329
+ },
330
+ }
331
+ self.requirements = [
332
+ [
333
+ "https://huggingface.co/Politrees/RVC_resources/resolve/main/predictors/rmvpe.pt",
334
+ self.rmvpe_path,
335
+ ],
336
+ [
337
+ "https://huggingface.co/Politrees/RVC_resources/resolve/main/predictors/fcpe.pt",
338
+ self.fcpe_path,
339
+ ],
340
+ ]
341
+ self.voicemodels_dir = os.path.join(script_dir, "vbach_models_cache")
342
+ os.makedirs(self.voicemodels_dir, exist_ok=True)
343
+ self.voicemodels_info = os.path.join(self.voicemodels_dir, "vbach_models.json")
344
+ self.voicemodels: Dict[str, Dict[str, str]] = {}
345
+ self.download_requirements()
346
+ self.check_hubert("hubert_base")
347
+ self.check_and_load()
348
+ pass
349
+
350
+ def check_hubert(self, embedder_name):
351
+ if embedder_name in self.huberts_fairseq_dict:
352
+ if not os.path.exists(
353
+ self.huberts_fairseq_dict[embedder_name]["local_path"]
354
+ ):
355
+ dw_file(
356
+ self.huberts_fairseq_dict[embedder_name]["url"],
357
+ self.huberts_fairseq_dict[embedder_name]["local_path"],
358
+ )
359
+ return self.huberts_fairseq_dict[embedder_name]["local_path"]
360
+ else:
361
+ return None
362
+
363
+ def check_hubert_transformers(self, embedder_name):
364
+ if embedder_name in self.huberts_transformers_dict:
365
+ os.makedirs(
366
+ self.huberts_transformers_dict[embedder_name]["base_dir"], exist_ok=True
367
+ )
368
+ if not os.path.exists(
369
+ self.huberts_transformers_dict[embedder_name]["local_bin"]
370
+ ) and not os.path.exists(
371
+ self.huberts_transformers_dict[embedder_name]["local_json"]
372
+ ):
373
+ dw_file(
374
+ self.huberts_transformers_dict[embedder_name]["url_bin"],
375
+ self.huberts_transformers_dict[embedder_name]["local_bin"],
376
+ )
377
+ dw_file(
378
+ self.huberts_transformers_dict[embedder_name]["url_json"],
379
+ self.huberts_transformers_dict[embedder_name]["local_json"],
380
+ )
381
+ return self.huberts_transformers_dict[embedder_name]["base_dir"]
382
+ else:
383
+ return None
384
+
385
+ def write_voicemodels_info(self):
386
+ with open(self.voicemodels_info, "w") as f:
387
+ json.dump(self.voicemodels, f, indent=4)
388
+
389
+ def load_voicemodels_info(self):
390
+ with open(self.voicemodels_info, "r") as f:
391
+ return json.load(f)
392
+
393
+ def add_voice_model(
394
+ self,
395
+ name,
396
+ pth_path,
397
+ index_path,
398
+ ):
399
+ self.voicemodels[name] = {"pth": pth_path, "index": index_path}
400
+ self.write_voicemodels_info()
401
+
402
+ def del_voice_model(self, name):
403
+ if name in self.parse_voice_models():
404
+ pth = self.voicemodels[name].get("pth", None)
405
+ index = self.voicemodels[name].get("index", None)
406
+ if index:
407
+ os.remove(index)
408
+ if pth:
409
+ os.remove(pth)
410
+ del self.voicemodels[name]
411
+ self.write_voicemodels_info()
412
+ return f"Модель {name} удалена"
413
+ else:
414
+ return f"Модель не была удалена, как так её не существует"
415
+
416
+ def parse_voice_models(self):
417
+ list_models = list(self.voicemodels.keys())
418
+ return list_models
419
+
420
+ def parse_pth_and_index(self, name):
421
+ pth = self.voicemodels[name].get("pth", None)
422
+ index = self.voicemodels[name].get("index", None)
423
+ return pth, index
424
+
425
+ def check_and_load(self):
426
+ if os.path.exists(self.voicemodels_info):
427
+ self.voicemodels = self.load_voicemodels_info()
428
+ else:
429
+ self.write_voicemodels_info()
430
+
431
+ def clear_voicemodels_info(self):
432
+ self.voicemodels: Dict[str, Dict[str, str]] = {}
433
+ self.write_voicemodels_info()
434
+
435
+ def download_requirements(self):
436
+ for url, file in self.requirements:
437
+ if not os.path.exists(file):
438
+ dw_file(url, file)
439
+
440
+ def download_voice_model_file(self, url, zip_name):
441
+ try:
442
+ if "drive.google.com" in url:
443
+ self.download_from_google_drive(url, zip_name)
444
+ elif "pixeldrain.com" in url:
445
+ self.download_from_pixeldrain(url, zip_name)
446
+ elif "disk.yandex.ru" in url or "yadi.sk" in url:
447
+ self.download_from_yandex(url, zip_name)
448
+ else:
449
+ dw_file(url, zip_name)
450
+ except Exception as e:
451
+ print(e)
452
+
453
+ def download_from_google_drive(self, url, zip_name):
454
+ file_id = (
455
+ url.split("file/d/")[1].split("/")[0]
456
+ if "file/d/" in url
457
+ else url.split("id=")[1].split("&")[0]
458
+ )
459
+ gdown.download(id=file_id, output=str(zip_name), quiet=False)
460
+
461
+ def download_from_pixeldrain(self, url, zip_name):
462
+ file_id = url.split("pixeldrain.com/u/")[1]
463
+ response = requests.get(f"https://pixeldrain.com/api/file/{file_id}")
464
+ with open(zip_name, "wb") as f:
465
+ f.write(response.content)
466
+
467
+ def download_from_yandex(self, url, zip_name):
468
+ yandex_public_key = f"download?public_key={url}"
469
+ yandex_api_url = (
470
+ f"https://cloud-api.yandex.net/v1/disk/public/resources/{yandex_public_key}"
471
+ )
472
+ response = requests.get(yandex_api_url)
473
+ if response.status_code == 200:
474
+ download_link = response.json().get("href")
475
+ urllib.request.urlretrieve(download_link, zip_name)
476
+ else:
477
+ print(response.status_code)
478
+
479
+ def extract_zip(self, zip_name, model_name):
480
+ model_dir = os.path.join(
481
+ self.voicemodels_dir, f"{model_name}_{generate_secure_random(17)}"
482
+ )
483
+ os.makedirs(model_dir, exist_ok=True)
484
+ try:
485
+ with zipfile.ZipFile(zip_name, "r") as zip_ref:
486
+ zip_ref.extractall(model_dir)
487
+ os.remove(zip_name)
488
+
489
+ added_voice_models = []
490
+
491
+ index_filepath, model_filepaths = None, []
492
+ for root, _, files in os.walk(model_dir):
493
+ for name in files:
494
+ file_path = os.path.join(root, name)
495
+ if (
496
+ name.endswith(".index")
497
+ and os.stat(file_path).st_size > 1024 * 100
498
+ ):
499
+ index_filepath = file_path
500
+ if (
501
+ name.endswith(".pth")
502
+ and os.stat(file_path).st_size > 1024 * 1024 * 20
503
+ ):
504
+ model_filepaths.append(file_path)
505
+
506
+ if len(model_filepaths) == 1:
507
+ self.add_voice_model(model_name, model_filepaths[0], index_filepath)
508
+ added_voice_models.append(model_name)
509
+ else:
510
+ for i, pth in enumerate(model_filepaths):
511
+ self.add_voice_model(f"{model_name}_{i + 1}", pth, index_filepath)
512
+ added_voice_models.append(f"{model_name}_{i + 1}")
513
+ list_models_str = "\n".join(added_voice_models)
514
+ return f"Добавленные модели:\n{list_models_str}"
515
+ except Exception as e:
516
+ return f"Произошла ошибка при загрузке модели: {e}"
517
+
518
+ def install_model_zip(self, zip, model_name, mode="url"):
519
+ if model_name in self.parse_voice_models():
520
+ print(
521
+ "Эта модель уже есть в списке установленных моделей. Она будут перезаписана"
522
+ )
523
+ if mode == "url":
524
+ with tempfile.TemporaryDirectory(
525
+ prefix="vbach_temp_model", ignore_cleanup_errors=True
526
+ ) as tmp:
527
+ zip_path = os.path.join(tmp, "model.zip")
528
+ self.download_voice_model_file(zip, zip_path)
529
+ status = self.extract_zip(zip_path, model_name)
530
+ if mode == "local":
531
+ status = self.extract_zip(zip, model_name)
532
+ return status
533
+
534
+ def install_model_files(self, index, pth, model_name, mode="url"):
535
+ if model_name in self.parse_voice_models():
536
+ print(
537
+ "Эта модель уже есть в списке установленных моделей. Она будут перезаписана"
538
+ )
539
+ model_dir = os.path.join(
540
+ self.voicemodels_dir, f"{model_name}_{generate_secure_random(17)}"
541
+ )
542
+ os.makedirs(model_dir, exist_ok=True)
543
+ local_index_path = None
544
+ local_pth_path = None
545
+ try:
546
+ if mode == "url":
547
+ if index:
548
+ local_index_path = os.path.join(model_dir, "model.index")
549
+ self.download_voice_model_file(index, local_index_path)
550
+ if pth:
551
+ local_pth_path = os.path.join(model_dir, "model.pth")
552
+ self.download_voice_model_file(pth, local_pth_path)
553
+
554
+ if mode == "local":
555
+ if index:
556
+ if os.path.exists(index):
557
+ local_index_path = os.path.join(
558
+ model_dir, os.path.basename(index)
559
+ )
560
+ shutil.copy(index, local_index_path)
561
+ if pth:
562
+ if os.path.exists(pth):
563
+ local_pth_path = os.path.join(model_dir, os.path.basename(pth))
564
+ shutil.copy(pth, local_pth_path)
565
+
566
+ self.add_voice_model(model_name, local_pth_path, local_index_path)
567
+ return f"Модель {model_name} добавлена"
568
+ except Exception as e:
569
+ return f"Произошла ошибка при загрузке модели: {e}"
570
+
571
+
572
+ if __name__ == "__main__":
573
+ parser = argparse.ArgumentParser(description="Менеджер моделей")
574
+ subparsers = parser.add_subparsers(
575
+ title="subcommands", dest="command", required=True
576
+ )
577
+
578
+ mvsepless_parser = subparsers.add_parser(
579
+ "mvsepless", help="Скачивание моделей в MVSepLess"
580
+ )
581
+ mvsepless_parser.add_argument("--model_type", required=True, help="Тип модели")
582
+ mvsepless_parser.add_argument("--model_name", required=True, help="Имя модели")
583
+
584
+ vbach_parser = subparsers.add_parser(
585
+ "vbach", help="Установка голосовых моделей в Vbach"
586
+ )
587
+ vbach_subparsers = vbach_parser.add_subparsers(
588
+ title="vbach_commands", dest="vbach_command", required=True
589
+ )
590
+
591
+ install_local_parser = vbach_subparsers.add_parser(
592
+ "install_local", help="Установка голосовой модели по локальным файлам"
593
+ )
594
+ install_local_parser.add_argument(
595
+ "--model_name", required=True, help="Имя голосовой модели"
596
+ )
597
+ install_local_parser.add_argument("--pth", required=True, help="Путь к *.pth файлу")
598
+ install_local_parser.add_argument(
599
+ "--index", required=False, help="Путь к *.index файлу"
600
+ )
601
+
602
+ install_url_zip_parser = vbach_subparsers.add_parser(
603
+ "install_url_zip", help="Установка голосовой модели по URL (архив с файлами)"
604
+ )
605
+ install_url_zip_parser.add_argument(
606
+ "--model_name", required=True, help="Имя голосовой модели"
607
+ )
608
+ install_url_zip_parser.add_argument("--url", required=True, help="URL *.zip файла")
609
+
610
+ install_url_files_parser = vbach_subparsers.add_parser(
611
+ "install_url_files", help="Установка голосовой модели по URL (отдельные файлы)"
612
+ )
613
+ install_url_files_parser.add_argument(
614
+ "--model_name", required=True, help="Имя голосовой модели"
615
+ )
616
+ install_url_files_parser.add_argument(
617
+ "--pth_url", required=True, help="URL *.pth файла"
618
+ )
619
+ install_url_files_parser.add_argument(
620
+ "--index_url", required=False, help="URL *.index файла"
621
+ )
622
+
623
+ list_parser = vbach_subparsers.add_parser(
624
+ "list", help="Список установленных моделей"
625
+ )
626
+
627
+ remove_voice_model = vbach_subparsers.add_parser("remove", help="Удаление модели")
628
+ remove_voice_model.add_argument(
629
+ "--model_name", required=True, help="Имя голосовой модели"
630
+ )
631
+
632
+ args = parser.parse_args()
633
+
634
+ if args.command == "mvsepless":
635
+
636
+ _model_manager = MvseplessModelManager()
637
+ info = _model_manager.models_info.get(args.model_type).get(args.model_name, None)
638
+ if not info:
639
+ raise ValueError(
640
+ f"Модель {args.model_name} не найдена для типа {args.model_type}"
641
+ )
642
+ conf, ckpt = _model_manager.download_model(
643
+ _model_manager.models_cache_dir,
644
+ args.model_name,
645
+ args.model_type,
646
+ info["checkpoint_url"],
647
+ info["config_url"],
648
+ )
649
+
650
+ elif args.command == "vbach":
651
+ model_manager = VbachModelManager()
652
+
653
+ if args.vbach_command == "install_local":
654
+ status = model_manager.install_model_files(
655
+ args.index, args.pth, args.model_name, mode="local"
656
+ )
657
+ print(status)
658
+
659
+ elif args.vbach_command == "install_url_zip":
660
+ status = model_manager.install_model_zip(
661
+ args.url, args.model_name, mode="url"
662
+ )
663
+ print(status)
664
+
665
+ elif args.vbach_command == "install_url_files":
666
+ status = model_manager.install_model_files(
667
+ args.index_url, args.pth_url, args.model_name, mode="url"
668
+ )
669
+ print(status)
670
+
671
+ elif args.vbach_command == "list":
672
+ models = model_manager.parse_voice_models()
673
+ if models:
674
+ print("Установленные модели:")
675
+ for model in models:
676
+ print(f" - {model}")
677
+ else:
678
+ print("Нет установленных моделей")
679
+
680
+ elif args.vbach_command == "remove":
681
+ status = model_manager.del_voice_model(args.model_name)
682
+ print(status)
mvsepless/models.json CHANGED
The diff for this file is too large to render. See raw diff
 
mvsepless/models/bandit/core/__init__.py CHANGED
@@ -1,691 +1,669 @@
1
- import os.path
2
- from collections import defaultdict
3
- from itertools import chain, combinations
4
- from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Type, TypedDict
5
-
6
- import pytorch_lightning as pl
7
- import torch
8
- import torchaudio as ta
9
- import torchmetrics as tm
10
- from asteroid import losses as asteroid_losses
11
-
12
- # from deepspeed.ops.adam import DeepSpeedCPUAdam
13
- # from geoopt import optim as gooptim
14
- from pytorch_lightning.utilities.types import STEP_OUTPUT
15
- from torch import nn, optim
16
- from torch.optim import lr_scheduler
17
- from torch.optim.lr_scheduler import LRScheduler
18
-
19
- from . import loss, metrics as metrics_, model
20
- from .data._types import BatchedDataDict
21
- from .data.augmentation import BaseAugmentor, StemAugmentor
22
- from .utils import audio as audio_
23
- from .utils.audio import BaseFader
24
-
25
- # from pandas.io.json._normalize import nested_to_record
26
-
27
- ConfigDict = TypedDict("ConfigDict", {"name": str, "kwargs": Dict[str, Any]})
28
-
29
-
30
- class SchedulerConfigDict(ConfigDict):
31
- monitor: str
32
-
33
-
34
- OptimizerSchedulerConfigDict = TypedDict(
35
- "OptimizerSchedulerConfigDict",
36
- {"optimizer": ConfigDict, "scheduler": SchedulerConfigDict},
37
- total=False,
38
- )
39
-
40
-
41
- class LRSchedulerReturnDict(TypedDict, total=False):
42
- scheduler: LRScheduler
43
- monitor: str
44
-
45
-
46
- class ConfigureOptimizerReturnDict(TypedDict, total=False):
47
- optimizer: torch.optim.Optimizer
48
- lr_scheduler: LRSchedulerReturnDict
49
-
50
-
51
- OutputType = Dict[str, Any]
52
- MetricsType = Dict[str, torch.Tensor]
53
-
54
-
55
- def get_optimizer_class(name: str) -> Type[optim.Optimizer]:
56
-
57
- if name == "DeepSpeedCPUAdam":
58
- return DeepSpeedCPUAdam
59
-
60
- for module in [optim, gooptim]:
61
- if name in module.__dict__:
62
- return module.__dict__[name]
63
-
64
- raise NameError
65
-
66
-
67
- def parse_optimizer_config(
68
- config: OptimizerSchedulerConfigDict, parameters: Iterator[nn.Parameter]
69
- ) -> ConfigureOptimizerReturnDict:
70
- optim_class = get_optimizer_class(config["optimizer"]["name"])
71
- optimizer = optim_class(parameters, **config["optimizer"]["kwargs"])
72
-
73
- optim_dict: ConfigureOptimizerReturnDict = {
74
- "optimizer": optimizer,
75
- }
76
-
77
- if "scheduler" in config:
78
-
79
- lr_scheduler_class_ = config["scheduler"]["name"]
80
- lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_]
81
- lr_scheduler_dict: LRSchedulerReturnDict = {
82
- "scheduler": lr_scheduler_class(optimizer, **config["scheduler"]["kwargs"])
83
- }
84
-
85
- if lr_scheduler_class_ == "ReduceLROnPlateau":
86
- lr_scheduler_dict["monitor"] = config["scheduler"]["monitor"]
87
-
88
- optim_dict["lr_scheduler"] = lr_scheduler_dict
89
-
90
- return optim_dict
91
-
92
-
93
- def parse_model_config(config: ConfigDict) -> Any:
94
- name = config["name"]
95
-
96
- for module in [model]:
97
- if name in module.__dict__:
98
- return module.__dict__[name](**config["kwargs"])
99
-
100
- raise NameError
101
-
102
-
103
- _LEGACY_LOSS_NAMES = ["HybridL1Loss"]
104
-
105
-
106
- def _parse_legacy_loss_config(config: ConfigDict) -> nn.Module:
107
- name = config["name"]
108
-
109
- if name == "HybridL1Loss":
110
- return loss.TimeFreqL1Loss(**config["kwargs"])
111
-
112
- raise NameError
113
-
114
-
115
- def parse_loss_config(config: ConfigDict) -> nn.Module:
116
- name = config["name"]
117
-
118
- if name in _LEGACY_LOSS_NAMES:
119
- return _parse_legacy_loss_config(config)
120
-
121
- for module in [loss, nn.modules.loss, asteroid_losses]:
122
- if name in module.__dict__:
123
- # print(config["kwargs"])
124
- return module.__dict__[name](**config["kwargs"])
125
-
126
- raise NameError
127
-
128
-
129
- def get_metric(config: ConfigDict) -> tm.Metric:
130
- name = config["name"]
131
-
132
- for module in [tm, metrics_]:
133
- if name in module.__dict__:
134
- return module.__dict__[name](**config["kwargs"])
135
- raise NameError
136
-
137
-
138
- def parse_metric_config(config: Dict[str, ConfigDict]) -> tm.MetricCollection:
139
- metrics = {}
140
-
141
- for metric in config:
142
- metrics[metric] = get_metric(config[metric])
143
-
144
- return tm.MetricCollection(metrics)
145
-
146
-
147
- def parse_fader_config(config: ConfigDict) -> BaseFader:
148
- name = config["name"]
149
-
150
- for module in [audio_]:
151
- if name in module.__dict__:
152
- return module.__dict__[name](**config["kwargs"])
153
-
154
- raise NameError
155
-
156
-
157
- class LightningSystem(pl.LightningModule):
158
- _VOX_STEMS = ["speech", "vocals"]
159
- _BG_STEMS = ["background", "effects", "mne"]
160
-
161
- def __init__(
162
- self, config: Dict, loss_adjustment: float = 1.0, attach_fader: bool = False
163
- ) -> None:
164
- super().__init__()
165
- self.optimizer_config = config["optimizer"]
166
- self.model = parse_model_config(config["model"])
167
- self.loss = parse_loss_config(config["loss"])
168
- self.metrics = nn.ModuleDict(
169
- {
170
- stem: parse_metric_config(config["metrics"]["dev"])
171
- for stem in self.model.stems
172
- }
173
- )
174
-
175
- self.metrics.disallow_fsdp = True
176
-
177
- self.test_metrics = nn.ModuleDict(
178
- {
179
- stem: parse_metric_config(config["metrics"]["test"])
180
- for stem in self.model.stems
181
- }
182
- )
183
-
184
- self.test_metrics.disallow_fsdp = True
185
-
186
- self.fs = config["model"]["kwargs"]["fs"]
187
-
188
- self.fader_config = config["inference"]["fader"]
189
- if attach_fader:
190
- self.fader = parse_fader_config(config["inference"]["fader"])
191
- else:
192
- self.fader = None
193
-
194
- self.augmentation: Optional[BaseAugmentor]
195
- if config.get("augmentation", None) is not None:
196
- self.augmentation = StemAugmentor(**config["augmentation"])
197
- else:
198
- self.augmentation = None
199
-
200
- self.predict_output_path: Optional[str] = None
201
- self.loss_adjustment = loss_adjustment
202
-
203
- self.val_prefix = None
204
- self.test_prefix = None
205
-
206
- def configure_optimizers(self) -> Any:
207
- return parse_optimizer_config(
208
- self.optimizer_config, self.trainer.model.parameters()
209
- )
210
-
211
- def compute_loss(
212
- self, batch: BatchedDataDict, output: OutputType
213
- ) -> Dict[str, torch.Tensor]:
214
- return {"loss": self.loss(output, batch)}
215
-
216
- def update_metrics(
217
- self, batch: BatchedDataDict, output: OutputType, mode: str
218
- ) -> None:
219
-
220
- if mode == "test":
221
- metrics = self.test_metrics
222
- else:
223
- metrics = self.metrics
224
-
225
- for stem, metric in metrics.items():
226
-
227
- if stem == "mne:+":
228
- stem = "mne"
229
-
230
- # print(f"matching for {stem}")
231
- if mode == "train":
232
- metric.update(
233
- output["audio"][stem], # .cpu(),
234
- batch["audio"][stem], # .cpu()
235
- )
236
- else:
237
- if stem not in batch["audio"]:
238
- matched = False
239
- if stem in self._VOX_STEMS:
240
- for bstem in self._VOX_STEMS:
241
- if bstem in batch["audio"]:
242
- batch["audio"][stem] = batch["audio"][bstem]
243
- matched = True
244
- break
245
- elif stem in self._BG_STEMS:
246
- for bstem in self._BG_STEMS:
247
- if bstem in batch["audio"]:
248
- batch["audio"][stem] = batch["audio"][bstem]
249
- matched = True
250
- break
251
- else:
252
- matched = True
253
-
254
- # print(batch["audio"].keys())
255
-
256
- if matched:
257
- # print(f"matched {stem}!")
258
- if stem == "mne" and "mne" not in output["audio"]:
259
- output["audio"]["mne"] = (
260
- output["audio"]["music"] + output["audio"]["effects"]
261
- )
262
-
263
- metric.update(
264
- output["audio"][stem], # .cpu(),
265
- batch["audio"][stem], # .cpu(),
266
- )
267
-
268
- # print(metric.compute())
269
-
270
- def compute_metrics(self, mode: str = "dev") -> Dict[str, torch.Tensor]:
271
-
272
- if mode == "test":
273
- metrics = self.test_metrics
274
- else:
275
- metrics = self.metrics
276
-
277
- metric_dict = {}
278
-
279
- for stem, metric in metrics.items():
280
- md = metric.compute()
281
- metric_dict.update({f"{stem}/{k}": v for k, v in md.items()})
282
-
283
- self.log_dict(metric_dict, prog_bar=True, logger=False)
284
-
285
- return metric_dict
286
-
287
- def reset_metrics(self, test_mode: bool = False) -> None:
288
-
289
- if test_mode:
290
- metrics = self.test_metrics
291
- else:
292
- metrics = self.metrics
293
-
294
- for _, metric in metrics.items():
295
- metric.reset()
296
-
297
- def forward(self, batch: BatchedDataDict) -> Any:
298
- batch, output = self.model(batch)
299
-
300
- return batch, output
301
-
302
- def common_step(self, batch: BatchedDataDict, mode: str) -> Any:
303
- batch, output = self.forward(batch)
304
- # print(batch)
305
- # print(output)
306
- loss_dict = self.compute_loss(batch, output)
307
-
308
- with torch.no_grad():
309
- self.update_metrics(batch, output, mode=mode)
310
-
311
- if mode == "train":
312
- self.log("loss", loss_dict["loss"], prog_bar=True)
313
-
314
- return output, loss_dict
315
-
316
- def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]:
317
-
318
- if self.augmentation is not None:
319
- with torch.no_grad():
320
- batch = self.augmentation(batch)
321
-
322
- _, loss_dict = self.common_step(batch, mode="train")
323
-
324
- with torch.inference_mode():
325
- self.log_dict_with_prefix(
326
- loss_dict, "train", batch_size=batch["audio"]["mixture"].shape[0]
327
- )
328
-
329
- loss_dict["loss"] *= self.loss_adjustment
330
-
331
- return loss_dict
332
-
333
- def on_train_batch_end(
334
- self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int
335
- ) -> None:
336
-
337
- metric_dict = self.compute_metrics()
338
- self.log_dict_with_prefix(metric_dict, "train")
339
- self.reset_metrics()
340
-
341
- def validation_step(
342
- self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
343
- ) -> Dict[str, Any]:
344
-
345
- with torch.inference_mode():
346
- curr_val_prefix = f"val{dataloader_idx}" if dataloader_idx > 0 else "val"
347
-
348
- if curr_val_prefix != self.val_prefix:
349
- # print(f"Switching to validation dataloader {dataloader_idx}")
350
- if self.val_prefix is not None:
351
- self._on_validation_epoch_end()
352
- self.val_prefix = curr_val_prefix
353
- _, loss_dict = self.common_step(batch, mode="val")
354
-
355
- self.log_dict_with_prefix(
356
- loss_dict,
357
- self.val_prefix,
358
- batch_size=batch["audio"]["mixture"].shape[0],
359
- prog_bar=True,
360
- add_dataloader_idx=False,
361
- )
362
-
363
- return loss_dict
364
-
365
- def on_validation_epoch_end(self) -> None:
366
- self._on_validation_epoch_end()
367
-
368
- def _on_validation_epoch_end(self) -> None:
369
- metric_dict = self.compute_metrics()
370
- self.log_dict_with_prefix(
371
- metric_dict, self.val_prefix, prog_bar=True, add_dataloader_idx=False
372
- )
373
- # self.logger.save()
374
- # print(self.val_prefix, "Validation metrics:", metric_dict)
375
- self.reset_metrics()
376
-
377
- def old_predtest_step(
378
- self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
379
- ) -> Tuple[BatchedDataDict, OutputType]:
380
-
381
- audio_batch = batch["audio"]["mixture"]
382
- track_batch = batch.get("track", ["" for _ in range(len(audio_batch))])
383
-
384
- output_list_of_dicts = [
385
- self.fader(audio[None, ...], lambda a: self.test_forward(a, track))
386
- for audio, track in zip(audio_batch, track_batch)
387
- ]
388
-
389
- output_dict_of_lists = defaultdict(list)
390
-
391
- for output_dict in output_list_of_dicts:
392
- for stem, audio in output_dict.items():
393
- output_dict_of_lists[stem].append(audio)
394
-
395
- output = {
396
- "audio": {
397
- stem: torch.concat(output_list, dim=0)
398
- for stem, output_list in output_dict_of_lists.items()
399
- }
400
- }
401
-
402
- return batch, output
403
-
404
- def predtest_step(
405
- self, batch: BatchedDataDict, batch_idx: int = -1, dataloader_idx: int = 0
406
- ) -> Tuple[BatchedDataDict, OutputType]:
407
-
408
- if getattr(self.model, "bypass_fader", False):
409
- batch, output = self.model(batch)
410
- else:
411
- audio_batch = batch["audio"]["mixture"]
412
- output = self.fader(
413
- audio_batch, lambda a: self.test_forward(a, "", batch=batch)
414
- )
415
-
416
- return batch, output
417
-
418
- def test_forward(
419
- self, audio: torch.Tensor, track: str = "", batch: BatchedDataDict = None
420
- ) -> torch.Tensor:
421
-
422
- if self.fader is None:
423
- self.attach_fader()
424
-
425
- cond = batch.get("condition", None)
426
-
427
- if cond is not None and cond.shape[0] == 1:
428
- cond = cond.repeat(audio.shape[0], 1)
429
-
430
- _, output = self.forward(
431
- {
432
- "audio": {"mixture": audio},
433
- "track": track,
434
- "condition": cond,
435
- }
436
- ) # TODO: support track properly
437
-
438
- return output["audio"]
439
-
440
- def on_test_epoch_start(self) -> None:
441
- self.attach_fader(force_reattach=True)
442
-
443
- def test_step(
444
- self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
445
- ) -> Any:
446
- curr_test_prefix = f"test{dataloader_idx}"
447
-
448
- # print(batch["audio"].keys())
449
-
450
- if curr_test_prefix != self.test_prefix:
451
- # print(f"Switching to test dataloader {dataloader_idx}")
452
- if self.test_prefix is not None:
453
- self._on_test_epoch_end()
454
- self.test_prefix = curr_test_prefix
455
-
456
- with torch.inference_mode():
457
- _, output = self.predtest_step(batch, batch_idx, dataloader_idx)
458
- # print(output)
459
- self.update_metrics(batch, output, mode="test")
460
-
461
- return output
462
-
463
- def on_test_epoch_end(self) -> None:
464
- self._on_test_epoch_end()
465
-
466
- def _on_test_epoch_end(self) -> None:
467
- metric_dict = self.compute_metrics(mode="test")
468
- self.log_dict_with_prefix(
469
- metric_dict, self.test_prefix, prog_bar=True, add_dataloader_idx=False
470
- )
471
- # self.logger.save()
472
- # print(self.test_prefix, "Test metrics:", metric_dict)
473
- self.reset_metrics()
474
-
475
- def predict_step(
476
- self,
477
- batch: BatchedDataDict,
478
- batch_idx: int = 0,
479
- dataloader_idx: int = 0,
480
- include_track_name: Optional[bool] = None,
481
- get_no_vox_combinations: bool = True,
482
- get_residual: bool = False,
483
- treat_batch_as_channels: bool = False,
484
- fs: Optional[int] = None,
485
- ) -> Any:
486
- assert self.predict_output_path is not None
487
-
488
- batch_size = batch["audio"]["mixture"].shape[0]
489
-
490
- if include_track_name is None:
491
- include_track_name = batch_size > 1
492
-
493
- with torch.inference_mode():
494
- batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
495
- print("Pred test finished...")
496
- torch.cuda.empty_cache()
497
- metric_dict = {}
498
-
499
- if get_residual:
500
- mixture = batch["audio"]["mixture"]
501
- extracted = sum([output["audio"][stem] for stem in output["audio"]])
502
- residual = mixture - extracted
503
- print(extracted.shape, mixture.shape, residual.shape)
504
-
505
- output["audio"]["residual"] = residual
506
-
507
- if get_no_vox_combinations:
508
- no_vox_stems = [
509
- stem for stem in output["audio"] if stem not in self._VOX_STEMS
510
- ]
511
- no_vox_combinations = chain.from_iterable(
512
- combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
513
- )
514
-
515
- for combination in no_vox_combinations:
516
- combination_ = list(combination)
517
- output["audio"]["+".join(combination_)] = sum(
518
- [output["audio"][stem] for stem in combination_]
519
- )
520
-
521
- if treat_batch_as_channels:
522
- for stem in output["audio"]:
523
- output["audio"][stem] = output["audio"][stem].reshape(
524
- 1, -1, output["audio"][stem].shape[-1]
525
- )
526
- batch_size = 1
527
-
528
- for b in range(batch_size):
529
- print("!!", b)
530
- for stem in output["audio"]:
531
- print(f"Saving audio for {stem} to {self.predict_output_path}")
532
- track_name = batch["track"][b].split("/")[-1]
533
-
534
- if batch.get("audio", {}).get(stem, None) is not None:
535
- self.test_metrics[stem].reset()
536
- metrics = self.test_metrics[stem](
537
- batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
538
- )
539
- snr = metrics["snr"]
540
- sisnr = metrics["sisnr"]
541
- sdr = metrics["sdr"]
542
- metric_dict[stem] = metrics
543
- print(
544
- track_name,
545
- f"snr={snr:2.2f} dB",
546
- f"sisnr={sisnr:2.2f}",
547
- f"sdr={sdr:2.2f} dB",
548
- )
549
- filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
550
- else:
551
- filename = f"{stem}.wav"
552
-
553
- if include_track_name:
554
- output_dir = os.path.join(self.predict_output_path, track_name)
555
- else:
556
- output_dir = self.predict_output_path
557
-
558
- os.makedirs(output_dir, exist_ok=True)
559
-
560
- if fs is None:
561
- fs = self.fs
562
-
563
- ta.save(
564
- os.path.join(output_dir, filename),
565
- output["audio"][stem][b, ...].cpu(),
566
- fs,
567
- )
568
-
569
- return metric_dict
570
-
571
- def get_stems(
572
- self,
573
- batch: BatchedDataDict,
574
- batch_idx: int = 0,
575
- dataloader_idx: int = 0,
576
- include_track_name: Optional[bool] = None,
577
- get_no_vox_combinations: bool = True,
578
- get_residual: bool = False,
579
- treat_batch_as_channels: bool = False,
580
- fs: Optional[int] = None,
581
- ) -> Any:
582
- assert self.predict_output_path is not None
583
-
584
- batch_size = batch["audio"]["mixture"].shape[0]
585
-
586
- if include_track_name is None:
587
- include_track_name = batch_size > 1
588
-
589
- with torch.inference_mode():
590
- batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
591
- torch.cuda.empty_cache()
592
- metric_dict = {}
593
-
594
- if get_residual:
595
- mixture = batch["audio"]["mixture"]
596
- extracted = sum([output["audio"][stem] for stem in output["audio"]])
597
- residual = mixture - extracted
598
- # print(extracted.shape, mixture.shape, residual.shape)
599
-
600
- output["audio"]["residual"] = residual
601
-
602
- if get_no_vox_combinations:
603
- no_vox_stems = [
604
- stem for stem in output["audio"] if stem not in self._VOX_STEMS
605
- ]
606
- no_vox_combinations = chain.from_iterable(
607
- combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
608
- )
609
-
610
- for combination in no_vox_combinations:
611
- combination_ = list(combination)
612
- output["audio"]["+".join(combination_)] = sum(
613
- [output["audio"][stem] for stem in combination_]
614
- )
615
-
616
- if treat_batch_as_channels:
617
- for stem in output["audio"]:
618
- output["audio"][stem] = output["audio"][stem].reshape(
619
- 1, -1, output["audio"][stem].shape[-1]
620
- )
621
- batch_size = 1
622
-
623
- result = {}
624
- for b in range(batch_size):
625
- for stem in output["audio"]:
626
- track_name = batch["track"][b].split("/")[-1]
627
-
628
- if batch.get("audio", {}).get(stem, None) is not None:
629
- self.test_metrics[stem].reset()
630
- metrics = self.test_metrics[stem](
631
- batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
632
- )
633
- snr = metrics["snr"]
634
- sisnr = metrics["sisnr"]
635
- sdr = metrics["sdr"]
636
- metric_dict[stem] = metrics
637
- print(
638
- track_name,
639
- f"snr={snr:2.2f} dB",
640
- f"sisnr={sisnr:2.2f}",
641
- f"sdr={sdr:2.2f} dB",
642
- )
643
- filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
644
- else:
645
- filename = f"{stem}.wav"
646
-
647
- if include_track_name:
648
- output_dir = os.path.join(self.predict_output_path, track_name)
649
- else:
650
- output_dir = self.predict_output_path
651
-
652
- os.makedirs(output_dir, exist_ok=True)
653
-
654
- if fs is None:
655
- fs = self.fs
656
-
657
- result[stem] = output["audio"][stem][b, ...].cpu().numpy()
658
-
659
- return result
660
-
661
- def load_state_dict(
662
- self, state_dict: Mapping[str, Any], strict: bool = False
663
- ) -> Any:
664
-
665
- return super().load_state_dict(state_dict, strict=False)
666
-
667
- def set_predict_output_path(self, path: str) -> None:
668
- self.predict_output_path = path
669
- os.makedirs(self.predict_output_path, exist_ok=True)
670
-
671
- self.attach_fader()
672
-
673
- def attach_fader(self, force_reattach=False) -> None:
674
- if self.fader is None or force_reattach:
675
- self.fader = parse_fader_config(self.fader_config)
676
- self.fader.to(self.device)
677
-
678
- def log_dict_with_prefix(
679
- self,
680
- dict_: Dict[str, torch.Tensor],
681
- prefix: str,
682
- batch_size: Optional[int] = None,
683
- **kwargs: Any,
684
- ) -> None:
685
- self.log_dict(
686
- {f"{prefix}/{k}": v for k, v in dict_.items()},
687
- batch_size=batch_size,
688
- logger=True,
689
- sync_dist=True,
690
- **kwargs,
691
- )
 
1
+ import os.path
2
+ from collections import defaultdict
3
+ from itertools import chain, combinations
4
+ from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Type, TypedDict
5
+
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ import torchaudio as ta
9
+ import torchmetrics as tm
10
+ from asteroid import losses as asteroid_losses
11
+
12
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
13
+ from torch import nn, optim
14
+ from torch.optim import lr_scheduler
15
+ from torch.optim.lr_scheduler import LRScheduler
16
+
17
+ from . import loss, metrics as metrics_, model
18
+ from .data._types import BatchedDataDict
19
+ from .data.augmentation import BaseAugmentor, StemAugmentor
20
+ from .utils import audio as audio_
21
+ from .utils.audio import BaseFader
22
+
23
+
24
+ ConfigDict = TypedDict("ConfigDict", {"name": str, "kwargs": Dict[str, Any]})
25
+
26
+
27
+ class SchedulerConfigDict(ConfigDict):
28
+ monitor: str
29
+
30
+
31
+ OptimizerSchedulerConfigDict = TypedDict(
32
+ "OptimizerSchedulerConfigDict",
33
+ {"optimizer": ConfigDict, "scheduler": SchedulerConfigDict},
34
+ total=False,
35
+ )
36
+
37
+
38
+ class LRSchedulerReturnDict(TypedDict, total=False):
39
+ scheduler: LRScheduler
40
+ monitor: str
41
+
42
+
43
+ class ConfigureOptimizerReturnDict(TypedDict, total=False):
44
+ optimizer: torch.optim.Optimizer
45
+ lr_scheduler: LRSchedulerReturnDict
46
+
47
+
48
+ OutputType = Dict[str, Any]
49
+ MetricsType = Dict[str, torch.Tensor]
50
+
51
+
52
+ def get_optimizer_class(name: str) -> Type[optim.Optimizer]:
53
+
54
+ if name == "DeepSpeedCPUAdam":
55
+ return DeepSpeedCPUAdam
56
+
57
+ for module in [optim, gooptim]:
58
+ if name in module.__dict__:
59
+ return module.__dict__[name]
60
+
61
+ raise NameError
62
+
63
+
64
+ def parse_optimizer_config(
65
+ config: OptimizerSchedulerConfigDict, parameters: Iterator[nn.Parameter]
66
+ ) -> ConfigureOptimizerReturnDict:
67
+ optim_class = get_optimizer_class(config["optimizer"]["name"])
68
+ optimizer = optim_class(parameters, **config["optimizer"]["kwargs"])
69
+
70
+ optim_dict: ConfigureOptimizerReturnDict = {
71
+ "optimizer": optimizer,
72
+ }
73
+
74
+ if "scheduler" in config:
75
+
76
+ lr_scheduler_class_ = config["scheduler"]["name"]
77
+ lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_]
78
+ lr_scheduler_dict: LRSchedulerReturnDict = {
79
+ "scheduler": lr_scheduler_class(optimizer, **config["scheduler"]["kwargs"])
80
+ }
81
+
82
+ if lr_scheduler_class_ == "ReduceLROnPlateau":
83
+ lr_scheduler_dict["monitor"] = config["scheduler"]["monitor"]
84
+
85
+ optim_dict["lr_scheduler"] = lr_scheduler_dict
86
+
87
+ return optim_dict
88
+
89
+
90
+ def parse_model_config(config: ConfigDict) -> Any:
91
+ name = config["name"]
92
+
93
+ for module in [model]:
94
+ if name in module.__dict__:
95
+ return module.__dict__[name](**config["kwargs"])
96
+
97
+ raise NameError
98
+
99
+
100
+ _LEGACY_LOSS_NAMES = ["HybridL1Loss"]
101
+
102
+
103
+ def _parse_legacy_loss_config(config: ConfigDict) -> nn.Module:
104
+ name = config["name"]
105
+
106
+ if name == "HybridL1Loss":
107
+ return loss.TimeFreqL1Loss(**config["kwargs"])
108
+
109
+ raise NameError
110
+
111
+
112
+ def parse_loss_config(config: ConfigDict) -> nn.Module:
113
+ name = config["name"]
114
+
115
+ if name in _LEGACY_LOSS_NAMES:
116
+ return _parse_legacy_loss_config(config)
117
+
118
+ for module in [loss, nn.modules.loss, asteroid_losses]:
119
+ if name in module.__dict__:
120
+ return module.__dict__[name](**config["kwargs"])
121
+
122
+ raise NameError
123
+
124
+
125
+ def get_metric(config: ConfigDict) -> tm.Metric:
126
+ name = config["name"]
127
+
128
+ for module in [tm, metrics_]:
129
+ if name in module.__dict__:
130
+ return module.__dict__[name](**config["kwargs"])
131
+ raise NameError
132
+
133
+
134
+ def parse_metric_config(config: Dict[str, ConfigDict]) -> tm.MetricCollection:
135
+ metrics = {}
136
+
137
+ for metric in config:
138
+ metrics[metric] = get_metric(config[metric])
139
+
140
+ return tm.MetricCollection(metrics)
141
+
142
+
143
+ def parse_fader_config(config: ConfigDict) -> BaseFader:
144
+ name = config["name"]
145
+
146
+ for module in [audio_]:
147
+ if name in module.__dict__:
148
+ return module.__dict__[name](**config["kwargs"])
149
+
150
+ raise NameError
151
+
152
+
153
+ class LightningSystem(pl.LightningModule):
154
+ _VOX_STEMS = ["speech", "vocals"]
155
+ _BG_STEMS = ["background", "effects", "mne"]
156
+
157
+ def __init__(
158
+ self, config: Dict, loss_adjustment: float = 1.0, attach_fader: bool = False
159
+ ) -> None:
160
+ super().__init__()
161
+ self.optimizer_config = config["optimizer"]
162
+ self.model = parse_model_config(config["model"])
163
+ self.loss = parse_loss_config(config["loss"])
164
+ self.metrics = nn.ModuleDict(
165
+ {
166
+ stem: parse_metric_config(config["metrics"]["dev"])
167
+ for stem in self.model.stems
168
+ }
169
+ )
170
+
171
+ self.metrics.disallow_fsdp = True
172
+
173
+ self.test_metrics = nn.ModuleDict(
174
+ {
175
+ stem: parse_metric_config(config["metrics"]["test"])
176
+ for stem in self.model.stems
177
+ }
178
+ )
179
+
180
+ self.test_metrics.disallow_fsdp = True
181
+
182
+ self.fs = config["model"]["kwargs"]["fs"]
183
+
184
+ self.fader_config = config["inference"]["fader"]
185
+ if attach_fader:
186
+ self.fader = parse_fader_config(config["inference"]["fader"])
187
+ else:
188
+ self.fader = None
189
+
190
+ self.augmentation: Optional[BaseAugmentor]
191
+ if config.get("augmentation", None) is not None:
192
+ self.augmentation = StemAugmentor(**config["augmentation"])
193
+ else:
194
+ self.augmentation = None
195
+
196
+ self.predict_output_path: Optional[str] = None
197
+ self.loss_adjustment = loss_adjustment
198
+
199
+ self.val_prefix = None
200
+ self.test_prefix = None
201
+
202
+ def configure_optimizers(self) -> Any:
203
+ return parse_optimizer_config(
204
+ self.optimizer_config, self.trainer.model.parameters()
205
+ )
206
+
207
+ def compute_loss(
208
+ self, batch: BatchedDataDict, output: OutputType
209
+ ) -> Dict[str, torch.Tensor]:
210
+ return {"loss": self.loss(output, batch)}
211
+
212
+ def update_metrics(
213
+ self, batch: BatchedDataDict, output: OutputType, mode: str
214
+ ) -> None:
215
+
216
+ if mode == "test":
217
+ metrics = self.test_metrics
218
+ else:
219
+ metrics = self.metrics
220
+
221
+ for stem, metric in metrics.items():
222
+
223
+ if stem == "mne:+":
224
+ stem = "mne"
225
+
226
+ if mode == "train":
227
+ metric.update(
228
+ output["audio"][stem],
229
+ batch["audio"][stem],
230
+ )
231
+ else:
232
+ if stem not in batch["audio"]:
233
+ matched = False
234
+ if stem in self._VOX_STEMS:
235
+ for bstem in self._VOX_STEMS:
236
+ if bstem in batch["audio"]:
237
+ batch["audio"][stem] = batch["audio"][bstem]
238
+ matched = True
239
+ break
240
+ elif stem in self._BG_STEMS:
241
+ for bstem in self._BG_STEMS:
242
+ if bstem in batch["audio"]:
243
+ batch["audio"][stem] = batch["audio"][bstem]
244
+ matched = True
245
+ break
246
+ else:
247
+ matched = True
248
+
249
+ if matched:
250
+ if stem == "mne" and "mne" not in output["audio"]:
251
+ output["audio"]["mne"] = (
252
+ output["audio"]["music"] + output["audio"]["effects"]
253
+ )
254
+
255
+ metric.update(
256
+ output["audio"][stem],
257
+ batch["audio"][stem],
258
+ )
259
+
260
+ def compute_metrics(self, mode: str = "dev") -> Dict[str, torch.Tensor]:
261
+
262
+ if mode == "test":
263
+ metrics = self.test_metrics
264
+ else:
265
+ metrics = self.metrics
266
+
267
+ metric_dict = {}
268
+
269
+ for stem, metric in metrics.items():
270
+ md = metric.compute()
271
+ metric_dict.update({f"{stem}/{k}": v for k, v in md.items()})
272
+
273
+ self.log_dict(metric_dict, prog_bar=True, logger=False)
274
+
275
+ return metric_dict
276
+
277
+ def reset_metrics(self, test_mode: bool = False) -> None:
278
+
279
+ if test_mode:
280
+ metrics = self.test_metrics
281
+ else:
282
+ metrics = self.metrics
283
+
284
+ for _, metric in metrics.items():
285
+ metric.reset()
286
+
287
+ def forward(self, batch: BatchedDataDict) -> Any:
288
+ batch, output = self.model(batch)
289
+
290
+ return batch, output
291
+
292
+ def common_step(self, batch: BatchedDataDict, mode: str) -> Any:
293
+ batch, output = self.forward(batch)
294
+ loss_dict = self.compute_loss(batch, output)
295
+
296
+ with torch.no_grad():
297
+ self.update_metrics(batch, output, mode=mode)
298
+
299
+ if mode == "train":
300
+ self.log("loss", loss_dict["loss"], prog_bar=True)
301
+
302
+ return output, loss_dict
303
+
304
+ def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]:
305
+
306
+ if self.augmentation is not None:
307
+ with torch.no_grad():
308
+ batch = self.augmentation(batch)
309
+
310
+ _, loss_dict = self.common_step(batch, mode="train")
311
+
312
+ with torch.inference_mode():
313
+ self.log_dict_with_prefix(
314
+ loss_dict, "train", batch_size=batch["audio"]["mixture"].shape[0]
315
+ )
316
+
317
+ loss_dict["loss"] *= self.loss_adjustment
318
+
319
+ return loss_dict
320
+
321
+ def on_train_batch_end(
322
+ self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int
323
+ ) -> None:
324
+
325
+ metric_dict = self.compute_metrics()
326
+ self.log_dict_with_prefix(metric_dict, "train")
327
+ self.reset_metrics()
328
+
329
+ def validation_step(
330
+ self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
331
+ ) -> Dict[str, Any]:
332
+
333
+ with torch.inference_mode():
334
+ curr_val_prefix = f"val{dataloader_idx}" if dataloader_idx > 0 else "val"
335
+
336
+ if curr_val_prefix != self.val_prefix:
337
+ if self.val_prefix is not None:
338
+ self._on_validation_epoch_end()
339
+ self.val_prefix = curr_val_prefix
340
+ _, loss_dict = self.common_step(batch, mode="val")
341
+
342
+ self.log_dict_with_prefix(
343
+ loss_dict,
344
+ self.val_prefix,
345
+ batch_size=batch["audio"]["mixture"].shape[0],
346
+ prog_bar=True,
347
+ add_dataloader_idx=False,
348
+ )
349
+
350
+ return loss_dict
351
+
352
+ def on_validation_epoch_end(self) -> None:
353
+ self._on_validation_epoch_end()
354
+
355
+ def _on_validation_epoch_end(self) -> None:
356
+ metric_dict = self.compute_metrics()
357
+ self.log_dict_with_prefix(
358
+ metric_dict, self.val_prefix, prog_bar=True, add_dataloader_idx=False
359
+ )
360
+ self.reset_metrics()
361
+
362
+ def old_predtest_step(
363
+ self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
364
+ ) -> Tuple[BatchedDataDict, OutputType]:
365
+
366
+ audio_batch = batch["audio"]["mixture"]
367
+ track_batch = batch.get("track", ["" for _ in range(len(audio_batch))])
368
+
369
+ output_list_of_dicts = [
370
+ self.fader(audio[None, ...], lambda a: self.test_forward(a, track))
371
+ for audio, track in zip(audio_batch, track_batch)
372
+ ]
373
+
374
+ output_dict_of_lists = defaultdict(list)
375
+
376
+ for output_dict in output_list_of_dicts:
377
+ for stem, audio in output_dict.items():
378
+ output_dict_of_lists[stem].append(audio)
379
+
380
+ output = {
381
+ "audio": {
382
+ stem: torch.concat(output_list, dim=0)
383
+ for stem, output_list in output_dict_of_lists.items()
384
+ }
385
+ }
386
+
387
+ return batch, output
388
+
389
+ def predtest_step(
390
+ self, batch: BatchedDataDict, batch_idx: int = -1, dataloader_idx: int = 0
391
+ ) -> Tuple[BatchedDataDict, OutputType]:
392
+
393
+ if getattr(self.model, "bypass_fader", False):
394
+ batch, output = self.model(batch)
395
+ else:
396
+ audio_batch = batch["audio"]["mixture"]
397
+ output = self.fader(
398
+ audio_batch, lambda a: self.test_forward(a, "", batch=batch)
399
+ )
400
+
401
+ return batch, output
402
+
403
+ def test_forward(
404
+ self, audio: torch.Tensor, track: str = "", batch: BatchedDataDict = None
405
+ ) -> torch.Tensor:
406
+
407
+ if self.fader is None:
408
+ self.attach_fader()
409
+
410
+ cond = batch.get("condition", None)
411
+
412
+ if cond is not None and cond.shape[0] == 1:
413
+ cond = cond.repeat(audio.shape[0], 1)
414
+
415
+ _, output = self.forward(
416
+ {
417
+ "audio": {"mixture": audio},
418
+ "track": track,
419
+ "condition": cond,
420
+ }
421
+ )
422
+
423
+ return output["audio"]
424
+
425
+ def on_test_epoch_start(self) -> None:
426
+ self.attach_fader(force_reattach=True)
427
+
428
+ def test_step(
429
+ self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
430
+ ) -> Any:
431
+ curr_test_prefix = f"test{dataloader_idx}"
432
+
433
+ if curr_test_prefix != self.test_prefix:
434
+ if self.test_prefix is not None:
435
+ self._on_test_epoch_end()
436
+ self.test_prefix = curr_test_prefix
437
+
438
+ with torch.inference_mode():
439
+ _, output = self.predtest_step(batch, batch_idx, dataloader_idx)
440
+ self.update_metrics(batch, output, mode="test")
441
+
442
+ return output
443
+
444
+ def on_test_epoch_end(self) -> None:
445
+ self._on_test_epoch_end()
446
+
447
+ def _on_test_epoch_end(self) -> None:
448
+ metric_dict = self.compute_metrics(mode="test")
449
+ self.log_dict_with_prefix(
450
+ metric_dict, self.test_prefix, prog_bar=True, add_dataloader_idx=False
451
+ )
452
+ self.reset_metrics()
453
+
454
+ def predict_step(
455
+ self,
456
+ batch: BatchedDataDict,
457
+ batch_idx: int = 0,
458
+ dataloader_idx: int = 0,
459
+ include_track_name: Optional[bool] = None,
460
+ get_no_vox_combinations: bool = True,
461
+ get_residual: bool = False,
462
+ treat_batch_as_channels: bool = False,
463
+ fs: Optional[int] = None,
464
+ ) -> Any:
465
+ assert self.predict_output_path is not None
466
+
467
+ batch_size = batch["audio"]["mixture"].shape[0]
468
+
469
+ if include_track_name is None:
470
+ include_track_name = batch_size > 1
471
+
472
+ with torch.inference_mode():
473
+ batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
474
+ print("Pred test finished...")
475
+ torch.cuda.empty_cache()
476
+ metric_dict = {}
477
+
478
+ if get_residual:
479
+ mixture = batch["audio"]["mixture"]
480
+ extracted = sum([output["audio"][stem] for stem in output["audio"]])
481
+ residual = mixture - extracted
482
+ print(extracted.shape, mixture.shape, residual.shape)
483
+
484
+ output["audio"]["residual"] = residual
485
+
486
+ if get_no_vox_combinations:
487
+ no_vox_stems = [
488
+ stem for stem in output["audio"] if stem not in self._VOX_STEMS
489
+ ]
490
+ no_vox_combinations = chain.from_iterable(
491
+ combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
492
+ )
493
+
494
+ for combination in no_vox_combinations:
495
+ combination_ = list(combination)
496
+ output["audio"]["+".join(combination_)] = sum(
497
+ [output["audio"][stem] for stem in combination_]
498
+ )
499
+
500
+ if treat_batch_as_channels:
501
+ for stem in output["audio"]:
502
+ output["audio"][stem] = output["audio"][stem].reshape(
503
+ 1, -1, output["audio"][stem].shape[-1]
504
+ )
505
+ batch_size = 1
506
+
507
+ for b in range(batch_size):
508
+ print("!!", b)
509
+ for stem in output["audio"]:
510
+ print(f"Saving audio for {stem} to {self.predict_output_path}")
511
+ track_name = batch["track"][b].split("/")[-1]
512
+
513
+ if batch.get("audio", {}).get(stem, None) is not None:
514
+ self.test_metrics[stem].reset()
515
+ metrics = self.test_metrics[stem](
516
+ batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
517
+ )
518
+ snr = metrics["snr"]
519
+ sisnr = metrics["sisnr"]
520
+ sdr = metrics["sdr"]
521
+ metric_dict[stem] = metrics
522
+ print(
523
+ track_name,
524
+ f"snr={snr:2.2f} dB",
525
+ f"sisnr={sisnr:2.2f}",
526
+ f"sdr={sdr:2.2f} dB",
527
+ )
528
+ filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
529
+ else:
530
+ filename = f"{stem}.wav"
531
+
532
+ if include_track_name:
533
+ output_dir = os.path.join(self.predict_output_path, track_name)
534
+ else:
535
+ output_dir = self.predict_output_path
536
+
537
+ os.makedirs(output_dir, exist_ok=True)
538
+
539
+ if fs is None:
540
+ fs = self.fs
541
+
542
+ ta.save(
543
+ os.path.join(output_dir, filename),
544
+ output["audio"][stem][b, ...].cpu(),
545
+ fs,
546
+ )
547
+
548
+ return metric_dict
549
+
550
+ def get_stems(
551
+ self,
552
+ batch: BatchedDataDict,
553
+ batch_idx: int = 0,
554
+ dataloader_idx: int = 0,
555
+ include_track_name: Optional[bool] = None,
556
+ get_no_vox_combinations: bool = True,
557
+ get_residual: bool = False,
558
+ treat_batch_as_channels: bool = False,
559
+ fs: Optional[int] = None,
560
+ ) -> Any:
561
+ assert self.predict_output_path is not None
562
+
563
+ batch_size = batch["audio"]["mixture"].shape[0]
564
+
565
+ if include_track_name is None:
566
+ include_track_name = batch_size > 1
567
+
568
+ with torch.inference_mode():
569
+ batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
570
+ torch.cuda.empty_cache()
571
+ metric_dict = {}
572
+
573
+ if get_residual:
574
+ mixture = batch["audio"]["mixture"]
575
+ extracted = sum([output["audio"][stem] for stem in output["audio"]])
576
+ residual = mixture - extracted
577
+
578
+ output["audio"]["residual"] = residual
579
+
580
+ if get_no_vox_combinations:
581
+ no_vox_stems = [
582
+ stem for stem in output["audio"] if stem not in self._VOX_STEMS
583
+ ]
584
+ no_vox_combinations = chain.from_iterable(
585
+ combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
586
+ )
587
+
588
+ for combination in no_vox_combinations:
589
+ combination_ = list(combination)
590
+ output["audio"]["+".join(combination_)] = sum(
591
+ [output["audio"][stem] for stem in combination_]
592
+ )
593
+
594
+ if treat_batch_as_channels:
595
+ for stem in output["audio"]:
596
+ output["audio"][stem] = output["audio"][stem].reshape(
597
+ 1, -1, output["audio"][stem].shape[-1]
598
+ )
599
+ batch_size = 1
600
+
601
+ result = {}
602
+ for b in range(batch_size):
603
+ for stem in output["audio"]:
604
+ track_name = batch["track"][b].split("/")[-1]
605
+
606
+ if batch.get("audio", {}).get(stem, None) is not None:
607
+ self.test_metrics[stem].reset()
608
+ metrics = self.test_metrics[stem](
609
+ batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
610
+ )
611
+ snr = metrics["snr"]
612
+ sisnr = metrics["sisnr"]
613
+ sdr = metrics["sdr"]
614
+ metric_dict[stem] = metrics
615
+ print(
616
+ track_name,
617
+ f"snr={snr:2.2f} dB",
618
+ f"sisnr={sisnr:2.2f}",
619
+ f"sdr={sdr:2.2f} dB",
620
+ )
621
+ filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
622
+ else:
623
+ filename = f"{stem}.wav"
624
+
625
+ if include_track_name:
626
+ output_dir = os.path.join(self.predict_output_path, track_name)
627
+ else:
628
+ output_dir = self.predict_output_path
629
+
630
+ os.makedirs(output_dir, exist_ok=True)
631
+
632
+ if fs is None:
633
+ fs = self.fs
634
+
635
+ result[stem] = output["audio"][stem][b, ...].cpu().numpy()
636
+
637
+ return result
638
+
639
+ def load_state_dict(
640
+ self, state_dict: Mapping[str, Any], strict: bool = False
641
+ ) -> Any:
642
+
643
+ return super().load_state_dict(state_dict, strict=False)
644
+
645
+ def set_predict_output_path(self, path: str) -> None:
646
+ self.predict_output_path = path
647
+ os.makedirs(self.predict_output_path, exist_ok=True)
648
+
649
+ self.attach_fader()
650
+
651
+ def attach_fader(self, force_reattach=False) -> None:
652
+ if self.fader is None or force_reattach:
653
+ self.fader = parse_fader_config(self.fader_config)
654
+ self.fader.to(self.device)
655
+
656
+ def log_dict_with_prefix(
657
+ self,
658
+ dict_: Dict[str, torch.Tensor],
659
+ prefix: str,
660
+ batch_size: Optional[int] = None,
661
+ **kwargs: Any,
662
+ ) -> None:
663
+ self.log_dict(
664
+ {f"{prefix}/{k}": v for k, v in dict_.items()},
665
+ batch_size=batch_size,
666
+ logger=True,
667
+ sync_dist=True,
668
+ **kwargs,
669
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvsepless/models/bandit/core/data/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
- from .dnr.datamodule import DivideAndRemasterDataModule
2
- from .musdb.datamodule import MUSDB18DataModule
 
1
+ from .dnr.datamodule import DivideAndRemasterDataModule
2
+ from .musdb.datamodule import MUSDB18DataModule
mvsepless/models/bandit/core/data/_types.py CHANGED
@@ -1,17 +1,17 @@
1
- from typing import Dict, Sequence, TypedDict
2
-
3
- import torch
4
-
5
- AudioDict = Dict[str, torch.Tensor]
6
-
7
- DataDict = TypedDict("DataDict", {"audio": AudioDict, "track": str})
8
-
9
- BatchedDataDict = TypedDict(
10
- "BatchedDataDict", {"audio": AudioDict, "track": Sequence[str]}
11
- )
12
-
13
-
14
- class DataDictWithLanguage(TypedDict):
15
- audio: AudioDict
16
- track: str
17
- language: str
 
1
+ from typing import Dict, Sequence, TypedDict
2
+
3
+ import torch
4
+
5
+ AudioDict = Dict[str, torch.Tensor]
6
+
7
+ DataDict = TypedDict("DataDict", {"audio": AudioDict, "track": str})
8
+
9
+ BatchedDataDict = TypedDict(
10
+ "BatchedDataDict", {"audio": AudioDict, "track": Sequence[str]}
11
+ )
12
+
13
+
14
+ class DataDictWithLanguage(TypedDict):
15
+ audio: AudioDict
16
+ track: str
17
+ language: str
mvsepless/models/bandit/core/data/augmentation.py CHANGED
@@ -1,102 +1,102 @@
1
- from abc import ABC
2
- from typing import Any, Dict, Union
3
-
4
- import torch
5
- import torch_audiomentations as tam
6
- from torch import nn
7
-
8
- from ._types import BatchedDataDict, DataDict
9
-
10
-
11
- class BaseAugmentor(nn.Module, ABC):
12
- def forward(
13
- self, item: Union[DataDict, BatchedDataDict]
14
- ) -> Union[DataDict, BatchedDataDict]:
15
- raise NotImplementedError
16
-
17
-
18
- class StemAugmentor(BaseAugmentor):
19
- def __init__(
20
- self,
21
- audiomentations: Dict[str, Dict[str, Any]],
22
- fix_clipping: bool = True,
23
- scaler_margin: float = 0.5,
24
- apply_both_default_and_common: bool = False,
25
- ) -> None:
26
- super().__init__()
27
-
28
- augmentations = {}
29
-
30
- self.has_default = "[default]" in audiomentations
31
- self.has_common = "[common]" in audiomentations
32
- self.apply_both_default_and_common = apply_both_default_and_common
33
-
34
- for stem in audiomentations:
35
- if audiomentations[stem]["name"] == "Compose":
36
- augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
37
- [
38
- getattr(tam, aug["name"])(**aug["kwargs"])
39
- for aug in audiomentations[stem]["kwargs"]["transforms"]
40
- ],
41
- **audiomentations[stem]["kwargs"]["kwargs"],
42
- )
43
- else:
44
- augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
45
- **audiomentations[stem]["kwargs"]
46
- )
47
-
48
- self.augmentations = nn.ModuleDict(augmentations)
49
- self.fix_clipping = fix_clipping
50
- self.scaler_margin = scaler_margin
51
-
52
- def check_and_fix_clipping(
53
- self, item: Union[DataDict, BatchedDataDict]
54
- ) -> Union[DataDict, BatchedDataDict]:
55
- max_abs = []
56
-
57
- for stem in item["audio"]:
58
- max_abs.append(item["audio"][stem].abs().max().item())
59
-
60
- if max(max_abs) > 1.0:
61
- scaler = 1.0 / (
62
- max(max_abs)
63
- + torch.rand((1,), device=item["audio"]["mixture"].device)
64
- * self.scaler_margin
65
- )
66
-
67
- for stem in item["audio"]:
68
- item["audio"][stem] *= scaler
69
-
70
- return item
71
-
72
- def forward(
73
- self, item: Union[DataDict, BatchedDataDict]
74
- ) -> Union[DataDict, BatchedDataDict]:
75
-
76
- for stem in item["audio"]:
77
- if stem == "mixture":
78
- continue
79
-
80
- if self.has_common:
81
- item["audio"][stem] = self.augmentations["[common]"](
82
- item["audio"][stem]
83
- ).samples
84
-
85
- if stem in self.augmentations:
86
- item["audio"][stem] = self.augmentations[stem](
87
- item["audio"][stem]
88
- ).samples
89
- elif self.has_default:
90
- if not self.has_common or self.apply_both_default_and_common:
91
- item["audio"][stem] = self.augmentations["[default]"](
92
- item["audio"][stem]
93
- ).samples
94
-
95
- item["audio"]["mixture"] = sum(
96
- [item["audio"][stem] for stem in item["audio"] if stem != "mixture"]
97
- ) # type: ignore[call-overload, assignment]
98
-
99
- if self.fix_clipping:
100
- item = self.check_and_fix_clipping(item)
101
-
102
- return item
 
1
+ from abc import ABC
2
+ from typing import Any, Dict, Union
3
+
4
+ import torch
5
+ import torch_audiomentations as tam
6
+ from torch import nn
7
+
8
+ from ._types import BatchedDataDict, DataDict
9
+
10
+
11
+ class BaseAugmentor(nn.Module, ABC):
12
+ def forward(
13
+ self, item: Union[DataDict, BatchedDataDict]
14
+ ) -> Union[DataDict, BatchedDataDict]:
15
+ raise NotImplementedError
16
+
17
+
18
+ class StemAugmentor(BaseAugmentor):
19
+ def __init__(
20
+ self,
21
+ audiomentations: Dict[str, Dict[str, Any]],
22
+ fix_clipping: bool = True,
23
+ scaler_margin: float = 0.5,
24
+ apply_both_default_and_common: bool = False,
25
+ ) -> None:
26
+ super().__init__()
27
+
28
+ augmentations = {}
29
+
30
+ self.has_default = "[default]" in audiomentations
31
+ self.has_common = "[common]" in audiomentations
32
+ self.apply_both_default_and_common = apply_both_default_and_common
33
+
34
+ for stem in audiomentations:
35
+ if audiomentations[stem]["name"] == "Compose":
36
+ augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
37
+ [
38
+ getattr(tam, aug["name"])(**aug["kwargs"])
39
+ for aug in audiomentations[stem]["kwargs"]["transforms"]
40
+ ],
41
+ **audiomentations[stem]["kwargs"]["kwargs"],
42
+ )
43
+ else:
44
+ augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
45
+ **audiomentations[stem]["kwargs"]
46
+ )
47
+
48
+ self.augmentations = nn.ModuleDict(augmentations)
49
+ self.fix_clipping = fix_clipping
50
+ self.scaler_margin = scaler_margin
51
+
52
+ def check_and_fix_clipping(
53
+ self, item: Union[DataDict, BatchedDataDict]
54
+ ) -> Union[DataDict, BatchedDataDict]:
55
+ max_abs = []
56
+
57
+ for stem in item["audio"]:
58
+ max_abs.append(item["audio"][stem].abs().max().item())
59
+
60
+ if max(max_abs) > 1.0:
61
+ scaler = 1.0 / (
62
+ max(max_abs)
63
+ + torch.rand((1,), device=item["audio"]["mixture"].device)
64
+ * self.scaler_margin
65
+ )
66
+
67
+ for stem in item["audio"]:
68
+ item["audio"][stem] *= scaler
69
+
70
+ return item
71
+
72
+ def forward(
73
+ self, item: Union[DataDict, BatchedDataDict]
74
+ ) -> Union[DataDict, BatchedDataDict]:
75
+
76
+ for stem in item["audio"]:
77
+ if stem == "mixture":
78
+ continue
79
+
80
+ if self.has_common:
81
+ item["audio"][stem] = self.augmentations["[common]"](
82
+ item["audio"][stem]
83
+ ).samples
84
+
85
+ if stem in self.augmentations:
86
+ item["audio"][stem] = self.augmentations[stem](
87
+ item["audio"][stem]
88
+ ).samples
89
+ elif self.has_default:
90
+ if not self.has_common or self.apply_both_default_and_common:
91
+ item["audio"][stem] = self.augmentations["[default]"](
92
+ item["audio"][stem]
93
+ ).samples
94
+
95
+ item["audio"]["mixture"] = sum(
96
+ [item["audio"][stem] for stem in item["audio"] if stem != "mixture"]
97
+ ) # type: ignore[call-overload, assignment]
98
+
99
+ if self.fix_clipping:
100
+ item = self.check_and_fix_clipping(item)
101
+
102
+ return item
mvsepless/models/bandit/core/data/augmented.py CHANGED
@@ -1,34 +1,34 @@
1
- import warnings
2
- from typing import Dict, Optional, Union
3
-
4
- import torch
5
- from torch import nn
6
- from torch.utils import data
7
-
8
-
9
- class AugmentedDataset(data.Dataset):
10
- def __init__(
11
- self,
12
- dataset: data.Dataset,
13
- augmentation: nn.Module = nn.Identity(),
14
- target_length: Optional[int] = None,
15
- ) -> None:
16
- warnings.warn(
17
- "This class is no longer used. Attach augmentation to "
18
- "the LightningSystem instead.",
19
- DeprecationWarning,
20
- )
21
-
22
- self.dataset = dataset
23
- self.augmentation = augmentation
24
-
25
- self.ds_length: int = len(dataset) # type: ignore[arg-type]
26
- self.length = target_length if target_length is not None else self.ds_length
27
-
28
- def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str, torch.Tensor]]]:
29
- item = self.dataset[index % self.ds_length]
30
- item = self.augmentation(item)
31
- return item
32
-
33
- def __len__(self) -> int:
34
- return self.length
 
1
+ import warnings
2
+ from typing import Dict, Optional, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.utils import data
7
+
8
+
9
+ class AugmentedDataset(data.Dataset):
10
+ def __init__(
11
+ self,
12
+ dataset: data.Dataset,
13
+ augmentation: nn.Module = nn.Identity(),
14
+ target_length: Optional[int] = None,
15
+ ) -> None:
16
+ warnings.warn(
17
+ "This class is no longer used. Attach augmentation to "
18
+ "the LightningSystem instead.",
19
+ DeprecationWarning,
20
+ )
21
+
22
+ self.dataset = dataset
23
+ self.augmentation = augmentation
24
+
25
+ self.ds_length: int = len(dataset) # type: ignore[arg-type]
26
+ self.length = target_length if target_length is not None else self.ds_length
27
+
28
+ def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str, torch.Tensor]]]:
29
+ item = self.dataset[index % self.ds_length]
30
+ item = self.augmentation(item)
31
+ return item
32
+
33
+ def __len__(self) -> int:
34
+ return self.length
mvsepless/models/bandit/core/data/base.py CHANGED
@@ -1,60 +1,60 @@
1
- import os
2
- from abc import ABC, abstractmethod
3
- from typing import Any, Dict, List, Optional
4
-
5
- import numpy as np
6
- import pedalboard as pb
7
- import torch
8
- import torchaudio as ta
9
- from torch.utils import data
10
-
11
- from ._types import AudioDict, DataDict
12
-
13
-
14
- class BaseSourceSeparationDataset(data.Dataset, ABC):
15
- def __init__(
16
- self,
17
- split: str,
18
- stems: List[str],
19
- files: List[str],
20
- data_path: str,
21
- fs: int,
22
- npy_memmap: bool,
23
- recompute_mixture: bool,
24
- ):
25
- self.split = split
26
- self.stems = stems
27
- self.stems_no_mixture = [s for s in stems if s != "mixture"]
28
- self.files = files
29
- self.data_path = data_path
30
- self.fs = fs
31
- self.npy_memmap = npy_memmap
32
- self.recompute_mixture = recompute_mixture
33
-
34
- @abstractmethod
35
- def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
36
- raise NotImplementedError
37
-
38
- def _get_audio(self, stems, identifier: Dict[str, Any]):
39
- audio = {}
40
- for stem in stems:
41
- audio[stem] = self.get_stem(stem=stem, identifier=identifier)
42
-
43
- return audio
44
-
45
- def get_audio(self, identifier: Dict[str, Any]) -> AudioDict:
46
-
47
- if self.recompute_mixture:
48
- audio = self._get_audio(self.stems_no_mixture, identifier=identifier)
49
- audio["mixture"] = self.compute_mixture(audio)
50
- return audio
51
- else:
52
- return self._get_audio(self.stems, identifier=identifier)
53
-
54
- @abstractmethod
55
- def get_identifier(self, index: int) -> Dict[str, Any]:
56
- pass
57
-
58
- def compute_mixture(self, audio: AudioDict) -> torch.Tensor:
59
-
60
- return sum(audio[stem] for stem in audio if stem != "mixture")
 
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import numpy as np
6
+ import pedalboard as pb
7
+ import torch
8
+ import torchaudio as ta
9
+ from torch.utils import data
10
+
11
+ from ._types import AudioDict, DataDict
12
+
13
+
14
+ class BaseSourceSeparationDataset(data.Dataset, ABC):
15
+ def __init__(
16
+ self,
17
+ split: str,
18
+ stems: List[str],
19
+ files: List[str],
20
+ data_path: str,
21
+ fs: int,
22
+ npy_memmap: bool,
23
+ recompute_mixture: bool,
24
+ ):
25
+ self.split = split
26
+ self.stems = stems
27
+ self.stems_no_mixture = [s for s in stems if s != "mixture"]
28
+ self.files = files
29
+ self.data_path = data_path
30
+ self.fs = fs
31
+ self.npy_memmap = npy_memmap
32
+ self.recompute_mixture = recompute_mixture
33
+
34
+ @abstractmethod
35
+ def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
36
+ raise NotImplementedError
37
+
38
+ def _get_audio(self, stems, identifier: Dict[str, Any]):
39
+ audio = {}
40
+ for stem in stems:
41
+ audio[stem] = self.get_stem(stem=stem, identifier=identifier)
42
+
43
+ return audio
44
+
45
+ def get_audio(self, identifier: Dict[str, Any]) -> AudioDict:
46
+
47
+ if self.recompute_mixture:
48
+ audio = self._get_audio(self.stems_no_mixture, identifier=identifier)
49
+ audio["mixture"] = self.compute_mixture(audio)
50
+ return audio
51
+ else:
52
+ return self._get_audio(self.stems, identifier=identifier)
53
+
54
+ @abstractmethod
55
+ def get_identifier(self, index: int) -> Dict[str, Any]:
56
+ pass
57
+
58
+ def compute_mixture(self, audio: AudioDict) -> torch.Tensor:
59
+
60
+ return sum(audio[stem] for stem in audio if stem != "mixture")
mvsepless/models/bandit/core/data/dnr/datamodule.py CHANGED
@@ -1,68 +1,64 @@
1
- import os
2
- from typing import Mapping, Optional
3
-
4
- import pytorch_lightning as pl
5
-
6
- from .dataset import (
7
- DivideAndRemasterDataset,
8
- DivideAndRemasterDeterministicChunkDataset,
9
- DivideAndRemasterRandomChunkDataset,
10
- DivideAndRemasterRandomChunkDatasetWithSpeechReverb,
11
- )
12
-
13
-
14
- def DivideAndRemasterDataModule(
15
- data_root: str = "$DATA_ROOT/DnR/v2",
16
- batch_size: int = 2,
17
- num_workers: int = 8,
18
- train_kwargs: Optional[Mapping] = None,
19
- val_kwargs: Optional[Mapping] = None,
20
- test_kwargs: Optional[Mapping] = None,
21
- datamodule_kwargs: Optional[Mapping] = None,
22
- use_speech_reverb: bool = False,
23
- # augmentor=None
24
- ) -> pl.LightningDataModule:
25
- if train_kwargs is None:
26
- train_kwargs = {}
27
-
28
- if val_kwargs is None:
29
- val_kwargs = {}
30
-
31
- if test_kwargs is None:
32
- test_kwargs = {}
33
-
34
- if datamodule_kwargs is None:
35
- datamodule_kwargs = {}
36
-
37
- if num_workers is None:
38
- num_workers = os.cpu_count()
39
-
40
- if num_workers is None:
41
- num_workers = 32
42
-
43
- num_workers = min(num_workers, 64)
44
-
45
- if use_speech_reverb:
46
- train_cls = DivideAndRemasterRandomChunkDatasetWithSpeechReverb
47
- else:
48
- train_cls = DivideAndRemasterRandomChunkDataset
49
-
50
- train_dataset = train_cls(data_root, "train", **train_kwargs)
51
-
52
- # if augmentor is not None:
53
- # train_dataset = AugmentedDataset(train_dataset, augmentor)
54
-
55
- datamodule = pl.LightningDataModule.from_datasets(
56
- train_dataset=train_dataset,
57
- val_dataset=DivideAndRemasterDeterministicChunkDataset(
58
- data_root, "val", **val_kwargs
59
- ),
60
- test_dataset=DivideAndRemasterDataset(data_root, "test", **test_kwargs),
61
- batch_size=batch_size,
62
- num_workers=num_workers,
63
- **datamodule_kwargs
64
- )
65
-
66
- datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign]
67
-
68
- return datamodule
 
1
+ import os
2
+ from typing import Mapping, Optional
3
+
4
+ import pytorch_lightning as pl
5
+
6
+ from .dataset import (
7
+ DivideAndRemasterDataset,
8
+ DivideAndRemasterDeterministicChunkDataset,
9
+ DivideAndRemasterRandomChunkDataset,
10
+ DivideAndRemasterRandomChunkDatasetWithSpeechReverb,
11
+ )
12
+
13
+
14
+ def DivideAndRemasterDataModule(
15
+ data_root: str = "$DATA_ROOT/DnR/v2",
16
+ batch_size: int = 2,
17
+ num_workers: int = 8,
18
+ train_kwargs: Optional[Mapping] = None,
19
+ val_kwargs: Optional[Mapping] = None,
20
+ test_kwargs: Optional[Mapping] = None,
21
+ datamodule_kwargs: Optional[Mapping] = None,
22
+ use_speech_reverb: bool = False,
23
+ ) -> pl.LightningDataModule:
24
+ if train_kwargs is None:
25
+ train_kwargs = {}
26
+
27
+ if val_kwargs is None:
28
+ val_kwargs = {}
29
+
30
+ if test_kwargs is None:
31
+ test_kwargs = {}
32
+
33
+ if datamodule_kwargs is None:
34
+ datamodule_kwargs = {}
35
+
36
+ if num_workers is None:
37
+ num_workers = os.cpu_count()
38
+
39
+ if num_workers is None:
40
+ num_workers = 32
41
+
42
+ num_workers = min(num_workers, 64)
43
+
44
+ if use_speech_reverb:
45
+ train_cls = DivideAndRemasterRandomChunkDatasetWithSpeechReverb
46
+ else:
47
+ train_cls = DivideAndRemasterRandomChunkDataset
48
+
49
+ train_dataset = train_cls(data_root, "train", **train_kwargs)
50
+
51
+ datamodule = pl.LightningDataModule.from_datasets(
52
+ train_dataset=train_dataset,
53
+ val_dataset=DivideAndRemasterDeterministicChunkDataset(
54
+ data_root, "val", **val_kwargs
55
+ ),
56
+ test_dataset=DivideAndRemasterDataset(data_root, "test", **test_kwargs),
57
+ batch_size=batch_size,
58
+ num_workers=num_workers,
59
+ **datamodule_kwargs,
60
+ )
61
+
62
+ datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign]
63
+
64
+ return datamodule
 
 
 
 
mvsepless/models/bandit/core/data/dnr/dataset.py CHANGED
@@ -1,366 +1,360 @@
1
- import os
2
- from abc import ABC
3
- from typing import Any, Dict, List, Optional
4
-
5
- import numpy as np
6
- import pedalboard as pb
7
- import torch
8
- import torchaudio as ta
9
- from torch.utils import data
10
-
11
- from .._types import AudioDict, DataDict
12
- from ..base import BaseSourceSeparationDataset
13
-
14
-
15
- class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC):
16
- ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"]
17
- STEM_NAME_MAP = {
18
- "mixture": "mix",
19
- "speech": "speech",
20
- "music": "music",
21
- "effects": "sfx",
22
- }
23
- SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"}
24
-
25
- FULL_TRACK_LENGTH_SECOND = 60
26
- FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100
27
-
28
- def __init__(
29
- self,
30
- split: str,
31
- stems: List[str],
32
- files: List[str],
33
- data_path: str,
34
- fs: int = 44100,
35
- npy_memmap: bool = True,
36
- recompute_mixture: bool = False,
37
- ) -> None:
38
- super().__init__(
39
- split=split,
40
- stems=stems,
41
- files=files,
42
- data_path=data_path,
43
- fs=fs,
44
- npy_memmap=npy_memmap,
45
- recompute_mixture=recompute_mixture,
46
- )
47
-
48
- def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
49
-
50
- if stem == "mne":
51
- return self.get_stem(stem="music", identifier=identifier) + self.get_stem(
52
- stem="effects", identifier=identifier
53
- )
54
-
55
- track = identifier["track"]
56
- path = os.path.join(self.data_path, track)
57
-
58
- if self.npy_memmap:
59
- audio = np.load(
60
- os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"), mmap_mode="r"
61
- )
62
- else:
63
- # noinspection PyUnresolvedReferences
64
- audio, _ = ta.load(os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav"))
65
-
66
- return audio
67
-
68
- def get_identifier(self, index):
69
- return dict(track=self.files[index])
70
-
71
- def __getitem__(self, index: int) -> DataDict:
72
- identifier = self.get_identifier(index)
73
- audio = self.get_audio(identifier)
74
-
75
- return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
76
-
77
-
78
- class DivideAndRemasterDataset(DivideAndRemasterBaseDataset):
79
- def __init__(
80
- self,
81
- data_root: str,
82
- split: str,
83
- stems: Optional[List[str]] = None,
84
- fs: int = 44100,
85
- npy_memmap: bool = True,
86
- ) -> None:
87
-
88
- if stems is None:
89
- stems = self.ALLOWED_STEMS
90
- self.stems = stems
91
-
92
- data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
93
-
94
- files = sorted(os.listdir(data_path))
95
- files = [
96
- f
97
- for f in files
98
- if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
99
- ]
100
- # pprint(list(enumerate(files)))
101
- if split == "train":
102
- assert len(files) == 3406, len(files)
103
- elif split == "val":
104
- assert len(files) == 487, len(files)
105
- elif split == "test":
106
- assert len(files) == 973, len(files)
107
-
108
- self.n_tracks = len(files)
109
-
110
- super().__init__(
111
- data_path=data_path,
112
- split=split,
113
- stems=stems,
114
- files=files,
115
- fs=fs,
116
- npy_memmap=npy_memmap,
117
- )
118
-
119
- def __len__(self) -> int:
120
- return self.n_tracks
121
-
122
-
123
- class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset):
124
- def __init__(
125
- self,
126
- data_root: str,
127
- split: str,
128
- target_length: int,
129
- chunk_size_second: float,
130
- stems: Optional[List[str]] = None,
131
- fs: int = 44100,
132
- npy_memmap: bool = True,
133
- ) -> None:
134
-
135
- if stems is None:
136
- stems = self.ALLOWED_STEMS
137
- self.stems = stems
138
-
139
- data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
140
-
141
- files = sorted(os.listdir(data_path))
142
- files = [
143
- f
144
- for f in files
145
- if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
146
- ]
147
-
148
- if split == "train":
149
- assert len(files) == 3406, len(files)
150
- elif split == "val":
151
- assert len(files) == 487, len(files)
152
- elif split == "test":
153
- assert len(files) == 973, len(files)
154
-
155
- self.n_tracks = len(files)
156
-
157
- self.target_length = target_length
158
- self.chunk_size = int(chunk_size_second * fs)
159
-
160
- super().__init__(
161
- data_path=data_path,
162
- split=split,
163
- stems=stems,
164
- files=files,
165
- fs=fs,
166
- npy_memmap=npy_memmap,
167
- )
168
-
169
- def __len__(self) -> int:
170
- return self.target_length
171
-
172
- def get_identifier(self, index):
173
- return super().get_identifier(index % self.n_tracks)
174
-
175
- def get_stem(
176
- self,
177
- *,
178
- stem: str,
179
- identifier: Dict[str, Any],
180
- chunk_here: bool = False,
181
- ) -> torch.Tensor:
182
-
183
- stem = super().get_stem(stem=stem, identifier=identifier)
184
-
185
- if chunk_here:
186
- start = np.random.randint(
187
- 0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
188
- )
189
- end = start + self.chunk_size
190
-
191
- stem = stem[:, start:end]
192
-
193
- return stem
194
-
195
- def __getitem__(self, index: int) -> DataDict:
196
- identifier = self.get_identifier(index)
197
- # self.index_lock = index
198
- audio = self.get_audio(identifier)
199
- # self.index_lock = None
200
-
201
- start = np.random.randint(0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size)
202
- end = start + self.chunk_size
203
-
204
- audio = {k: v[:, start:end] for k, v in audio.items()}
205
-
206
- return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
207
-
208
-
209
- class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset):
210
- def __init__(
211
- self,
212
- data_root: str,
213
- split: str,
214
- chunk_size_second: float,
215
- hop_size_second: float,
216
- stems: Optional[List[str]] = None,
217
- fs: int = 44100,
218
- npy_memmap: bool = True,
219
- ) -> None:
220
-
221
- if stems is None:
222
- stems = self.ALLOWED_STEMS
223
- self.stems = stems
224
-
225
- data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
226
-
227
- files = sorted(os.listdir(data_path))
228
- files = [
229
- f
230
- for f in files
231
- if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
232
- ]
233
- # pprint(list(enumerate(files)))
234
- if split == "train":
235
- assert len(files) == 3406, len(files)
236
- elif split == "val":
237
- assert len(files) == 487, len(files)
238
- elif split == "test":
239
- assert len(files) == 973, len(files)
240
-
241
- self.n_tracks = len(files)
242
-
243
- self.chunk_size = int(chunk_size_second * fs)
244
- self.hop_size = int(hop_size_second * fs)
245
- self.n_chunks_per_track = int(
246
- (self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second
247
- )
248
-
249
- self.length = self.n_tracks * self.n_chunks_per_track
250
-
251
- super().__init__(
252
- data_path=data_path,
253
- split=split,
254
- stems=stems,
255
- files=files,
256
- fs=fs,
257
- npy_memmap=npy_memmap,
258
- )
259
-
260
- def get_identifier(self, index):
261
- return super().get_identifier(index % self.n_tracks)
262
-
263
- def __len__(self) -> int:
264
- return self.length
265
-
266
- def __getitem__(self, item: int) -> DataDict:
267
-
268
- index = item % self.n_tracks
269
- chunk = item // self.n_tracks
270
-
271
- data_ = super().__getitem__(index)
272
-
273
- audio = data_["audio"]
274
-
275
- start = chunk * self.hop_size
276
- end = start + self.chunk_size
277
-
278
- for stem in self.stems:
279
- data_["audio"][stem] = audio[stem][:, start:end]
280
-
281
- return data_
282
-
283
-
284
- class DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
285
- DivideAndRemasterRandomChunkDataset
286
- ):
287
- def __init__(
288
- self,
289
- data_root: str,
290
- split: str,
291
- target_length: int,
292
- chunk_size_second: float,
293
- stems: Optional[List[str]] = None,
294
- fs: int = 44100,
295
- npy_memmap: bool = True,
296
- ) -> None:
297
-
298
- if stems is None:
299
- stems = self.ALLOWED_STEMS
300
-
301
- stems_no_mixture = [s for s in stems if s != "mixture"]
302
-
303
- super().__init__(
304
- data_root=data_root,
305
- split=split,
306
- target_length=target_length,
307
- chunk_size_second=chunk_size_second,
308
- stems=stems_no_mixture,
309
- fs=fs,
310
- npy_memmap=npy_memmap,
311
- )
312
-
313
- self.stems = stems
314
- self.stems_no_mixture = stems_no_mixture
315
-
316
- def __getitem__(self, index: int) -> DataDict:
317
-
318
- data_ = super().__getitem__(index)
319
-
320
- dry = data_["audio"]["speech"][:]
321
- n_samples = dry.shape[-1]
322
-
323
- wet_level = np.random.rand()
324
-
325
- speech = pb.Reverb(
326
- room_size=np.random.rand(),
327
- damping=np.random.rand(),
328
- wet_level=wet_level,
329
- dry_level=(1 - wet_level),
330
- width=np.random.rand(),
331
- ).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples]
332
-
333
- data_["audio"]["speech"] = speech
334
-
335
- data_["audio"]["mixture"] = sum(
336
- [data_["audio"][s] for s in self.stems_no_mixture]
337
- )
338
-
339
- return data_
340
-
341
- def __len__(self) -> int:
342
- return super().__len__()
343
-
344
-
345
- if __name__ == "__main__":
346
-
347
- from pprint import pprint
348
- from tqdm.auto import tqdm
349
-
350
- for split_ in ["train", "val", "test"]:
351
- ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
352
- data_root="$DATA_ROOT/DnR/v2np",
353
- split=split_,
354
- target_length=100,
355
- chunk_size_second=6.0,
356
- )
357
-
358
- print(split_, len(ds))
359
-
360
- for track_ in tqdm(ds): # type: ignore
361
- pprint(track_)
362
- track_["audio"] = {k: v.shape for k, v in track_["audio"].items()}
363
- pprint(track_)
364
- # break
365
-
366
- break
 
1
+ import os
2
+ from abc import ABC
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import numpy as np
6
+ import pedalboard as pb
7
+ import torch
8
+ import torchaudio as ta
9
+ from torch.utils import data
10
+
11
+ from .._types import AudioDict, DataDict
12
+ from ..base import BaseSourceSeparationDataset
13
+
14
+
15
+ class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC):
16
+ ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"]
17
+ STEM_NAME_MAP = {
18
+ "mixture": "mix",
19
+ "speech": "speech",
20
+ "music": "music",
21
+ "effects": "sfx",
22
+ }
23
+ SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"}
24
+
25
+ FULL_TRACK_LENGTH_SECOND = 60
26
+ FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100
27
+
28
+ def __init__(
29
+ self,
30
+ split: str,
31
+ stems: List[str],
32
+ files: List[str],
33
+ data_path: str,
34
+ fs: int = 44100,
35
+ npy_memmap: bool = True,
36
+ recompute_mixture: bool = False,
37
+ ) -> None:
38
+ super().__init__(
39
+ split=split,
40
+ stems=stems,
41
+ files=files,
42
+ data_path=data_path,
43
+ fs=fs,
44
+ npy_memmap=npy_memmap,
45
+ recompute_mixture=recompute_mixture,
46
+ )
47
+
48
+ def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
49
+
50
+ if stem == "mne":
51
+ return self.get_stem(stem="music", identifier=identifier) + self.get_stem(
52
+ stem="effects", identifier=identifier
53
+ )
54
+
55
+ track = identifier["track"]
56
+ path = os.path.join(self.data_path, track)
57
+
58
+ if self.npy_memmap:
59
+ audio = np.load(
60
+ os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"), mmap_mode="r"
61
+ )
62
+ else:
63
+ audio, _ = ta.load(os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav"))
64
+
65
+ return audio
66
+
67
+ def get_identifier(self, index):
68
+ return dict(track=self.files[index])
69
+
70
+ def __getitem__(self, index: int) -> DataDict:
71
+ identifier = self.get_identifier(index)
72
+ audio = self.get_audio(identifier)
73
+
74
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
75
+
76
+
77
+ class DivideAndRemasterDataset(DivideAndRemasterBaseDataset):
78
+ def __init__(
79
+ self,
80
+ data_root: str,
81
+ split: str,
82
+ stems: Optional[List[str]] = None,
83
+ fs: int = 44100,
84
+ npy_memmap: bool = True,
85
+ ) -> None:
86
+
87
+ if stems is None:
88
+ stems = self.ALLOWED_STEMS
89
+ self.stems = stems
90
+
91
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
92
+
93
+ files = sorted(os.listdir(data_path))
94
+ files = [
95
+ f
96
+ for f in files
97
+ if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
98
+ ]
99
+ if split == "train":
100
+ assert len(files) == 3406, len(files)
101
+ elif split == "val":
102
+ assert len(files) == 487, len(files)
103
+ elif split == "test":
104
+ assert len(files) == 973, len(files)
105
+
106
+ self.n_tracks = len(files)
107
+
108
+ super().__init__(
109
+ data_path=data_path,
110
+ split=split,
111
+ stems=stems,
112
+ files=files,
113
+ fs=fs,
114
+ npy_memmap=npy_memmap,
115
+ )
116
+
117
+ def __len__(self) -> int:
118
+ return self.n_tracks
119
+
120
+
121
+ class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset):
122
+ def __init__(
123
+ self,
124
+ data_root: str,
125
+ split: str,
126
+ target_length: int,
127
+ chunk_size_second: float,
128
+ stems: Optional[List[str]] = None,
129
+ fs: int = 44100,
130
+ npy_memmap: bool = True,
131
+ ) -> None:
132
+
133
+ if stems is None:
134
+ stems = self.ALLOWED_STEMS
135
+ self.stems = stems
136
+
137
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
138
+
139
+ files = sorted(os.listdir(data_path))
140
+ files = [
141
+ f
142
+ for f in files
143
+ if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
144
+ ]
145
+
146
+ if split == "train":
147
+ assert len(files) == 3406, len(files)
148
+ elif split == "val":
149
+ assert len(files) == 487, len(files)
150
+ elif split == "test":
151
+ assert len(files) == 973, len(files)
152
+
153
+ self.n_tracks = len(files)
154
+
155
+ self.target_length = target_length
156
+ self.chunk_size = int(chunk_size_second * fs)
157
+
158
+ super().__init__(
159
+ data_path=data_path,
160
+ split=split,
161
+ stems=stems,
162
+ files=files,
163
+ fs=fs,
164
+ npy_memmap=npy_memmap,
165
+ )
166
+
167
+ def __len__(self) -> int:
168
+ return self.target_length
169
+
170
+ def get_identifier(self, index):
171
+ return super().get_identifier(index % self.n_tracks)
172
+
173
+ def get_stem(
174
+ self,
175
+ *,
176
+ stem: str,
177
+ identifier: Dict[str, Any],
178
+ chunk_here: bool = False,
179
+ ) -> torch.Tensor:
180
+
181
+ stem = super().get_stem(stem=stem, identifier=identifier)
182
+
183
+ if chunk_here:
184
+ start = np.random.randint(
185
+ 0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
186
+ )
187
+ end = start + self.chunk_size
188
+
189
+ stem = stem[:, start:end]
190
+
191
+ return stem
192
+
193
+ def __getitem__(self, index: int) -> DataDict:
194
+ identifier = self.get_identifier(index)
195
+ audio = self.get_audio(identifier)
196
+
197
+ start = np.random.randint(0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size)
198
+ end = start + self.chunk_size
199
+
200
+ audio = {k: v[:, start:end] for k, v in audio.items()}
201
+
202
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
203
+
204
+
205
+ class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset):
206
+ def __init__(
207
+ self,
208
+ data_root: str,
209
+ split: str,
210
+ chunk_size_second: float,
211
+ hop_size_second: float,
212
+ stems: Optional[List[str]] = None,
213
+ fs: int = 44100,
214
+ npy_memmap: bool = True,
215
+ ) -> None:
216
+
217
+ if stems is None:
218
+ stems = self.ALLOWED_STEMS
219
+ self.stems = stems
220
+
221
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
222
+
223
+ files = sorted(os.listdir(data_path))
224
+ files = [
225
+ f
226
+ for f in files
227
+ if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
228
+ ]
229
+ if split == "train":
230
+ assert len(files) == 3406, len(files)
231
+ elif split == "val":
232
+ assert len(files) == 487, len(files)
233
+ elif split == "test":
234
+ assert len(files) == 973, len(files)
235
+
236
+ self.n_tracks = len(files)
237
+
238
+ self.chunk_size = int(chunk_size_second * fs)
239
+ self.hop_size = int(hop_size_second * fs)
240
+ self.n_chunks_per_track = int(
241
+ (self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second
242
+ )
243
+
244
+ self.length = self.n_tracks * self.n_chunks_per_track
245
+
246
+ super().__init__(
247
+ data_path=data_path,
248
+ split=split,
249
+ stems=stems,
250
+ files=files,
251
+ fs=fs,
252
+ npy_memmap=npy_memmap,
253
+ )
254
+
255
+ def get_identifier(self, index):
256
+ return super().get_identifier(index % self.n_tracks)
257
+
258
+ def __len__(self) -> int:
259
+ return self.length
260
+
261
+ def __getitem__(self, item: int) -> DataDict:
262
+
263
+ index = item % self.n_tracks
264
+ chunk = item // self.n_tracks
265
+
266
+ data_ = super().__getitem__(index)
267
+
268
+ audio = data_["audio"]
269
+
270
+ start = chunk * self.hop_size
271
+ end = start + self.chunk_size
272
+
273
+ for stem in self.stems:
274
+ data_["audio"][stem] = audio[stem][:, start:end]
275
+
276
+ return data_
277
+
278
+
279
+ class DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
280
+ DivideAndRemasterRandomChunkDataset
281
+ ):
282
+ def __init__(
283
+ self,
284
+ data_root: str,
285
+ split: str,
286
+ target_length: int,
287
+ chunk_size_second: float,
288
+ stems: Optional[List[str]] = None,
289
+ fs: int = 44100,
290
+ npy_memmap: bool = True,
291
+ ) -> None:
292
+
293
+ if stems is None:
294
+ stems = self.ALLOWED_STEMS
295
+
296
+ stems_no_mixture = [s for s in stems if s != "mixture"]
297
+
298
+ super().__init__(
299
+ data_root=data_root,
300
+ split=split,
301
+ target_length=target_length,
302
+ chunk_size_second=chunk_size_second,
303
+ stems=stems_no_mixture,
304
+ fs=fs,
305
+ npy_memmap=npy_memmap,
306
+ )
307
+
308
+ self.stems = stems
309
+ self.stems_no_mixture = stems_no_mixture
310
+
311
+ def __getitem__(self, index: int) -> DataDict:
312
+
313
+ data_ = super().__getitem__(index)
314
+
315
+ dry = data_["audio"]["speech"][:]
316
+ n_samples = dry.shape[-1]
317
+
318
+ wet_level = np.random.rand()
319
+
320
+ speech = pb.Reverb(
321
+ room_size=np.random.rand(),
322
+ damping=np.random.rand(),
323
+ wet_level=wet_level,
324
+ dry_level=(1 - wet_level),
325
+ width=np.random.rand(),
326
+ ).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples]
327
+
328
+ data_["audio"]["speech"] = speech
329
+
330
+ data_["audio"]["mixture"] = sum(
331
+ [data_["audio"][s] for s in self.stems_no_mixture]
332
+ )
333
+
334
+ return data_
335
+
336
+ def __len__(self) -> int:
337
+ return super().__len__()
338
+
339
+
340
+ if __name__ == "__main__":
341
+
342
+ from pprint import pprint
343
+ from tqdm.auto import tqdm
344
+
345
+ for split_ in ["train", "val", "test"]:
346
+ ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
347
+ data_root="$DATA_ROOT/DnR/v2np",
348
+ split=split_,
349
+ target_length=100,
350
+ chunk_size_second=6.0,
351
+ )
352
+
353
+ print(split_, len(ds))
354
+
355
+ for track_ in tqdm(ds): # type: ignore
356
+ pprint(track_)
357
+ track_["audio"] = {k: v.shape for k, v in track_["audio"].items()}
358
+ pprint(track_)
359
+
360
+ break
 
 
 
 
 
 
mvsepless/models/bandit/core/data/dnr/preprocess.py CHANGED
@@ -1,51 +1,51 @@
1
- import glob
2
- import os
3
- from typing import Tuple
4
-
5
- import numpy as np
6
- import torchaudio as ta
7
- from tqdm.contrib.concurrent import process_map
8
-
9
-
10
- def process_one(inputs: Tuple[str, str, int]) -> None:
11
- infile, outfile, target_fs = inputs
12
-
13
- dir = os.path.dirname(outfile)
14
- os.makedirs(dir, exist_ok=True)
15
-
16
- data, fs = ta.load(infile)
17
-
18
- if fs != target_fs:
19
- data = ta.functional.resample(
20
- data, fs, target_fs, resampling_method="sinc_interp_kaiser"
21
- )
22
- fs = target_fs
23
-
24
- data = data.numpy()
25
- data = data.astype(np.float32)
26
-
27
- if os.path.exists(outfile):
28
- data_ = np.load(outfile)
29
- if np.allclose(data, data_):
30
- return
31
-
32
- np.save(outfile, data)
33
-
34
-
35
- def preprocess(data_path: str, output_path: str, fs: int) -> None:
36
- files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
37
- print(files)
38
- outfiles = [
39
- f.replace(data_path, output_path).replace(".wav", ".npy") for f in files
40
- ]
41
-
42
- os.makedirs(output_path, exist_ok=True)
43
- inputs = list(zip(files, outfiles, [fs] * len(files)))
44
-
45
- process_map(process_one, inputs, chunksize=32)
46
-
47
-
48
- if __name__ == "__main__":
49
- import fire
50
-
51
- fire.Fire()
 
1
+ import glob
2
+ import os
3
+ from typing import Tuple
4
+
5
+ import numpy as np
6
+ import torchaudio as ta
7
+ from tqdm.contrib.concurrent import process_map
8
+
9
+
10
+ def process_one(inputs: Tuple[str, str, int]) -> None:
11
+ infile, outfile, target_fs = inputs
12
+
13
+ dir = os.path.dirname(outfile)
14
+ os.makedirs(dir, exist_ok=True)
15
+
16
+ data, fs = ta.load(infile)
17
+
18
+ if fs != target_fs:
19
+ data = ta.functional.resample(
20
+ data, fs, target_fs, resampling_method="sinc_interp_kaiser"
21
+ )
22
+ fs = target_fs
23
+
24
+ data = data.numpy()
25
+ data = data.astype(np.float32)
26
+
27
+ if os.path.exists(outfile):
28
+ data_ = np.load(outfile)
29
+ if np.allclose(data, data_):
30
+ return
31
+
32
+ np.save(outfile, data)
33
+
34
+
35
+ def preprocess(data_path: str, output_path: str, fs: int) -> None:
36
+ files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
37
+ print(files)
38
+ outfiles = [
39
+ f.replace(data_path, output_path).replace(".wav", ".npy") for f in files
40
+ ]
41
+
42
+ os.makedirs(output_path, exist_ok=True)
43
+ inputs = list(zip(files, outfiles, [fs] * len(files)))
44
+
45
+ process_map(process_one, inputs, chunksize=32)
46
+
47
+
48
+ if __name__ == "__main__":
49
+ import fire
50
+
51
+ fire.Fire()
mvsepless/models/bandit/core/data/musdb/datamodule.py CHANGED
@@ -1,75 +1,75 @@
1
- import os.path
2
- from typing import Mapping, Optional
3
-
4
- import pytorch_lightning as pl
5
-
6
- from .dataset import (
7
- MUSDB18BaseDataset,
8
- MUSDB18FullTrackDataset,
9
- MUSDB18SadDataset,
10
- MUSDB18SadOnTheFlyAugmentedDataset,
11
- )
12
-
13
-
14
- def MUSDB18DataModule(
15
- data_root: str = "$DATA_ROOT/MUSDB18/HQ",
16
- target_stem: str = "vocals",
17
- batch_size: int = 2,
18
- num_workers: int = 8,
19
- train_kwargs: Optional[Mapping] = None,
20
- val_kwargs: Optional[Mapping] = None,
21
- test_kwargs: Optional[Mapping] = None,
22
- datamodule_kwargs: Optional[Mapping] = None,
23
- use_on_the_fly: bool = True,
24
- npy_memmap: bool = True,
25
- ) -> pl.LightningDataModule:
26
- if train_kwargs is None:
27
- train_kwargs = {}
28
-
29
- if val_kwargs is None:
30
- val_kwargs = {}
31
-
32
- if test_kwargs is None:
33
- test_kwargs = {}
34
-
35
- if datamodule_kwargs is None:
36
- datamodule_kwargs = {}
37
-
38
- train_dataset: MUSDB18BaseDataset
39
-
40
- if use_on_the_fly:
41
- train_dataset = MUSDB18SadOnTheFlyAugmentedDataset(
42
- data_root=os.path.join(data_root, "saded-np"),
43
- split="train",
44
- target_stem=target_stem,
45
- **train_kwargs
46
- )
47
- else:
48
- train_dataset = MUSDB18SadDataset(
49
- data_root=os.path.join(data_root, "saded-np"),
50
- split="train",
51
- target_stem=target_stem,
52
- **train_kwargs
53
- )
54
-
55
- datamodule = pl.LightningDataModule.from_datasets(
56
- train_dataset=train_dataset,
57
- val_dataset=MUSDB18SadDataset(
58
- data_root=os.path.join(data_root, "saded-np"),
59
- split="val",
60
- target_stem=target_stem,
61
- **val_kwargs
62
- ),
63
- test_dataset=MUSDB18FullTrackDataset(
64
- data_root=os.path.join(data_root, "canonical"), split="test", **test_kwargs
65
- ),
66
- batch_size=batch_size,
67
- num_workers=num_workers,
68
- **datamodule_kwargs
69
- )
70
-
71
- datamodule.predict_dataloader = ( # type: ignore[method-assign]
72
- datamodule.test_dataloader
73
- )
74
-
75
- return datamodule
 
1
+ import os.path
2
+ from typing import Mapping, Optional
3
+
4
+ import pytorch_lightning as pl
5
+
6
+ from .dataset import (
7
+ MUSDB18BaseDataset,
8
+ MUSDB18FullTrackDataset,
9
+ MUSDB18SadDataset,
10
+ MUSDB18SadOnTheFlyAugmentedDataset,
11
+ )
12
+
13
+
14
+ def MUSDB18DataModule(
15
+ data_root: str = "$DATA_ROOT/MUSDB18/HQ",
16
+ target_stem: str = "vocals",
17
+ batch_size: int = 2,
18
+ num_workers: int = 8,
19
+ train_kwargs: Optional[Mapping] = None,
20
+ val_kwargs: Optional[Mapping] = None,
21
+ test_kwargs: Optional[Mapping] = None,
22
+ datamodule_kwargs: Optional[Mapping] = None,
23
+ use_on_the_fly: bool = True,
24
+ npy_memmap: bool = True,
25
+ ) -> pl.LightningDataModule:
26
+ if train_kwargs is None:
27
+ train_kwargs = {}
28
+
29
+ if val_kwargs is None:
30
+ val_kwargs = {}
31
+
32
+ if test_kwargs is None:
33
+ test_kwargs = {}
34
+
35
+ if datamodule_kwargs is None:
36
+ datamodule_kwargs = {}
37
+
38
+ train_dataset: MUSDB18BaseDataset
39
+
40
+ if use_on_the_fly:
41
+ train_dataset = MUSDB18SadOnTheFlyAugmentedDataset(
42
+ data_root=os.path.join(data_root, "saded-np"),
43
+ split="train",
44
+ target_stem=target_stem,
45
+ **train_kwargs,
46
+ )
47
+ else:
48
+ train_dataset = MUSDB18SadDataset(
49
+ data_root=os.path.join(data_root, "saded-np"),
50
+ split="train",
51
+ target_stem=target_stem,
52
+ **train_kwargs,
53
+ )
54
+
55
+ datamodule = pl.LightningDataModule.from_datasets(
56
+ train_dataset=train_dataset,
57
+ val_dataset=MUSDB18SadDataset(
58
+ data_root=os.path.join(data_root, "saded-np"),
59
+ split="val",
60
+ target_stem=target_stem,
61
+ **val_kwargs,
62
+ ),
63
+ test_dataset=MUSDB18FullTrackDataset(
64
+ data_root=os.path.join(data_root, "canonical"), split="test", **test_kwargs
65
+ ),
66
+ batch_size=batch_size,
67
+ num_workers=num_workers,
68
+ **datamodule_kwargs,
69
+ )
70
+
71
+ datamodule.predict_dataloader = ( # type: ignore[method-assign]
72
+ datamodule.test_dataloader
73
+ )
74
+
75
+ return datamodule
mvsepless/models/bandit/core/data/musdb/dataset.py CHANGED
@@ -1,273 +1,241 @@
1
- import os
2
- from abc import ABC
3
- from typing import List, Optional, Tuple
4
-
5
- import numpy as np
6
- import torch
7
- import torchaudio as ta
8
- from torch.utils import data
9
-
10
- from .._types import AudioDict, DataDict
11
- from ..base import BaseSourceSeparationDataset
12
-
13
-
14
- class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC):
15
-
16
- ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"]
17
-
18
- def __init__(
19
- self,
20
- split: str,
21
- stems: List[str],
22
- files: List[str],
23
- data_path: str,
24
- fs: int = 44100,
25
- npy_memmap=False,
26
- ) -> None:
27
- super().__init__(
28
- split=split,
29
- stems=stems,
30
- files=files,
31
- data_path=data_path,
32
- fs=fs,
33
- npy_memmap=npy_memmap,
34
- recompute_mixture=False,
35
- )
36
-
37
- def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
38
- track = identifier["track"]
39
- path = os.path.join(self.data_path, track)
40
- # noinspection PyUnresolvedReferences
41
-
42
- if self.npy_memmap:
43
- audio = np.load(os.path.join(path, f"{stem}.wav.npy"), mmap_mode="r")
44
- else:
45
- audio, _ = ta.load(os.path.join(path, f"{stem}.wav"))
46
-
47
- return audio
48
-
49
- def get_identifier(self, index):
50
- return dict(track=self.files[index])
51
-
52
- def __getitem__(self, index: int) -> DataDict:
53
- identifier = self.get_identifier(index)
54
- audio = self.get_audio(identifier)
55
-
56
- return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
57
-
58
-
59
- class MUSDB18FullTrackDataset(MUSDB18BaseDataset):
60
-
61
- N_TRAIN_TRACKS = 100
62
- N_TEST_TRACKS = 50
63
- VALIDATION_FILES = [
64
- "Actions - One Minute Smile",
65
- "Clara Berry And Wooldog - Waltz For My Victims",
66
- "Johnny Lokke - Promises & Lies",
67
- "Patrick Talbot - A Reason To Leave",
68
- "Triviul - Angelsaint",
69
- "Alexander Ross - Goodbye Bolero",
70
- "Fergessen - Nos Palpitants",
71
- "Leaf - Summerghost",
72
- "Skelpolu - Human Mistakes",
73
- "Young Griffo - Pennies",
74
- "ANiMAL - Rockshow",
75
- "James May - On The Line",
76
- "Meaxic - Take A Step",
77
- "Traffic Experiment - Sirens",
78
- ]
79
-
80
- def __init__(
81
- self, data_root: str, split: str, stems: Optional[List[str]] = None
82
- ) -> None:
83
-
84
- if stems is None:
85
- stems = self.ALLOWED_STEMS
86
- self.stems = stems
87
-
88
- if split == "test":
89
- subset = "test"
90
- elif split in ["train", "val"]:
91
- subset = "train"
92
- else:
93
- raise NameError
94
-
95
- data_path = os.path.join(data_root, subset)
96
-
97
- files = sorted(os.listdir(data_path))
98
- files = [f for f in files if not f.startswith(".")]
99
- # pprint(list(enumerate(files)))
100
- if subset == "train":
101
- assert len(files) == 100, len(files)
102
- if split == "train":
103
- files = [f for f in files if f not in self.VALIDATION_FILES]
104
- assert len(files) == 100 - len(self.VALIDATION_FILES)
105
- else:
106
- files = [f for f in files if f in self.VALIDATION_FILES]
107
- assert len(files) == len(self.VALIDATION_FILES)
108
- else:
109
- split = "test"
110
- assert len(files) == 50
111
-
112
- self.n_tracks = len(files)
113
-
114
- super().__init__(data_path=data_path, split=split, stems=stems, files=files)
115
-
116
- def __len__(self) -> int:
117
- return self.n_tracks
118
-
119
-
120
- class MUSDB18SadDataset(MUSDB18BaseDataset):
121
- def __init__(
122
- self,
123
- data_root: str,
124
- split: str,
125
- target_stem: str,
126
- stems: Optional[List[str]] = None,
127
- target_length: Optional[int] = None,
128
- npy_memmap=False,
129
- ) -> None:
130
-
131
- if stems is None:
132
- stems = self.ALLOWED_STEMS
133
-
134
- data_path = os.path.join(data_root, target_stem, split)
135
-
136
- files = sorted(os.listdir(data_path))
137
- files = [f for f in files if not f.startswith(".")]
138
-
139
- super().__init__(
140
- data_path=data_path,
141
- split=split,
142
- stems=stems,
143
- files=files,
144
- npy_memmap=npy_memmap,
145
- )
146
- self.n_segments = len(files)
147
- self.target_stem = target_stem
148
- self.target_length = (
149
- target_length if target_length is not None else self.n_segments
150
- )
151
-
152
- def __len__(self) -> int:
153
- return self.target_length
154
-
155
- def __getitem__(self, index: int) -> DataDict:
156
-
157
- index = index % self.n_segments
158
-
159
- return super().__getitem__(index)
160
-
161
- def get_identifier(self, index):
162
- return super().get_identifier(index % self.n_segments)
163
-
164
-
165
- class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset):
166
- def __init__(
167
- self,
168
- data_root: str,
169
- split: str,
170
- target_stem: str,
171
- stems: Optional[List[str]] = None,
172
- target_length: int = 20000,
173
- apply_probability: Optional[float] = None,
174
- chunk_size_second: float = 3.0,
175
- random_scale_range_db: Tuple[float, float] = (-10, 10),
176
- drop_probability: float = 0.1,
177
- rescale: bool = True,
178
- ) -> None:
179
- super().__init__(data_root, split, target_stem, stems)
180
-
181
- if apply_probability is None:
182
- apply_probability = (target_length - self.n_segments) / target_length
183
-
184
- self.apply_probability = apply_probability
185
- self.drop_probability = drop_probability
186
- self.chunk_size_second = chunk_size_second
187
- self.random_scale_range_db = random_scale_range_db
188
- self.rescale = rescale
189
-
190
- self.chunk_size_sample = int(self.chunk_size_second * self.fs)
191
- self.target_length = target_length
192
-
193
- def __len__(self) -> int:
194
- return self.target_length
195
-
196
- def __getitem__(self, index: int) -> DataDict:
197
-
198
- index = index % self.n_segments
199
-
200
- # if np.random.rand() > self.apply_probability:
201
- # return super().__getitem__(index)
202
-
203
- audio = {}
204
- identifier = self.get_identifier(index)
205
-
206
- # assert self.target_stem in self.stems_no_mixture
207
- for stem in self.stems_no_mixture:
208
- if stem == self.target_stem:
209
- identifier_ = identifier
210
- else:
211
- if np.random.rand() < self.apply_probability:
212
- index_ = np.random.randint(self.n_segments)
213
- identifier_ = self.get_identifier(index_)
214
- else:
215
- identifier_ = identifier
216
-
217
- audio[stem] = self.get_stem(stem=stem, identifier=identifier_)
218
-
219
- # if stem == self.target_stem:
220
-
221
- if self.chunk_size_sample < audio[stem].shape[-1]:
222
- chunk_start = np.random.randint(
223
- audio[stem].shape[-1] - self.chunk_size_sample
224
- )
225
- else:
226
- chunk_start = 0
227
-
228
- if np.random.rand() < self.drop_probability:
229
- # db_scale = "-inf"
230
- linear_scale = 0.0
231
- else:
232
- db_scale = np.random.uniform(*self.random_scale_range_db)
233
- linear_scale = np.power(10, db_scale / 20)
234
- # db_scale = f"{db_scale:+2.1f}"
235
- # print(linear_scale)
236
- audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample] = (
237
- linear_scale
238
- * audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample]
239
- )
240
-
241
- audio["mixture"] = self.compute_mixture(audio)
242
-
243
- if self.rescale:
244
- max_abs_val = max(
245
- [torch.max(torch.abs(audio[stem])) for stem in self.stems]
246
- ) # type: ignore[type-var]
247
- if max_abs_val > 1:
248
- audio = {k: v / max_abs_val for k, v in audio.items()}
249
-
250
- track = identifier["track"]
251
-
252
- return {"audio": audio, "track": f"{self.split}/{track}"}
253
-
254
-
255
- # if __name__ == "__main__":
256
- #
257
- # from pprint import pprint
258
- # from tqdm.auto import tqdm
259
- #
260
- # for split_ in ["train", "val", "test"]:
261
- # ds = MUSDB18SadOnTheFlyAugmentedDataset(
262
- # data_root="$DATA_ROOT/MUSDB18/HQ/saded",
263
- # split=split_,
264
- # target_stem="vocals"
265
- # )
266
- #
267
- # print(split_, len(ds))
268
- #
269
- # for track_ in tqdm(ds):
270
- # track_["audio"] = {
271
- # k: v.shape for k, v in track_["audio"].items()
272
- # }
273
- # pprint(track_)
 
1
+ import os
2
+ from abc import ABC
3
+ from typing import List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchaudio as ta
8
+ from torch.utils import data
9
+
10
+ from .._types import AudioDict, DataDict
11
+ from ..base import BaseSourceSeparationDataset
12
+
13
+
14
+ class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC):
15
+
16
+ ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"]
17
+
18
+ def __init__(
19
+ self,
20
+ split: str,
21
+ stems: List[str],
22
+ files: List[str],
23
+ data_path: str,
24
+ fs: int = 44100,
25
+ npy_memmap=False,
26
+ ) -> None:
27
+ super().__init__(
28
+ split=split,
29
+ stems=stems,
30
+ files=files,
31
+ data_path=data_path,
32
+ fs=fs,
33
+ npy_memmap=npy_memmap,
34
+ recompute_mixture=False,
35
+ )
36
+
37
+ def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
38
+ track = identifier["track"]
39
+ path = os.path.join(self.data_path, track)
40
+
41
+ if self.npy_memmap:
42
+ audio = np.load(os.path.join(path, f"{stem}.wav.npy"), mmap_mode="r")
43
+ else:
44
+ audio, _ = ta.load(os.path.join(path, f"{stem}.wav"))
45
+
46
+ return audio
47
+
48
+ def get_identifier(self, index):
49
+ return dict(track=self.files[index])
50
+
51
+ def __getitem__(self, index: int) -> DataDict:
52
+ identifier = self.get_identifier(index)
53
+ audio = self.get_audio(identifier)
54
+
55
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
56
+
57
+
58
+ class MUSDB18FullTrackDataset(MUSDB18BaseDataset):
59
+
60
+ N_TRAIN_TRACKS = 100
61
+ N_TEST_TRACKS = 50
62
+ VALIDATION_FILES = [
63
+ "Actions - One Minute Smile",
64
+ "Clara Berry And Wooldog - Waltz For My Victims",
65
+ "Johnny Lokke - Promises & Lies",
66
+ "Patrick Talbot - A Reason To Leave",
67
+ "Triviul - Angelsaint",
68
+ "Alexander Ross - Goodbye Bolero",
69
+ "Fergessen - Nos Palpitants",
70
+ "Leaf - Summerghost",
71
+ "Skelpolu - Human Mistakes",
72
+ "Young Griffo - Pennies",
73
+ "ANiMAL - Rockshow",
74
+ "James May - On The Line",
75
+ "Meaxic - Take A Step",
76
+ "Traffic Experiment - Sirens",
77
+ ]
78
+
79
+ def __init__(
80
+ self, data_root: str, split: str, stems: Optional[List[str]] = None
81
+ ) -> None:
82
+
83
+ if stems is None:
84
+ stems = self.ALLOWED_STEMS
85
+ self.stems = stems
86
+
87
+ if split == "test":
88
+ subset = "test"
89
+ elif split in ["train", "val"]:
90
+ subset = "train"
91
+ else:
92
+ raise NameError
93
+
94
+ data_path = os.path.join(data_root, subset)
95
+
96
+ files = sorted(os.listdir(data_path))
97
+ files = [f for f in files if not f.startswith(".")]
98
+ if subset == "train":
99
+ assert len(files) == 100, len(files)
100
+ if split == "train":
101
+ files = [f for f in files if f not in self.VALIDATION_FILES]
102
+ assert len(files) == 100 - len(self.VALIDATION_FILES)
103
+ else:
104
+ files = [f for f in files if f in self.VALIDATION_FILES]
105
+ assert len(files) == len(self.VALIDATION_FILES)
106
+ else:
107
+ split = "test"
108
+ assert len(files) == 50
109
+
110
+ self.n_tracks = len(files)
111
+
112
+ super().__init__(data_path=data_path, split=split, stems=stems, files=files)
113
+
114
+ def __len__(self) -> int:
115
+ return self.n_tracks
116
+
117
+
118
+ class MUSDB18SadDataset(MUSDB18BaseDataset):
119
+ def __init__(
120
+ self,
121
+ data_root: str,
122
+ split: str,
123
+ target_stem: str,
124
+ stems: Optional[List[str]] = None,
125
+ target_length: Optional[int] = None,
126
+ npy_memmap=False,
127
+ ) -> None:
128
+
129
+ if stems is None:
130
+ stems = self.ALLOWED_STEMS
131
+
132
+ data_path = os.path.join(data_root, target_stem, split)
133
+
134
+ files = sorted(os.listdir(data_path))
135
+ files = [f for f in files if not f.startswith(".")]
136
+
137
+ super().__init__(
138
+ data_path=data_path,
139
+ split=split,
140
+ stems=stems,
141
+ files=files,
142
+ npy_memmap=npy_memmap,
143
+ )
144
+ self.n_segments = len(files)
145
+ self.target_stem = target_stem
146
+ self.target_length = (
147
+ target_length if target_length is not None else self.n_segments
148
+ )
149
+
150
+ def __len__(self) -> int:
151
+ return self.target_length
152
+
153
+ def __getitem__(self, index: int) -> DataDict:
154
+
155
+ index = index % self.n_segments
156
+
157
+ return super().__getitem__(index)
158
+
159
+ def get_identifier(self, index):
160
+ return super().get_identifier(index % self.n_segments)
161
+
162
+
163
+ class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset):
164
+ def __init__(
165
+ self,
166
+ data_root: str,
167
+ split: str,
168
+ target_stem: str,
169
+ stems: Optional[List[str]] = None,
170
+ target_length: int = 20000,
171
+ apply_probability: Optional[float] = None,
172
+ chunk_size_second: float = 3.0,
173
+ random_scale_range_db: Tuple[float, float] = (-10, 10),
174
+ drop_probability: float = 0.1,
175
+ rescale: bool = True,
176
+ ) -> None:
177
+ super().__init__(data_root, split, target_stem, stems)
178
+
179
+ if apply_probability is None:
180
+ apply_probability = (target_length - self.n_segments) / target_length
181
+
182
+ self.apply_probability = apply_probability
183
+ self.drop_probability = drop_probability
184
+ self.chunk_size_second = chunk_size_second
185
+ self.random_scale_range_db = random_scale_range_db
186
+ self.rescale = rescale
187
+
188
+ self.chunk_size_sample = int(self.chunk_size_second * self.fs)
189
+ self.target_length = target_length
190
+
191
+ def __len__(self) -> int:
192
+ return self.target_length
193
+
194
+ def __getitem__(self, index: int) -> DataDict:
195
+
196
+ index = index % self.n_segments
197
+
198
+ audio = {}
199
+ identifier = self.get_identifier(index)
200
+
201
+ for stem in self.stems_no_mixture:
202
+ if stem == self.target_stem:
203
+ identifier_ = identifier
204
+ else:
205
+ if np.random.rand() < self.apply_probability:
206
+ index_ = np.random.randint(self.n_segments)
207
+ identifier_ = self.get_identifier(index_)
208
+ else:
209
+ identifier_ = identifier
210
+
211
+ audio[stem] = self.get_stem(stem=stem, identifier=identifier_)
212
+
213
+ if self.chunk_size_sample < audio[stem].shape[-1]:
214
+ chunk_start = np.random.randint(
215
+ audio[stem].shape[-1] - self.chunk_size_sample
216
+ )
217
+ else:
218
+ chunk_start = 0
219
+
220
+ if np.random.rand() < self.drop_probability:
221
+ linear_scale = 0.0
222
+ else:
223
+ db_scale = np.random.uniform(*self.random_scale_range_db)
224
+ linear_scale = np.power(10, db_scale / 20)
225
+ audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample] = (
226
+ linear_scale
227
+ * audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample]
228
+ )
229
+
230
+ audio["mixture"] = self.compute_mixture(audio)
231
+
232
+ if self.rescale:
233
+ max_abs_val = max(
234
+ [torch.max(torch.abs(audio[stem])) for stem in self.stems]
235
+ ) # type: ignore[type-var]
236
+ if max_abs_val > 1:
237
+ audio = {k: v / max_abs_val for k, v in audio.items()}
238
+
239
+ track = identifier["track"]
240
+
241
+ return {"audio": audio, "track": f"{self.split}/{track}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvsepless/models/bandit/core/data/musdb/preprocess.py CHANGED
@@ -1,226 +1,223 @@
1
- import glob
2
- import os
3
-
4
- import numpy as np
5
- import torch
6
- import torchaudio as ta
7
- from torch import nn
8
- from torch.nn import functional as F
9
- from tqdm.contrib.concurrent import process_map
10
-
11
- from .._types import DataDict
12
- from .dataset import MUSDB18FullTrackDataset
13
- import pyloudnorm as pyln
14
-
15
-
16
- class SourceActivityDetector(nn.Module):
17
- def __init__(
18
- self,
19
- analysis_stem: str,
20
- output_path: str,
21
- fs: int = 44100,
22
- segment_length_second: float = 6.0,
23
- hop_length_second: float = 3.0,
24
- n_chunks: int = 10,
25
- chunk_epsilon: float = 1e-5,
26
- energy_threshold_quantile: float = 0.15,
27
- segment_epsilon: float = 1e-3,
28
- salient_proportion_threshold: float = 0.5,
29
- target_lufs: float = -24,
30
- ) -> None:
31
- super().__init__()
32
-
33
- self.fs = fs
34
- self.segment_length = int(segment_length_second * self.fs)
35
- self.hop_length = int(hop_length_second * self.fs)
36
- self.n_chunks = n_chunks
37
- assert self.segment_length % self.n_chunks == 0
38
- self.chunk_size = self.segment_length // self.n_chunks
39
- self.chunk_epsilon = chunk_epsilon
40
- self.energy_threshold_quantile = energy_threshold_quantile
41
- self.segment_epsilon = segment_epsilon
42
- self.salient_proportion_threshold = salient_proportion_threshold
43
- self.analysis_stem = analysis_stem
44
-
45
- self.meter = pyln.Meter(self.fs)
46
- self.target_lufs = target_lufs
47
-
48
- self.output_path = output_path
49
-
50
- def forward(self, data: DataDict) -> None:
51
-
52
- stem_ = self.analysis_stem if (self.analysis_stem != "none") else "mixture"
53
-
54
- x = data["audio"][stem_]
55
-
56
- xnp = x.numpy()
57
- loudness = self.meter.integrated_loudness(xnp.T)
58
-
59
- for stem in data["audio"]:
60
- s = data["audio"][stem]
61
- s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T
62
- s = torch.as_tensor(s)
63
- data["audio"][stem] = s
64
-
65
- if x.ndim == 3:
66
- assert x.shape[0] == 1
67
- x = x[0]
68
-
69
- n_chan, n_samples = x.shape
70
-
71
- n_segments = (
72
- int(np.ceil((n_samples - self.segment_length) / self.hop_length)) + 1
73
- )
74
-
75
- segments = torch.zeros((n_segments, n_chan, self.segment_length))
76
- for i in range(n_segments):
77
- start = i * self.hop_length
78
- end = start + self.segment_length
79
- end = min(end, n_samples)
80
-
81
- xseg = x[:, start:end]
82
-
83
- if end - start < self.segment_length:
84
- xseg = F.pad(
85
- xseg, pad=(0, self.segment_length - (end - start)), value=torch.nan
86
- )
87
-
88
- segments[i, :, :] = xseg
89
-
90
- chunks = segments.reshape((n_segments, n_chan, self.n_chunks, self.chunk_size))
91
-
92
- if self.analysis_stem != "none":
93
- chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3))
94
- chunk_energies = torch.nan_to_num(chunk_energies, nan=0)
95
- chunk_energies[chunk_energies == 0] = self.chunk_epsilon
96
-
97
- energy_threshold = torch.nanquantile(
98
- chunk_energies, q=self.energy_threshold_quantile
99
- )
100
-
101
- if energy_threshold < self.segment_epsilon:
102
- energy_threshold = self.segment_epsilon # type: ignore[assignment]
103
-
104
- chunks_above_threshold = chunk_energies > energy_threshold
105
- n_chunks_above_threshold = torch.mean(
106
- chunks_above_threshold.to(torch.float), dim=-1
107
- )
108
-
109
- segment_above_threshold = (
110
- n_chunks_above_threshold > self.salient_proportion_threshold
111
- )
112
-
113
- if torch.sum(segment_above_threshold) == 0:
114
- return
115
-
116
- else:
117
- segment_above_threshold = torch.ones((n_segments,))
118
-
119
- for i in range(n_segments):
120
- if not segment_above_threshold[i]:
121
- continue
122
-
123
- outpath = os.path.join(
124
- self.output_path,
125
- self.analysis_stem,
126
- f"{data['track']} - {self.analysis_stem}{i:03d}",
127
- )
128
- os.makedirs(outpath, exist_ok=True)
129
-
130
- for stem in data["audio"]:
131
- if stem == self.analysis_stem:
132
- segment = torch.nan_to_num(segments[i, :, :], nan=0)
133
- else:
134
- start = i * self.hop_length
135
- end = start + self.segment_length
136
- end = min(n_samples, end)
137
-
138
- segment = data["audio"][stem][:, start:end]
139
-
140
- if end - start < self.segment_length:
141
- segment = F.pad(
142
- segment, (0, self.segment_length - (end - start))
143
- )
144
-
145
- assert segment.shape[-1] == self.segment_length, segment.shape
146
-
147
- # ta.save(os.path.join(outpath, f"{stem}.wav"), segment, self.fs)
148
-
149
- np.save(os.path.join(outpath, f"{stem}.wav"), segment)
150
-
151
-
152
- def preprocess(
153
- analysis_stem: str,
154
- output_path: str = "/data/MUSDB18/HQ/saded-np",
155
- fs: int = 44100,
156
- segment_length_second: float = 6.0,
157
- hop_length_second: float = 3.0,
158
- n_chunks: int = 10,
159
- chunk_epsilon: float = 1e-5,
160
- energy_threshold_quantile: float = 0.15,
161
- segment_epsilon: float = 1e-3,
162
- salient_proportion_threshold: float = 0.5,
163
- ) -> None:
164
-
165
- sad = SourceActivityDetector(
166
- analysis_stem=analysis_stem,
167
- output_path=output_path,
168
- fs=fs,
169
- segment_length_second=segment_length_second,
170
- hop_length_second=hop_length_second,
171
- n_chunks=n_chunks,
172
- chunk_epsilon=chunk_epsilon,
173
- energy_threshold_quantile=energy_threshold_quantile,
174
- segment_epsilon=segment_epsilon,
175
- salient_proportion_threshold=salient_proportion_threshold,
176
- )
177
-
178
- for split in ["train", "val", "test"]:
179
- ds = MUSDB18FullTrackDataset(
180
- data_root="/data/MUSDB18/HQ/canonical",
181
- split=split,
182
- )
183
-
184
- tracks = []
185
- for i, track in enumerate(tqdm(ds, total=len(ds))):
186
- if i % 32 == 0 and tracks:
187
- process_map(sad, tracks, max_workers=8)
188
- tracks = []
189
- tracks.append(track)
190
- process_map(sad, tracks, max_workers=8)
191
-
192
-
193
- def loudness_norm_one(inputs):
194
- infile, outfile, target_lufs = inputs
195
-
196
- audio, fs = ta.load(infile)
197
- audio = audio.mean(dim=0, keepdim=True).numpy().T
198
-
199
- meter = pyln.Meter(fs)
200
- loudness = meter.integrated_loudness(audio)
201
- audio = pyln.normalize.loudness(audio, loudness, target_lufs)
202
-
203
- os.makedirs(os.path.dirname(outfile), exist_ok=True)
204
- np.save(outfile, audio.T)
205
-
206
-
207
- def loudness_norm(
208
- data_path: str,
209
- # output_path: str,
210
- target_lufs=-17.0,
211
- ):
212
- files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
213
-
214
- outfiles = [f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files]
215
-
216
- files = [(f, o, target_lufs) for f, o in zip(files, outfiles)]
217
-
218
- process_map(loudness_norm_one, files, chunksize=2)
219
-
220
-
221
- if __name__ == "__main__":
222
-
223
- from tqdm.auto import tqdm
224
- import fire
225
-
226
- fire.Fire()
 
1
+ import glob
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torchaudio as ta
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from tqdm.contrib.concurrent import process_map
10
+
11
+ from .._types import DataDict
12
+ from .dataset import MUSDB18FullTrackDataset
13
+ import pyloudnorm as pyln
14
+
15
+
16
+ class SourceActivityDetector(nn.Module):
17
+ def __init__(
18
+ self,
19
+ analysis_stem: str,
20
+ output_path: str,
21
+ fs: int = 44100,
22
+ segment_length_second: float = 6.0,
23
+ hop_length_second: float = 3.0,
24
+ n_chunks: int = 10,
25
+ chunk_epsilon: float = 1e-5,
26
+ energy_threshold_quantile: float = 0.15,
27
+ segment_epsilon: float = 1e-3,
28
+ salient_proportion_threshold: float = 0.5,
29
+ target_lufs: float = -24,
30
+ ) -> None:
31
+ super().__init__()
32
+
33
+ self.fs = fs
34
+ self.segment_length = int(segment_length_second * self.fs)
35
+ self.hop_length = int(hop_length_second * self.fs)
36
+ self.n_chunks = n_chunks
37
+ assert self.segment_length % self.n_chunks == 0
38
+ self.chunk_size = self.segment_length // self.n_chunks
39
+ self.chunk_epsilon = chunk_epsilon
40
+ self.energy_threshold_quantile = energy_threshold_quantile
41
+ self.segment_epsilon = segment_epsilon
42
+ self.salient_proportion_threshold = salient_proportion_threshold
43
+ self.analysis_stem = analysis_stem
44
+
45
+ self.meter = pyln.Meter(self.fs)
46
+ self.target_lufs = target_lufs
47
+
48
+ self.output_path = output_path
49
+
50
+ def forward(self, data: DataDict) -> None:
51
+
52
+ stem_ = self.analysis_stem if (self.analysis_stem != "none") else "mixture"
53
+
54
+ x = data["audio"][stem_]
55
+
56
+ xnp = x.numpy()
57
+ loudness = self.meter.integrated_loudness(xnp.T)
58
+
59
+ for stem in data["audio"]:
60
+ s = data["audio"][stem]
61
+ s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T
62
+ s = torch.as_tensor(s)
63
+ data["audio"][stem] = s
64
+
65
+ if x.ndim == 3:
66
+ assert x.shape[0] == 1
67
+ x = x[0]
68
+
69
+ n_chan, n_samples = x.shape
70
+
71
+ n_segments = (
72
+ int(np.ceil((n_samples - self.segment_length) / self.hop_length)) + 1
73
+ )
74
+
75
+ segments = torch.zeros((n_segments, n_chan, self.segment_length))
76
+ for i in range(n_segments):
77
+ start = i * self.hop_length
78
+ end = start + self.segment_length
79
+ end = min(end, n_samples)
80
+
81
+ xseg = x[:, start:end]
82
+
83
+ if end - start < self.segment_length:
84
+ xseg = F.pad(
85
+ xseg, pad=(0, self.segment_length - (end - start)), value=torch.nan
86
+ )
87
+
88
+ segments[i, :, :] = xseg
89
+
90
+ chunks = segments.reshape((n_segments, n_chan, self.n_chunks, self.chunk_size))
91
+
92
+ if self.analysis_stem != "none":
93
+ chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3))
94
+ chunk_energies = torch.nan_to_num(chunk_energies, nan=0)
95
+ chunk_energies[chunk_energies == 0] = self.chunk_epsilon
96
+
97
+ energy_threshold = torch.nanquantile(
98
+ chunk_energies, q=self.energy_threshold_quantile
99
+ )
100
+
101
+ if energy_threshold < self.segment_epsilon:
102
+ energy_threshold = self.segment_epsilon # type: ignore[assignment]
103
+
104
+ chunks_above_threshold = chunk_energies > energy_threshold
105
+ n_chunks_above_threshold = torch.mean(
106
+ chunks_above_threshold.to(torch.float), dim=-1
107
+ )
108
+
109
+ segment_above_threshold = (
110
+ n_chunks_above_threshold > self.salient_proportion_threshold
111
+ )
112
+
113
+ if torch.sum(segment_above_threshold) == 0:
114
+ return
115
+
116
+ else:
117
+ segment_above_threshold = torch.ones((n_segments,))
118
+
119
+ for i in range(n_segments):
120
+ if not segment_above_threshold[i]:
121
+ continue
122
+
123
+ outpath = os.path.join(
124
+ self.output_path,
125
+ self.analysis_stem,
126
+ f"{data['track']} - {self.analysis_stem}{i:03d}",
127
+ )
128
+ os.makedirs(outpath, exist_ok=True)
129
+
130
+ for stem in data["audio"]:
131
+ if stem == self.analysis_stem:
132
+ segment = torch.nan_to_num(segments[i, :, :], nan=0)
133
+ else:
134
+ start = i * self.hop_length
135
+ end = start + self.segment_length
136
+ end = min(n_samples, end)
137
+
138
+ segment = data["audio"][stem][:, start:end]
139
+
140
+ if end - start < self.segment_length:
141
+ segment = F.pad(
142
+ segment, (0, self.segment_length - (end - start))
143
+ )
144
+
145
+ assert segment.shape[-1] == self.segment_length, segment.shape
146
+
147
+ np.save(os.path.join(outpath, f"{stem}.wav"), segment)
148
+
149
+
150
+ def preprocess(
151
+ analysis_stem: str,
152
+ output_path: str = "/data/MUSDB18/HQ/saded-np",
153
+ fs: int = 44100,
154
+ segment_length_second: float = 6.0,
155
+ hop_length_second: float = 3.0,
156
+ n_chunks: int = 10,
157
+ chunk_epsilon: float = 1e-5,
158
+ energy_threshold_quantile: float = 0.15,
159
+ segment_epsilon: float = 1e-3,
160
+ salient_proportion_threshold: float = 0.5,
161
+ ) -> None:
162
+
163
+ sad = SourceActivityDetector(
164
+ analysis_stem=analysis_stem,
165
+ output_path=output_path,
166
+ fs=fs,
167
+ segment_length_second=segment_length_second,
168
+ hop_length_second=hop_length_second,
169
+ n_chunks=n_chunks,
170
+ chunk_epsilon=chunk_epsilon,
171
+ energy_threshold_quantile=energy_threshold_quantile,
172
+ segment_epsilon=segment_epsilon,
173
+ salient_proportion_threshold=salient_proportion_threshold,
174
+ )
175
+
176
+ for split in ["train", "val", "test"]:
177
+ ds = MUSDB18FullTrackDataset(
178
+ data_root="/data/MUSDB18/HQ/canonical",
179
+ split=split,
180
+ )
181
+
182
+ tracks = []
183
+ for i, track in enumerate(tqdm(ds, total=len(ds))):
184
+ if i % 32 == 0 and tracks:
185
+ process_map(sad, tracks, max_workers=8)
186
+ tracks = []
187
+ tracks.append(track)
188
+ process_map(sad, tracks, max_workers=8)
189
+
190
+
191
+ def loudness_norm_one(inputs):
192
+ infile, outfile, target_lufs = inputs
193
+
194
+ audio, fs = ta.load(infile)
195
+ audio = audio.mean(dim=0, keepdim=True).numpy().T
196
+
197
+ meter = pyln.Meter(fs)
198
+ loudness = meter.integrated_loudness(audio)
199
+ audio = pyln.normalize.loudness(audio, loudness, target_lufs)
200
+
201
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
202
+ np.save(outfile, audio.T)
203
+
204
+
205
+ def loudness_norm(
206
+ data_path: str,
207
+ target_lufs=-17.0,
208
+ ):
209
+ files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
210
+
211
+ outfiles = [f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files]
212
+
213
+ files = [(f, o, target_lufs) for f, o in zip(files, outfiles)]
214
+
215
+ process_map(loudness_norm_one, files, chunksize=2)
216
+
217
+
218
+ if __name__ == "__main__":
219
+
220
+ from tqdm.auto import tqdm
221
+ import fire
222
+
223
+ fire.Fire()
 
 
 
mvsepless/models/bandit/core/data/musdb/validation.yaml CHANGED
@@ -1,15 +1,15 @@
1
- validation:
2
- - 'Actions - One Minute Smile'
3
- - 'Clara Berry And Wooldog - Waltz For My Victims'
4
- - 'Johnny Lokke - Promises & Lies'
5
- - 'Patrick Talbot - A Reason To Leave'
6
- - 'Triviul - Angelsaint'
7
- - 'Alexander Ross - Goodbye Bolero'
8
- - 'Fergessen - Nos Palpitants'
9
- - 'Leaf - Summerghost'
10
- - 'Skelpolu - Human Mistakes'
11
- - 'Young Griffo - Pennies'
12
- - 'ANiMAL - Rockshow'
13
- - 'James May - On The Line'
14
- - 'Meaxic - Take A Step'
15
  - 'Traffic Experiment - Sirens'
 
1
+ validation:
2
+ - 'Actions - One Minute Smile'
3
+ - 'Clara Berry And Wooldog - Waltz For My Victims'
4
+ - 'Johnny Lokke - Promises & Lies'
5
+ - 'Patrick Talbot - A Reason To Leave'
6
+ - 'Triviul - Angelsaint'
7
+ - 'Alexander Ross - Goodbye Bolero'
8
+ - 'Fergessen - Nos Palpitants'
9
+ - 'Leaf - Summerghost'
10
+ - 'Skelpolu - Human Mistakes'
11
+ - 'Young Griffo - Pennies'
12
+ - 'ANiMAL - Rockshow'
13
+ - 'James May - On The Line'
14
+ - 'Meaxic - Take A Step'
15
  - 'Traffic Experiment - Sirens'
mvsepless/models/bandit/core/loss/__init__.py CHANGED
@@ -1,8 +1,8 @@
1
- from ._multistem import MultiStemWrapperFromConfig
2
- from ._timefreq import (
3
- ReImL1Loss,
4
- ReImL2Loss,
5
- TimeFreqL1Loss,
6
- TimeFreqL2Loss,
7
- TimeFreqSignalNoisePNormRatioLoss,
8
- )
 
1
+ from ._multistem import MultiStemWrapperFromConfig
2
+ from ._timefreq import (
3
+ ReImL1Loss,
4
+ ReImL2Loss,
5
+ TimeFreqL1Loss,
6
+ TimeFreqL2Loss,
7
+ TimeFreqSignalNoisePNormRatioLoss,
8
+ )
mvsepless/models/bandit/core/loss/_complex.py CHANGED
@@ -1,27 +1,27 @@
1
- from typing import Any
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn.modules import loss as _loss
6
- from torch.nn.modules.loss import _Loss
7
-
8
-
9
- class ReImLossWrapper(_Loss):
10
- def __init__(self, module: _Loss) -> None:
11
- super().__init__()
12
- self.module = module
13
-
14
- def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
15
- return self.module(torch.view_as_real(preds), torch.view_as_real(target))
16
-
17
-
18
- class ReImL1Loss(ReImLossWrapper):
19
- def __init__(self, **kwargs: Any) -> None:
20
- l1_loss = _loss.L1Loss(**kwargs)
21
- super().__init__(module=(l1_loss))
22
-
23
-
24
- class ReImL2Loss(ReImLossWrapper):
25
- def __init__(self, **kwargs: Any) -> None:
26
- l2_loss = _loss.MSELoss(**kwargs)
27
- super().__init__(module=(l2_loss))
 
1
+ from typing import Any
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.modules import loss as _loss
6
+ from torch.nn.modules.loss import _Loss
7
+
8
+
9
+ class ReImLossWrapper(_Loss):
10
+ def __init__(self, module: _Loss) -> None:
11
+ super().__init__()
12
+ self.module = module
13
+
14
+ def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
15
+ return self.module(torch.view_as_real(preds), torch.view_as_real(target))
16
+
17
+
18
+ class ReImL1Loss(ReImLossWrapper):
19
+ def __init__(self, **kwargs: Any) -> None:
20
+ l1_loss = _loss.L1Loss(**kwargs)
21
+ super().__init__(module=(l1_loss))
22
+
23
+
24
+ class ReImL2Loss(ReImLossWrapper):
25
+ def __init__(self, **kwargs: Any) -> None:
26
+ l2_loss = _loss.MSELoss(**kwargs)
27
+ super().__init__(module=(l2_loss))
mvsepless/models/bandit/core/loss/_multistem.py CHANGED
@@ -1,43 +1,43 @@
1
- from typing import Any, Dict
2
-
3
- import torch
4
- from asteroid import losses as asteroid_losses
5
- from torch import nn
6
- from torch.nn.modules.loss import _Loss
7
-
8
- from . import snr
9
-
10
-
11
- def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss:
12
-
13
- for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]:
14
- if name in module.__dict__:
15
- return module.__dict__[name](**kwargs)
16
-
17
- raise NameError
18
-
19
-
20
- class MultiStemWrapper(_Loss):
21
- def __init__(self, module: _Loss, modality: str = "audio") -> None:
22
- super().__init__()
23
- self.loss = module
24
- self.modality = modality
25
-
26
- def forward(
27
- self,
28
- preds: Dict[str, Dict[str, torch.Tensor]],
29
- target: Dict[str, Dict[str, torch.Tensor]],
30
- ) -> torch.Tensor:
31
- loss = {
32
- stem: self.loss(preds[self.modality][stem], target[self.modality][stem])
33
- for stem in preds[self.modality]
34
- if stem in target[self.modality]
35
- }
36
-
37
- return sum(list(loss.values()))
38
-
39
-
40
- class MultiStemWrapperFromConfig(MultiStemWrapper):
41
- def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None:
42
- loss = parse_loss(name, kwargs)
43
- super().__init__(module=loss, modality=modality)
 
1
+ from typing import Any, Dict
2
+
3
+ import torch
4
+ from asteroid import losses as asteroid_losses
5
+ from torch import nn
6
+ from torch.nn.modules.loss import _Loss
7
+
8
+ from . import snr
9
+
10
+
11
+ def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss:
12
+
13
+ for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]:
14
+ if name in module.__dict__:
15
+ return module.__dict__[name](**kwargs)
16
+
17
+ raise NameError
18
+
19
+
20
+ class MultiStemWrapper(_Loss):
21
+ def __init__(self, module: _Loss, modality: str = "audio") -> None:
22
+ super().__init__()
23
+ self.loss = module
24
+ self.modality = modality
25
+
26
+ def forward(
27
+ self,
28
+ preds: Dict[str, Dict[str, torch.Tensor]],
29
+ target: Dict[str, Dict[str, torch.Tensor]],
30
+ ) -> torch.Tensor:
31
+ loss = {
32
+ stem: self.loss(preds[self.modality][stem], target[self.modality][stem])
33
+ for stem in preds[self.modality]
34
+ if stem in target[self.modality]
35
+ }
36
+
37
+ return sum(list(loss.values()))
38
+
39
+
40
+ class MultiStemWrapperFromConfig(MultiStemWrapper):
41
+ def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None:
42
+ loss = parse_loss(name, kwargs)
43
+ super().__init__(module=loss, modality=modality)
mvsepless/models/bandit/core/loss/_timefreq.py CHANGED
@@ -1,95 +1,94 @@
1
- from typing import Any, Dict, Optional
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn.modules.loss import _Loss
6
-
7
- from ._multistem import MultiStemWrapper
8
- from ._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper
9
- from .snr import SignalNoisePNormRatio
10
-
11
-
12
- class TimeFreqWrapper(_Loss):
13
- def __init__(
14
- self,
15
- time_module: _Loss,
16
- freq_module: Optional[_Loss] = None,
17
- time_weight: float = 1.0,
18
- freq_weight: float = 1.0,
19
- multistem: bool = True,
20
- ) -> None:
21
- super().__init__()
22
-
23
- if freq_module is None:
24
- freq_module = time_module
25
-
26
- if multistem:
27
- time_module = MultiStemWrapper(time_module, modality="audio")
28
- freq_module = MultiStemWrapper(freq_module, modality="spectrogram")
29
-
30
- self.time_module = time_module
31
- self.freq_module = freq_module
32
-
33
- self.time_weight = time_weight
34
- self.freq_weight = freq_weight
35
-
36
- # TODO: add better type hints
37
- def forward(self, preds: Any, target: Any) -> torch.Tensor:
38
-
39
- return self.time_weight * self.time_module(
40
- preds, target
41
- ) + self.freq_weight * self.freq_module(preds, target)
42
-
43
-
44
- class TimeFreqL1Loss(TimeFreqWrapper):
45
- def __init__(
46
- self,
47
- time_weight: float = 1.0,
48
- freq_weight: float = 1.0,
49
- tkwargs: Optional[Dict[str, Any]] = None,
50
- fkwargs: Optional[Dict[str, Any]] = None,
51
- multistem: bool = True,
52
- ) -> None:
53
- if tkwargs is None:
54
- tkwargs = {}
55
- if fkwargs is None:
56
- fkwargs = {}
57
- time_module = nn.L1Loss(**tkwargs)
58
- freq_module = ReImL1Loss(**fkwargs)
59
- super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
60
-
61
-
62
- class TimeFreqL2Loss(TimeFreqWrapper):
63
- def __init__(
64
- self,
65
- time_weight: float = 1.0,
66
- freq_weight: float = 1.0,
67
- tkwargs: Optional[Dict[str, Any]] = None,
68
- fkwargs: Optional[Dict[str, Any]] = None,
69
- multistem: bool = True,
70
- ) -> None:
71
- if tkwargs is None:
72
- tkwargs = {}
73
- if fkwargs is None:
74
- fkwargs = {}
75
- time_module = nn.MSELoss(**tkwargs)
76
- freq_module = ReImL2Loss(**fkwargs)
77
- super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
78
-
79
-
80
- class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper):
81
- def __init__(
82
- self,
83
- time_weight: float = 1.0,
84
- freq_weight: float = 1.0,
85
- tkwargs: Optional[Dict[str, Any]] = None,
86
- fkwargs: Optional[Dict[str, Any]] = None,
87
- multistem: bool = True,
88
- ) -> None:
89
- if tkwargs is None:
90
- tkwargs = {}
91
- if fkwargs is None:
92
- fkwargs = {}
93
- time_module = SignalNoisePNormRatio(**tkwargs)
94
- freq_module = SignalNoisePNormRatio(**fkwargs)
95
- super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.modules.loss import _Loss
6
+
7
+ from ._multistem import MultiStemWrapper
8
+ from ._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper
9
+ from .snr import SignalNoisePNormRatio
10
+
11
+
12
+ class TimeFreqWrapper(_Loss):
13
+ def __init__(
14
+ self,
15
+ time_module: _Loss,
16
+ freq_module: Optional[_Loss] = None,
17
+ time_weight: float = 1.0,
18
+ freq_weight: float = 1.0,
19
+ multistem: bool = True,
20
+ ) -> None:
21
+ super().__init__()
22
+
23
+ if freq_module is None:
24
+ freq_module = time_module
25
+
26
+ if multistem:
27
+ time_module = MultiStemWrapper(time_module, modality="audio")
28
+ freq_module = MultiStemWrapper(freq_module, modality="spectrogram")
29
+
30
+ self.time_module = time_module
31
+ self.freq_module = freq_module
32
+
33
+ self.time_weight = time_weight
34
+ self.freq_weight = freq_weight
35
+
36
+ def forward(self, preds: Any, target: Any) -> torch.Tensor:
37
+
38
+ return self.time_weight * self.time_module(
39
+ preds, target
40
+ ) + self.freq_weight * self.freq_module(preds, target)
41
+
42
+
43
+ class TimeFreqL1Loss(TimeFreqWrapper):
44
+ def __init__(
45
+ self,
46
+ time_weight: float = 1.0,
47
+ freq_weight: float = 1.0,
48
+ tkwargs: Optional[Dict[str, Any]] = None,
49
+ fkwargs: Optional[Dict[str, Any]] = None,
50
+ multistem: bool = True,
51
+ ) -> None:
52
+ if tkwargs is None:
53
+ tkwargs = {}
54
+ if fkwargs is None:
55
+ fkwargs = {}
56
+ time_module = nn.L1Loss(**tkwargs)
57
+ freq_module = ReImL1Loss(**fkwargs)
58
+ super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
59
+
60
+
61
+ class TimeFreqL2Loss(TimeFreqWrapper):
62
+ def __init__(
63
+ self,
64
+ time_weight: float = 1.0,
65
+ freq_weight: float = 1.0,
66
+ tkwargs: Optional[Dict[str, Any]] = None,
67
+ fkwargs: Optional[Dict[str, Any]] = None,
68
+ multistem: bool = True,
69
+ ) -> None:
70
+ if tkwargs is None:
71
+ tkwargs = {}
72
+ if fkwargs is None:
73
+ fkwargs = {}
74
+ time_module = nn.MSELoss(**tkwargs)
75
+ freq_module = ReImL2Loss(**fkwargs)
76
+ super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
77
+
78
+
79
+ class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper):
80
+ def __init__(
81
+ self,
82
+ time_weight: float = 1.0,
83
+ freq_weight: float = 1.0,
84
+ tkwargs: Optional[Dict[str, Any]] = None,
85
+ fkwargs: Optional[Dict[str, Any]] = None,
86
+ multistem: bool = True,
87
+ ) -> None:
88
+ if tkwargs is None:
89
+ tkwargs = {}
90
+ if fkwargs is None:
91
+ fkwargs = {}
92
+ time_module = SignalNoisePNormRatio(**tkwargs)
93
+ freq_module = SignalNoisePNormRatio(**fkwargs)
94
+ super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
 
mvsepless/models/bandit/core/loss/snr.py CHANGED
@@ -1,139 +1,131 @@
1
- import torch
2
- from torch.nn.modules.loss import _Loss
3
- from torch.nn import functional as F
4
-
5
-
6
- class SignalNoisePNormRatio(_Loss):
7
- def __init__(
8
- self,
9
- p: float = 1.0,
10
- scale_invariant: bool = False,
11
- zero_mean: bool = False,
12
- take_log: bool = True,
13
- reduction: str = "mean",
14
- EPS: float = 1e-3,
15
- ) -> None:
16
- assert reduction != "sum", NotImplementedError
17
- super().__init__(reduction=reduction)
18
- assert not zero_mean
19
-
20
- self.p = p
21
-
22
- self.EPS = EPS
23
- self.take_log = take_log
24
-
25
- self.scale_invariant = scale_invariant
26
-
27
- def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
28
-
29
- target_ = target
30
- if self.scale_invariant:
31
- ndim = target.ndim
32
- dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True)
33
- s_target_energy = torch.sum(
34
- target * torch.conj(target), dim=-1, keepdim=True
35
- )
36
-
37
- if ndim > 2:
38
- dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True)
39
- s_target_energy = torch.sum(
40
- s_target_energy, dim=list(range(1, ndim)), keepdim=True
41
- )
42
-
43
- target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8)
44
- target = target_ * target_scaler
45
-
46
- if torch.is_complex(est_target):
47
- est_target = torch.view_as_real(est_target)
48
- target = torch.view_as_real(target)
49
-
50
- batch_size = est_target.shape[0]
51
- est_target = est_target.reshape(batch_size, -1)
52
- target = target.reshape(batch_size, -1)
53
- # target_ = target_.reshape(batch_size, -1)
54
-
55
- if self.p == 1:
56
- e_error = torch.abs(est_target - target).mean(dim=-1)
57
- e_target = torch.abs(target).mean(dim=-1)
58
- elif self.p == 2:
59
- e_error = torch.square(est_target - target).mean(dim=-1)
60
- e_target = torch.square(target).mean(dim=-1)
61
- else:
62
- raise NotImplementedError
63
-
64
- if self.take_log:
65
- loss = 10 * (
66
- torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS)
67
- )
68
- else:
69
- loss = (e_error + self.EPS) / (e_target + self.EPS)
70
-
71
- if self.reduction == "mean":
72
- loss = loss.mean()
73
- elif self.reduction == "sum":
74
- loss = loss.sum()
75
-
76
- return loss
77
-
78
-
79
- class MultichannelSingleSrcNegSDR(_Loss):
80
- def __init__(
81
- self,
82
- sdr_type: str,
83
- p: float = 2.0,
84
- zero_mean: bool = True,
85
- take_log: bool = True,
86
- reduction: str = "mean",
87
- EPS: float = 1e-8,
88
- ) -> None:
89
- assert reduction != "sum", NotImplementedError
90
- super().__init__(reduction=reduction)
91
-
92
- assert sdr_type in ["snr", "sisdr", "sdsdr"]
93
- self.sdr_type = sdr_type
94
- self.zero_mean = zero_mean
95
- self.take_log = take_log
96
- self.EPS = 1e-8
97
-
98
- self.p = p
99
-
100
- def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
101
- if target.size() != est_target.size() or target.ndim != 3:
102
- raise TypeError(
103
- f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead"
104
- )
105
- # Step 1. Zero-mean norm
106
- if self.zero_mean:
107
- mean_source = torch.mean(target, dim=[1, 2], keepdim=True)
108
- mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True)
109
- target = target - mean_source
110
- est_target = est_target - mean_estimate
111
- # Step 2. Pair-wise SI-SDR.
112
- if self.sdr_type in ["sisdr", "sdsdr"]:
113
- # [batch, 1]
114
- dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True)
115
- # [batch, 1]
116
- s_target_energy = torch.sum(target**2, dim=[1, 2], keepdim=True) + self.EPS
117
- # [batch, time]
118
- scaled_target = dot * target / s_target_energy
119
- else:
120
- # [batch, time]
121
- scaled_target = target
122
- if self.sdr_type in ["sdsdr", "snr"]:
123
- e_noise = est_target - target
124
- else:
125
- e_noise = est_target - scaled_target
126
- # [batch]
127
-
128
- if self.p == 2.0:
129
- losses = torch.sum(scaled_target**2, dim=[1, 2]) / (
130
- torch.sum(e_noise**2, dim=[1, 2]) + self.EPS
131
- )
132
- else:
133
- losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / (
134
- torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
135
- )
136
- if self.take_log:
137
- losses = 10 * torch.log10(losses + self.EPS)
138
- losses = losses.mean() if self.reduction == "mean" else losses
139
- return -losses
 
1
+ import torch
2
+ from torch.nn.modules.loss import _Loss
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class SignalNoisePNormRatio(_Loss):
7
+ def __init__(
8
+ self,
9
+ p: float = 1.0,
10
+ scale_invariant: bool = False,
11
+ zero_mean: bool = False,
12
+ take_log: bool = True,
13
+ reduction: str = "mean",
14
+ EPS: float = 1e-3,
15
+ ) -> None:
16
+ assert reduction != "sum", NotImplementedError
17
+ super().__init__(reduction=reduction)
18
+ assert not zero_mean
19
+
20
+ self.p = p
21
+
22
+ self.EPS = EPS
23
+ self.take_log = take_log
24
+
25
+ self.scale_invariant = scale_invariant
26
+
27
+ def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
28
+
29
+ target_ = target
30
+ if self.scale_invariant:
31
+ ndim = target.ndim
32
+ dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True)
33
+ s_target_energy = torch.sum(
34
+ target * torch.conj(target), dim=-1, keepdim=True
35
+ )
36
+
37
+ if ndim > 2:
38
+ dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True)
39
+ s_target_energy = torch.sum(
40
+ s_target_energy, dim=list(range(1, ndim)), keepdim=True
41
+ )
42
+
43
+ target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8)
44
+ target = target_ * target_scaler
45
+
46
+ if torch.is_complex(est_target):
47
+ est_target = torch.view_as_real(est_target)
48
+ target = torch.view_as_real(target)
49
+
50
+ batch_size = est_target.shape[0]
51
+ est_target = est_target.reshape(batch_size, -1)
52
+ target = target.reshape(batch_size, -1)
53
+
54
+ if self.p == 1:
55
+ e_error = torch.abs(est_target - target).mean(dim=-1)
56
+ e_target = torch.abs(target).mean(dim=-1)
57
+ elif self.p == 2:
58
+ e_error = torch.square(est_target - target).mean(dim=-1)
59
+ e_target = torch.square(target).mean(dim=-1)
60
+ else:
61
+ raise NotImplementedError
62
+
63
+ if self.take_log:
64
+ loss = 10 * (
65
+ torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS)
66
+ )
67
+ else:
68
+ loss = (e_error + self.EPS) / (e_target + self.EPS)
69
+
70
+ if self.reduction == "mean":
71
+ loss = loss.mean()
72
+ elif self.reduction == "sum":
73
+ loss = loss.sum()
74
+
75
+ return loss
76
+
77
+
78
+ class MultichannelSingleSrcNegSDR(_Loss):
79
+ def __init__(
80
+ self,
81
+ sdr_type: str,
82
+ p: float = 2.0,
83
+ zero_mean: bool = True,
84
+ take_log: bool = True,
85
+ reduction: str = "mean",
86
+ EPS: float = 1e-8,
87
+ ) -> None:
88
+ assert reduction != "sum", NotImplementedError
89
+ super().__init__(reduction=reduction)
90
+
91
+ assert sdr_type in ["snr", "sisdr", "sdsdr"]
92
+ self.sdr_type = sdr_type
93
+ self.zero_mean = zero_mean
94
+ self.take_log = take_log
95
+ self.EPS = 1e-8
96
+
97
+ self.p = p
98
+
99
+ def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
100
+ if target.size() != est_target.size() or target.ndim != 3:
101
+ raise TypeError(
102
+ f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead"
103
+ )
104
+ if self.zero_mean:
105
+ mean_source = torch.mean(target, dim=[1, 2], keepdim=True)
106
+ mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True)
107
+ target = target - mean_source
108
+ est_target = est_target - mean_estimate
109
+ if self.sdr_type in ["sisdr", "sdsdr"]:
110
+ dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True)
111
+ s_target_energy = torch.sum(target**2, dim=[1, 2], keepdim=True) + self.EPS
112
+ scaled_target = dot * target / s_target_energy
113
+ else:
114
+ scaled_target = target
115
+ if self.sdr_type in ["sdsdr", "snr"]:
116
+ e_noise = est_target - target
117
+ else:
118
+ e_noise = est_target - scaled_target
119
+
120
+ if self.p == 2.0:
121
+ losses = torch.sum(scaled_target**2, dim=[1, 2]) / (
122
+ torch.sum(e_noise**2, dim=[1, 2]) + self.EPS
123
+ )
124
+ else:
125
+ losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / (
126
+ torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
127
+ )
128
+ if self.take_log:
129
+ losses = 10 * torch.log10(losses + self.EPS)
130
+ losses = losses.mean() if self.reduction == "mean" else losses
131
+ return -losses
 
 
 
 
 
 
 
 
mvsepless/models/bandit/core/metrics/__init__.py CHANGED
@@ -1,9 +1,7 @@
1
- from .snr import (
2
- ChunkMedianScaleInvariantSignalDistortionRatio,
3
- ChunkMedianScaleInvariantSignalNoiseRatio,
4
- ChunkMedianSignalDistortionRatio,
5
- ChunkMedianSignalNoiseRatio,
6
- SafeSignalDistortionRatio,
7
- )
8
-
9
- # from .mushra import EstimatedMushraScore
 
1
+ from .snr import (
2
+ ChunkMedianScaleInvariantSignalDistortionRatio,
3
+ ChunkMedianScaleInvariantSignalNoiseRatio,
4
+ ChunkMedianSignalDistortionRatio,
5
+ ChunkMedianSignalNoiseRatio,
6
+ SafeSignalDistortionRatio,
7
+ )
 
 
mvsepless/models/bandit/core/metrics/_squim.py CHANGED
@@ -1,443 +1,350 @@
1
- from dataclasses import dataclass
2
-
3
- from torchaudio._internal import load_state_dict_from_url
4
-
5
- import math
6
- from typing import List, Optional, Tuple
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
-
12
-
13
- def transform_wb_pesq_range(x: float) -> float:
14
- """The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined
15
- for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric
16
- defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score".
17
-
18
- Args:
19
- x (float): Narrow-band PESQ score.
20
-
21
- Returns:
22
- (float): Wide-band PESQ score.
23
- """
24
- return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
25
-
26
-
27
- PESQRange: Tuple[float, float] = (
28
- 1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of
29
- # the raw score is not -0.5 anymore. It's hard to figure out the true lower bound.
30
- # We are using 1.0 as a reasonable approximation.
31
- transform_wb_pesq_range(4.5),
32
- )
33
-
34
-
35
- class RangeSigmoid(nn.Module):
36
- def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
37
- super(RangeSigmoid, self).__init__()
38
- assert isinstance(val_range, tuple) and len(val_range) == 2
39
- self.val_range: Tuple[float, float] = val_range
40
- self.sigmoid: nn.modules.Module = nn.Sigmoid()
41
-
42
- def forward(self, x: torch.Tensor) -> torch.Tensor:
43
- out = (
44
- self.sigmoid(x) * (self.val_range[1] - self.val_range[0])
45
- + self.val_range[0]
46
- )
47
- return out
48
-
49
-
50
- class Encoder(nn.Module):
51
- """Encoder module that transform 1D waveform to 2D representations.
52
-
53
- Args:
54
- feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512)
55
- win_len (int, optional): kernel size in the Conv1D layer. (Default: 32)
56
- """
57
-
58
- def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
59
- super(Encoder, self).__init__()
60
-
61
- self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
62
-
63
- def forward(self, x: torch.Tensor) -> torch.Tensor:
64
- """Apply waveforms to convolutional layer and ReLU layer.
65
-
66
- Args:
67
- x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
68
-
69
- Returns:
70
- (torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`.
71
- """
72
- out = x.unsqueeze(dim=1)
73
- out = F.relu(self.conv1d(out))
74
- return out
75
-
76
-
77
- class SingleRNN(nn.Module):
78
- def __init__(
79
- self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0
80
- ) -> None:
81
- super(SingleRNN, self).__init__()
82
-
83
- self.rnn_type = rnn_type
84
- self.input_size = input_size
85
- self.hidden_size = hidden_size
86
-
87
- self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
88
- input_size,
89
- hidden_size,
90
- 1,
91
- dropout=dropout,
92
- batch_first=True,
93
- bidirectional=True,
94
- )
95
-
96
- self.proj = nn.Linear(hidden_size * 2, input_size)
97
-
98
- def forward(self, x: torch.Tensor) -> torch.Tensor:
99
- # input shape: batch, seq, dim
100
- out, _ = self.rnn(x)
101
- out = self.proj(out)
102
- return out
103
-
104
-
105
- class DPRNN(nn.Module):
106
- """*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`.
107
-
108
- Args:
109
- feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64)
110
- hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128)
111
- num_blocks (int, optional): Number of DPRNN layers. (Default: 6)
112
- rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM")
113
- d_model (int, optional): The number of expected features in the input. (Default: 256)
114
- chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100)
115
- chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50)
116
- """
117
-
118
- def __init__(
119
- self,
120
- feat_dim: int = 64,
121
- hidden_dim: int = 128,
122
- num_blocks: int = 6,
123
- rnn_type: str = "LSTM",
124
- d_model: int = 256,
125
- chunk_size: int = 100,
126
- chunk_stride: int = 50,
127
- ) -> None:
128
- super(DPRNN, self).__init__()
129
-
130
- self.num_blocks = num_blocks
131
-
132
- self.row_rnn = nn.ModuleList([])
133
- self.col_rnn = nn.ModuleList([])
134
- self.row_norm = nn.ModuleList([])
135
- self.col_norm = nn.ModuleList([])
136
- for _ in range(num_blocks):
137
- self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
138
- self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
139
- self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
140
- self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
141
- self.conv = nn.Sequential(
142
- nn.Conv2d(feat_dim, d_model, 1),
143
- nn.PReLU(),
144
- )
145
- self.chunk_size = chunk_size
146
- self.chunk_stride = chunk_stride
147
-
148
- def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
149
- # input shape: (B, N, T)
150
- seq_len = x.shape[-1]
151
-
152
- rest = (
153
- self.chunk_size
154
- - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
155
- )
156
- out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
157
-
158
- return out, rest
159
-
160
- def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
161
- out, rest = self.pad_chunk(x)
162
- batch_size, feat_dim, seq_len = out.shape
163
-
164
- segments1 = (
165
- out[:, :, : -self.chunk_stride]
166
- .contiguous()
167
- .view(batch_size, feat_dim, -1, self.chunk_size)
168
- )
169
- segments2 = (
170
- out[:, :, self.chunk_stride :]
171
- .contiguous()
172
- .view(batch_size, feat_dim, -1, self.chunk_size)
173
- )
174
- out = torch.cat([segments1, segments2], dim=3)
175
- out = (
176
- out.view(batch_size, feat_dim, -1, self.chunk_size)
177
- .transpose(2, 3)
178
- .contiguous()
179
- )
180
-
181
- return out, rest
182
-
183
- def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
184
- batch_size, dim, _, _ = x.shape
185
- out = (
186
- x.transpose(2, 3)
187
- .contiguous()
188
- .view(batch_size, dim, -1, self.chunk_size * 2)
189
- )
190
- out1 = (
191
- out[:, :, :, : self.chunk_size]
192
- .contiguous()
193
- .view(batch_size, dim, -1)[:, :, self.chunk_stride :]
194
- )
195
- out2 = (
196
- out[:, :, :, self.chunk_size :]
197
- .contiguous()
198
- .view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
199
- )
200
- out = out1 + out2
201
- if rest > 0:
202
- out = out[:, :, :-rest]
203
- out = out.contiguous()
204
- return out
205
-
206
- def forward(self, x: torch.Tensor) -> torch.Tensor:
207
- x, rest = self.chunking(x)
208
- batch_size, _, dim1, dim2 = x.shape
209
- out = x
210
- for row_rnn, row_norm, col_rnn, col_norm in zip(
211
- self.row_rnn, self.row_norm, self.col_rnn, self.col_norm
212
- ):
213
- row_in = (
214
- out.permute(0, 3, 2, 1)
215
- .contiguous()
216
- .view(batch_size * dim2, dim1, -1)
217
- .contiguous()
218
- )
219
- row_out = row_rnn(row_in)
220
- row_out = (
221
- row_out.view(batch_size, dim2, dim1, -1)
222
- .permute(0, 3, 2, 1)
223
- .contiguous()
224
- )
225
- row_out = row_norm(row_out)
226
- out = out + row_out
227
-
228
- col_in = (
229
- out.permute(0, 2, 3, 1)
230
- .contiguous()
231
- .view(batch_size * dim1, dim2, -1)
232
- .contiguous()
233
- )
234
- col_out = col_rnn(col_in)
235
- col_out = (
236
- col_out.view(batch_size, dim1, dim2, -1)
237
- .permute(0, 3, 1, 2)
238
- .contiguous()
239
- )
240
- col_out = col_norm(col_out)
241
- out = out + col_out
242
- out = self.conv(out)
243
- out = self.merging(out, rest)
244
- out = out.transpose(1, 2).contiguous()
245
- return out
246
-
247
-
248
- class AutoPool(nn.Module):
249
- def __init__(self, pool_dim: int = 1) -> None:
250
- super(AutoPool, self).__init__()
251
- self.pool_dim: int = pool_dim
252
- self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
253
- self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
254
-
255
- def forward(self, x: torch.Tensor) -> torch.Tensor:
256
- weight = self.softmax(torch.mul(x, self.alpha))
257
- out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
258
- return out
259
-
260
-
261
- class SquimObjective(nn.Module):
262
- """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores
263
- for speech enhancement (e.g., STOI, PESQ, and SI-SDR).
264
-
265
- Args:
266
- encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation.
267
- dprnn (torch.nn.Module): DPRNN module to model sequential feature.
268
- branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score.
269
- """
270
-
271
- def __init__(
272
- self,
273
- encoder: nn.Module,
274
- dprnn: nn.Module,
275
- branches: nn.ModuleList,
276
- ):
277
- super(SquimObjective, self).__init__()
278
- self.encoder = encoder
279
- self.dprnn = dprnn
280
- self.branches = branches
281
-
282
- def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
283
- """
284
- Args:
285
- x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
286
-
287
- Returns:
288
- List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`.
289
- """
290
- if x.ndim != 2:
291
- raise ValueError(
292
- f"The input must be a 2D Tensor. Found dimension {x.ndim}."
293
- )
294
- x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
295
- out = self.encoder(x)
296
- out = self.dprnn(out)
297
- scores = []
298
- for branch in self.branches:
299
- scores.append(branch(out).squeeze(dim=1))
300
- return scores
301
-
302
-
303
- def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
304
- """Create branch module after DPRNN model for predicting metric score.
305
-
306
- Args:
307
- d_model (int): The number of expected features in the input.
308
- nhead (int): Number of heads in the multi-head attention model.
309
- metric (str): The metric name to predict.
310
-
311
- Returns:
312
- (nn.Module): Returned module to predict corresponding metric score.
313
- """
314
- layer1 = nn.TransformerEncoderLayer(
315
- d_model, nhead, d_model * 4, dropout=0.0, batch_first=True
316
- )
317
- layer2 = AutoPool()
318
- if metric == "stoi":
319
- layer3 = nn.Sequential(
320
- nn.Linear(d_model, d_model),
321
- nn.PReLU(),
322
- nn.Linear(d_model, 1),
323
- RangeSigmoid(),
324
- )
325
- elif metric == "pesq":
326
- layer3 = nn.Sequential(
327
- nn.Linear(d_model, d_model),
328
- nn.PReLU(),
329
- nn.Linear(d_model, 1),
330
- RangeSigmoid(val_range=PESQRange),
331
- )
332
- else:
333
- layer3: nn.modules.Module = nn.Sequential(
334
- nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1)
335
- )
336
- return nn.Sequential(layer1, layer2, layer3)
337
-
338
-
339
- def squim_objective_model(
340
- feat_dim: int,
341
- win_len: int,
342
- d_model: int,
343
- nhead: int,
344
- hidden_dim: int,
345
- num_blocks: int,
346
- rnn_type: str,
347
- chunk_size: int,
348
- chunk_stride: Optional[int] = None,
349
- ) -> SquimObjective:
350
- """Build a custome :class:`torchaudio.prototype.models.SquimObjective` model.
351
-
352
- Args:
353
- feat_dim (int, optional): The feature dimension after Encoder module.
354
- win_len (int): Kernel size in the Encoder module.
355
- d_model (int): The number of expected features in the input.
356
- nhead (int): Number of heads in the multi-head attention model.
357
- hidden_dim (int): Hidden dimension in the RNN layer of DPRNN.
358
- num_blocks (int): Number of DPRNN layers.
359
- rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"].
360
- chunk_size (int): Chunk size of input for DPRNN.
361
- chunk_stride (int or None, optional): Stride of chunk input for DPRNN.
362
- """
363
- if chunk_stride is None:
364
- chunk_stride = chunk_size // 2
365
- encoder = Encoder(feat_dim, win_len)
366
- dprnn = DPRNN(
367
- feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride
368
- )
369
- branches = nn.ModuleList(
370
- [
371
- _create_branch(d_model, nhead, "stoi"),
372
- _create_branch(d_model, nhead, "pesq"),
373
- _create_branch(d_model, nhead, "sisdr"),
374
- ]
375
- )
376
- return SquimObjective(encoder, dprnn, branches)
377
-
378
-
379
- def squim_objective_base() -> SquimObjective:
380
- """Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments."""
381
- return squim_objective_model(
382
- feat_dim=256,
383
- win_len=64,
384
- d_model=256,
385
- nhead=4,
386
- hidden_dim=256,
387
- num_blocks=2,
388
- rnn_type="LSTM",
389
- chunk_size=71,
390
- )
391
-
392
-
393
- @dataclass
394
- class SquimObjectiveBundle:
395
-
396
- _path: str
397
- _sample_rate: float
398
-
399
- def _get_state_dict(self, dl_kwargs):
400
- url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
401
- dl_kwargs = {} if dl_kwargs is None else dl_kwargs
402
- state_dict = load_state_dict_from_url(url, **dl_kwargs)
403
- return state_dict
404
-
405
- def get_model(self, *, dl_kwargs=None) -> SquimObjective:
406
- """Construct the SquimObjective model, and load the pretrained weight.
407
-
408
- The weight file is downloaded from the internet and cached with
409
- :func:`torch.hub.load_state_dict_from_url`
410
-
411
- Args:
412
- dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
413
-
414
- Returns:
415
- Variation of :py:class:`~torchaudio.models.SquimObjective`.
416
- """
417
- model = squim_objective_base()
418
- model.load_state_dict(self._get_state_dict(dl_kwargs))
419
- model.eval()
420
- return model
421
-
422
- @property
423
- def sample_rate(self):
424
- """Sample rate of the audio that the model is trained on.
425
-
426
- :type: float
427
- """
428
- return self._sample_rate
429
-
430
-
431
- SQUIM_OBJECTIVE = SquimObjectiveBundle(
432
- "squim_objective_dns2020.pth",
433
- _sample_rate=16000,
434
- )
435
- SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
436
- :cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
437
-
438
- The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
439
- The weights are under `Creative Commons Attribution 4.0 International License
440
- <https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
441
-
442
- Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
443
- """
 
1
+ from dataclasses import dataclass
2
+
3
+ from torchaudio._internal import load_state_dict_from_url
4
+
5
+ import math
6
+ from typing import List, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def transform_wb_pesq_range(x: float) -> float:
14
+ return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
15
+
16
+
17
+ PESQRange: Tuple[float, float] = (
18
+ 1.0,
19
+ transform_wb_pesq_range(4.5),
20
+ )
21
+
22
+
23
+ class RangeSigmoid(nn.Module):
24
+ def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
25
+ super(RangeSigmoid, self).__init__()
26
+ assert isinstance(val_range, tuple) and len(val_range) == 2
27
+ self.val_range: Tuple[float, float] = val_range
28
+ self.sigmoid: nn.modules.Module = nn.Sigmoid()
29
+
30
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
31
+ out = (
32
+ self.sigmoid(x) * (self.val_range[1] - self.val_range[0])
33
+ + self.val_range[0]
34
+ )
35
+ return out
36
+
37
+
38
+ class Encoder(nn.Module):
39
+
40
+ def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
41
+ super(Encoder, self).__init__()
42
+
43
+ self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
44
+
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
+ out = x.unsqueeze(dim=1)
47
+ out = F.relu(self.conv1d(out))
48
+ return out
49
+
50
+
51
+ class SingleRNN(nn.Module):
52
+ def __init__(
53
+ self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0
54
+ ) -> None:
55
+ super(SingleRNN, self).__init__()
56
+
57
+ self.rnn_type = rnn_type
58
+ self.input_size = input_size
59
+ self.hidden_size = hidden_size
60
+
61
+ self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
62
+ input_size,
63
+ hidden_size,
64
+ 1,
65
+ dropout=dropout,
66
+ batch_first=True,
67
+ bidirectional=True,
68
+ )
69
+
70
+ self.proj = nn.Linear(hidden_size * 2, input_size)
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ out, _ = self.rnn(x)
74
+ out = self.proj(out)
75
+ return out
76
+
77
+
78
+ class DPRNN(nn.Module):
79
+
80
+ def __init__(
81
+ self,
82
+ feat_dim: int = 64,
83
+ hidden_dim: int = 128,
84
+ num_blocks: int = 6,
85
+ rnn_type: str = "LSTM",
86
+ d_model: int = 256,
87
+ chunk_size: int = 100,
88
+ chunk_stride: int = 50,
89
+ ) -> None:
90
+ super(DPRNN, self).__init__()
91
+
92
+ self.num_blocks = num_blocks
93
+
94
+ self.row_rnn = nn.ModuleList([])
95
+ self.col_rnn = nn.ModuleList([])
96
+ self.row_norm = nn.ModuleList([])
97
+ self.col_norm = nn.ModuleList([])
98
+ for _ in range(num_blocks):
99
+ self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
100
+ self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
101
+ self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
102
+ self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
103
+ self.conv = nn.Sequential(
104
+ nn.Conv2d(feat_dim, d_model, 1),
105
+ nn.PReLU(),
106
+ )
107
+ self.chunk_size = chunk_size
108
+ self.chunk_stride = chunk_stride
109
+
110
+ def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
111
+ seq_len = x.shape[-1]
112
+
113
+ rest = (
114
+ self.chunk_size
115
+ - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
116
+ )
117
+ out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
118
+
119
+ return out, rest
120
+
121
+ def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
122
+ out, rest = self.pad_chunk(x)
123
+ batch_size, feat_dim, seq_len = out.shape
124
+
125
+ segments1 = (
126
+ out[:, :, : -self.chunk_stride]
127
+ .contiguous()
128
+ .view(batch_size, feat_dim, -1, self.chunk_size)
129
+ )
130
+ segments2 = (
131
+ out[:, :, self.chunk_stride :]
132
+ .contiguous()
133
+ .view(batch_size, feat_dim, -1, self.chunk_size)
134
+ )
135
+ out = torch.cat([segments1, segments2], dim=3)
136
+ out = (
137
+ out.view(batch_size, feat_dim, -1, self.chunk_size)
138
+ .transpose(2, 3)
139
+ .contiguous()
140
+ )
141
+
142
+ return out, rest
143
+
144
+ def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
145
+ batch_size, dim, _, _ = x.shape
146
+ out = (
147
+ x.transpose(2, 3)
148
+ .contiguous()
149
+ .view(batch_size, dim, -1, self.chunk_size * 2)
150
+ )
151
+ out1 = (
152
+ out[:, :, :, : self.chunk_size]
153
+ .contiguous()
154
+ .view(batch_size, dim, -1)[:, :, self.chunk_stride :]
155
+ )
156
+ out2 = (
157
+ out[:, :, :, self.chunk_size :]
158
+ .contiguous()
159
+ .view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
160
+ )
161
+ out = out1 + out2
162
+ if rest > 0:
163
+ out = out[:, :, :-rest]
164
+ out = out.contiguous()
165
+ return out
166
+
167
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
168
+ x, rest = self.chunking(x)
169
+ batch_size, _, dim1, dim2 = x.shape
170
+ out = x
171
+ for row_rnn, row_norm, col_rnn, col_norm in zip(
172
+ self.row_rnn, self.row_norm, self.col_rnn, self.col_norm
173
+ ):
174
+ row_in = (
175
+ out.permute(0, 3, 2, 1)
176
+ .contiguous()
177
+ .view(batch_size * dim2, dim1, -1)
178
+ .contiguous()
179
+ )
180
+ row_out = row_rnn(row_in)
181
+ row_out = (
182
+ row_out.view(batch_size, dim2, dim1, -1)
183
+ .permute(0, 3, 2, 1)
184
+ .contiguous()
185
+ )
186
+ row_out = row_norm(row_out)
187
+ out = out + row_out
188
+
189
+ col_in = (
190
+ out.permute(0, 2, 3, 1)
191
+ .contiguous()
192
+ .view(batch_size * dim1, dim2, -1)
193
+ .contiguous()
194
+ )
195
+ col_out = col_rnn(col_in)
196
+ col_out = (
197
+ col_out.view(batch_size, dim1, dim2, -1)
198
+ .permute(0, 3, 1, 2)
199
+ .contiguous()
200
+ )
201
+ col_out = col_norm(col_out)
202
+ out = out + col_out
203
+ out = self.conv(out)
204
+ out = self.merging(out, rest)
205
+ out = out.transpose(1, 2).contiguous()
206
+ return out
207
+
208
+
209
+ class AutoPool(nn.Module):
210
+ def __init__(self, pool_dim: int = 1) -> None:
211
+ super(AutoPool, self).__init__()
212
+ self.pool_dim: int = pool_dim
213
+ self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
214
+ self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
215
+
216
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
217
+ weight = self.softmax(torch.mul(x, self.alpha))
218
+ out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
219
+ return out
220
+
221
+
222
+ class SquimObjective(nn.Module):
223
+
224
+ def __init__(
225
+ self,
226
+ encoder: nn.Module,
227
+ dprnn: nn.Module,
228
+ branches: nn.ModuleList,
229
+ ):
230
+ super(SquimObjective, self).__init__()
231
+ self.encoder = encoder
232
+ self.dprnn = dprnn
233
+ self.branches = branches
234
+
235
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
236
+ if x.ndim != 2:
237
+ raise ValueError(
238
+ f"The input must be a 2D Tensor. Found dimension {x.ndim}."
239
+ )
240
+ x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
241
+ out = self.encoder(x)
242
+ out = self.dprnn(out)
243
+ scores = []
244
+ for branch in self.branches:
245
+ scores.append(branch(out).squeeze(dim=1))
246
+ return scores
247
+
248
+
249
+ def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
250
+ layer1 = nn.TransformerEncoderLayer(
251
+ d_model, nhead, d_model * 4, dropout=0.0, batch_first=True
252
+ )
253
+ layer2 = AutoPool()
254
+ if metric == "stoi":
255
+ layer3 = nn.Sequential(
256
+ nn.Linear(d_model, d_model),
257
+ nn.PReLU(),
258
+ nn.Linear(d_model, 1),
259
+ RangeSigmoid(),
260
+ )
261
+ elif metric == "pesq":
262
+ layer3 = nn.Sequential(
263
+ nn.Linear(d_model, d_model),
264
+ nn.PReLU(),
265
+ nn.Linear(d_model, 1),
266
+ RangeSigmoid(val_range=PESQRange),
267
+ )
268
+ else:
269
+ layer3: nn.modules.Module = nn.Sequential(
270
+ nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1)
271
+ )
272
+ return nn.Sequential(layer1, layer2, layer3)
273
+
274
+
275
+ def squim_objective_model(
276
+ feat_dim: int,
277
+ win_len: int,
278
+ d_model: int,
279
+ nhead: int,
280
+ hidden_dim: int,
281
+ num_blocks: int,
282
+ rnn_type: str,
283
+ chunk_size: int,
284
+ chunk_stride: Optional[int] = None,
285
+ ) -> SquimObjective:
286
+ if chunk_stride is None:
287
+ chunk_stride = chunk_size // 2
288
+ encoder = Encoder(feat_dim, win_len)
289
+ dprnn = DPRNN(
290
+ feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride
291
+ )
292
+ branches = nn.ModuleList(
293
+ [
294
+ _create_branch(d_model, nhead, "stoi"),
295
+ _create_branch(d_model, nhead, "pesq"),
296
+ _create_branch(d_model, nhead, "sisdr"),
297
+ ]
298
+ )
299
+ return SquimObjective(encoder, dprnn, branches)
300
+
301
+
302
+ def squim_objective_base() -> SquimObjective:
303
+ return squim_objective_model(
304
+ feat_dim=256,
305
+ win_len=64,
306
+ d_model=256,
307
+ nhead=4,
308
+ hidden_dim=256,
309
+ num_blocks=2,
310
+ rnn_type="LSTM",
311
+ chunk_size=71,
312
+ )
313
+
314
+
315
+ @dataclass
316
+ class SquimObjectiveBundle:
317
+
318
+ _path: str
319
+ _sample_rate: float
320
+
321
+ def _get_state_dict(self, dl_kwargs):
322
+ url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
323
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
324
+ state_dict = load_state_dict_from_url(url, **dl_kwargs)
325
+ return state_dict
326
+
327
+ def get_model(self, *, dl_kwargs=None) -> SquimObjective:
328
+ model = squim_objective_base()
329
+ model.load_state_dict(self._get_state_dict(dl_kwargs))
330
+ model.eval()
331
+ return model
332
+
333
+ @property
334
+ def sample_rate(self):
335
+ return self._sample_rate
336
+
337
+
338
+ SQUIM_OBJECTIVE = SquimObjectiveBundle(
339
+ "squim_objective_dns2020.pth",
340
+ _sample_rate=16000,
341
+ )
342
+ SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
343
+ :cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
344
+
345
+ The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
346
+ The weights are under `Creative Commons Attribution 4.0 International License
347
+ <https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
348
+
349
+ Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
350
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvsepless/models/bandit/core/metrics/snr.py CHANGED
@@ -1,127 +1,124 @@
1
- from typing import Any, Callable
2
-
3
- import numpy as np
4
- import torch
5
- import torchmetrics as tm
6
- from torch._C import _LinAlgError
7
- from torchmetrics import functional as tmF
8
-
9
-
10
- class SafeSignalDistortionRatio(tm.SignalDistortionRatio):
11
- def __init__(self, **kwargs) -> None:
12
- super().__init__(**kwargs)
13
-
14
- def update(self, *args, **kwargs) -> Any:
15
- try:
16
- super().update(*args, **kwargs)
17
- except:
18
- pass
19
-
20
- def compute(self) -> Any:
21
- if self.total == 0:
22
- return torch.tensor(torch.nan)
23
- return super().compute()
24
-
25
-
26
- class BaseChunkMedianSignalRatio(tm.Metric):
27
- def __init__(
28
- self,
29
- func: Callable,
30
- window_size: int,
31
- hop_size: int = None,
32
- zero_mean: bool = False,
33
- ) -> None:
34
- super().__init__()
35
-
36
- # self.zero_mean = zero_mean
37
- self.func = func
38
- self.window_size = window_size
39
- if hop_size is None:
40
- hop_size = window_size
41
- self.hop_size = hop_size
42
-
43
- self.add_state("sum_snr", default=torch.tensor(0.0), dist_reduce_fx="sum")
44
- self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
45
-
46
- def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
47
-
48
- n_samples = target.shape[-1]
49
-
50
- n_chunks = int(np.ceil((n_samples - self.window_size) / self.hop_size) + 1)
51
-
52
- snr_chunk = []
53
-
54
- for i in range(n_chunks):
55
- start = i * self.hop_size
56
-
57
- if n_samples - start < self.window_size:
58
- continue
59
-
60
- end = start + self.window_size
61
-
62
- try:
63
- chunk_snr = self.func(preds[..., start:end], target[..., start:end])
64
-
65
- # print(preds.shape, chunk_snr.shape)
66
-
67
- if torch.all(torch.isfinite(chunk_snr)):
68
- snr_chunk.append(chunk_snr)
69
- except _LinAlgError:
70
- pass
71
-
72
- snr_chunk = torch.stack(snr_chunk, dim=-1)
73
- snr_batch, _ = torch.nanmedian(snr_chunk, dim=-1)
74
-
75
- self.sum_snr += snr_batch.sum()
76
- self.total += snr_batch.numel()
77
-
78
- def compute(self) -> Any:
79
- return self.sum_snr / self.total
80
-
81
-
82
- class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio):
83
- def __init__(
84
- self, window_size: int, hop_size: int = None, zero_mean: bool = False
85
- ) -> None:
86
- super().__init__(
87
- func=tmF.signal_noise_ratio,
88
- window_size=window_size,
89
- hop_size=hop_size,
90
- zero_mean=zero_mean,
91
- )
92
-
93
-
94
- class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio):
95
- def __init__(
96
- self, window_size: int, hop_size: int = None, zero_mean: bool = False
97
- ) -> None:
98
- super().__init__(
99
- func=tmF.scale_invariant_signal_noise_ratio,
100
- window_size=window_size,
101
- hop_size=hop_size,
102
- zero_mean=zero_mean,
103
- )
104
-
105
-
106
- class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio):
107
- def __init__(
108
- self, window_size: int, hop_size: int = None, zero_mean: bool = False
109
- ) -> None:
110
- super().__init__(
111
- func=tmF.signal_distortion_ratio,
112
- window_size=window_size,
113
- hop_size=hop_size,
114
- zero_mean=zero_mean,
115
- )
116
-
117
-
118
- class ChunkMedianScaleInvariantSignalDistortionRatio(BaseChunkMedianSignalRatio):
119
- def __init__(
120
- self, window_size: int, hop_size: int = None, zero_mean: bool = False
121
- ) -> None:
122
- super().__init__(
123
- func=tmF.scale_invariant_signal_distortion_ratio,
124
- window_size=window_size,
125
- hop_size=hop_size,
126
- zero_mean=zero_mean,
127
- )
 
1
+ from typing import Any, Callable
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchmetrics as tm
6
+ from torch._C import _LinAlgError
7
+ from torchmetrics import functional as tmF
8
+
9
+
10
+ class SafeSignalDistortionRatio(tm.SignalDistortionRatio):
11
+ def __init__(self, **kwargs) -> None:
12
+ super().__init__(**kwargs)
13
+
14
+ def update(self, *args, **kwargs) -> Any:
15
+ try:
16
+ super().update(*args, **kwargs)
17
+ except:
18
+ pass
19
+
20
+ def compute(self) -> Any:
21
+ if self.total == 0:
22
+ return torch.tensor(torch.nan)
23
+ return super().compute()
24
+
25
+
26
+ class BaseChunkMedianSignalRatio(tm.Metric):
27
+ def __init__(
28
+ self,
29
+ func: Callable,
30
+ window_size: int,
31
+ hop_size: int = None,
32
+ zero_mean: bool = False,
33
+ ) -> None:
34
+ super().__init__()
35
+
36
+ self.func = func
37
+ self.window_size = window_size
38
+ if hop_size is None:
39
+ hop_size = window_size
40
+ self.hop_size = hop_size
41
+
42
+ self.add_state("sum_snr", default=torch.tensor(0.0), dist_reduce_fx="sum")
43
+ self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
44
+
45
+ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
46
+
47
+ n_samples = target.shape[-1]
48
+
49
+ n_chunks = int(np.ceil((n_samples - self.window_size) / self.hop_size) + 1)
50
+
51
+ snr_chunk = []
52
+
53
+ for i in range(n_chunks):
54
+ start = i * self.hop_size
55
+
56
+ if n_samples - start < self.window_size:
57
+ continue
58
+
59
+ end = start + self.window_size
60
+
61
+ try:
62
+ chunk_snr = self.func(preds[..., start:end], target[..., start:end])
63
+
64
+ if torch.all(torch.isfinite(chunk_snr)):
65
+ snr_chunk.append(chunk_snr)
66
+ except _LinAlgError:
67
+ pass
68
+
69
+ snr_chunk = torch.stack(snr_chunk, dim=-1)
70
+ snr_batch, _ = torch.nanmedian(snr_chunk, dim=-1)
71
+
72
+ self.sum_snr += snr_batch.sum()
73
+ self.total += snr_batch.numel()
74
+
75
+ def compute(self) -> Any:
76
+ return self.sum_snr / self.total
77
+
78
+
79
+ class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio):
80
+ def __init__(
81
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
82
+ ) -> None:
83
+ super().__init__(
84
+ func=tmF.signal_noise_ratio,
85
+ window_size=window_size,
86
+ hop_size=hop_size,
87
+ zero_mean=zero_mean,
88
+ )
89
+
90
+
91
+ class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio):
92
+ def __init__(
93
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
94
+ ) -> None:
95
+ super().__init__(
96
+ func=tmF.scale_invariant_signal_noise_ratio,
97
+ window_size=window_size,
98
+ hop_size=hop_size,
99
+ zero_mean=zero_mean,
100
+ )
101
+
102
+
103
+ class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio):
104
+ def __init__(
105
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
106
+ ) -> None:
107
+ super().__init__(
108
+ func=tmF.signal_distortion_ratio,
109
+ window_size=window_size,
110
+ hop_size=hop_size,
111
+ zero_mean=zero_mean,
112
+ )
113
+
114
+
115
+ class ChunkMedianScaleInvariantSignalDistortionRatio(BaseChunkMedianSignalRatio):
116
+ def __init__(
117
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
118
+ ) -> None:
119
+ super().__init__(
120
+ func=tmF.scale_invariant_signal_distortion_ratio,
121
+ window_size=window_size,
122
+ hop_size=hop_size,
123
+ zero_mean=zero_mean,
124
+ )
 
 
 
mvsepless/models/bandit/core/model/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from .bsrnn.wrapper import (
2
- MultiMaskMultiSourceBandSplitRNNSimple,
3
- )
 
1
+ from .bsrnn.wrapper import (
2
+ MultiMaskMultiSourceBandSplitRNNSimple,
3
+ )
mvsepless/models/bandit/core/model/_spectral.py CHANGED
@@ -1,54 +1,54 @@
1
- from typing import Dict, Optional
2
-
3
- import torch
4
- import torchaudio as ta
5
- from torch import nn
6
-
7
-
8
- class _SpectralComponent(nn.Module):
9
- def __init__(
10
- self,
11
- n_fft: int = 2048,
12
- win_length: Optional[int] = 2048,
13
- hop_length: int = 512,
14
- window_fn: str = "hann_window",
15
- wkwargs: Optional[Dict] = None,
16
- power: Optional[int] = None,
17
- center: bool = True,
18
- normalized: bool = True,
19
- pad_mode: str = "constant",
20
- onesided: bool = True,
21
- **kwargs,
22
- ) -> None:
23
- super().__init__()
24
-
25
- assert power is None
26
-
27
- window_fn = torch.__dict__[window_fn]
28
-
29
- self.stft = ta.transforms.Spectrogram(
30
- n_fft=n_fft,
31
- win_length=win_length,
32
- hop_length=hop_length,
33
- pad_mode=pad_mode,
34
- pad=0,
35
- window_fn=window_fn,
36
- wkwargs=wkwargs,
37
- power=power,
38
- normalized=normalized,
39
- center=center,
40
- onesided=onesided,
41
- )
42
-
43
- self.istft = ta.transforms.InverseSpectrogram(
44
- n_fft=n_fft,
45
- win_length=win_length,
46
- hop_length=hop_length,
47
- pad_mode=pad_mode,
48
- pad=0,
49
- window_fn=window_fn,
50
- wkwargs=wkwargs,
51
- normalized=normalized,
52
- center=center,
53
- onesided=onesided,
54
- )
 
1
+ from typing import Dict, Optional
2
+
3
+ import torch
4
+ import torchaudio as ta
5
+ from torch import nn
6
+
7
+
8
+ class _SpectralComponent(nn.Module):
9
+ def __init__(
10
+ self,
11
+ n_fft: int = 2048,
12
+ win_length: Optional[int] = 2048,
13
+ hop_length: int = 512,
14
+ window_fn: str = "hann_window",
15
+ wkwargs: Optional[Dict] = None,
16
+ power: Optional[int] = None,
17
+ center: bool = True,
18
+ normalized: bool = True,
19
+ pad_mode: str = "constant",
20
+ onesided: bool = True,
21
+ **kwargs,
22
+ ) -> None:
23
+ super().__init__()
24
+
25
+ assert power is None
26
+
27
+ window_fn = torch.__dict__[window_fn]
28
+
29
+ self.stft = ta.transforms.Spectrogram(
30
+ n_fft=n_fft,
31
+ win_length=win_length,
32
+ hop_length=hop_length,
33
+ pad_mode=pad_mode,
34
+ pad=0,
35
+ window_fn=window_fn,
36
+ wkwargs=wkwargs,
37
+ power=power,
38
+ normalized=normalized,
39
+ center=center,
40
+ onesided=onesided,
41
+ )
42
+
43
+ self.istft = ta.transforms.InverseSpectrogram(
44
+ n_fft=n_fft,
45
+ win_length=win_length,
46
+ hop_length=hop_length,
47
+ pad_mode=pad_mode,
48
+ pad=0,
49
+ window_fn=window_fn,
50
+ wkwargs=wkwargs,
51
+ normalized=normalized,
52
+ center=center,
53
+ onesided=onesided,
54
+ )
mvsepless/models/bandit/core/model/bsrnn/__init__.py CHANGED
@@ -1,23 +1,23 @@
1
- from abc import ABC
2
- from typing import Iterable, Mapping, Union
3
-
4
- from torch import nn
5
-
6
- from .bandsplit import BandSplitModule
7
- from .tfmodel import (
8
- SeqBandModellingModule,
9
- TransformerTimeFreqModule,
10
- )
11
-
12
-
13
- class BandsplitCoreBase(nn.Module, ABC):
14
- band_split: nn.Module
15
- tf_model: nn.Module
16
- mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]]
17
-
18
- def __init__(self) -> None:
19
- super().__init__()
20
-
21
- @staticmethod
22
- def mask(x, m):
23
- return x * m
 
1
+ from abc import ABC
2
+ from typing import Iterable, Mapping, Union
3
+
4
+ from torch import nn
5
+
6
+ from .bandsplit import BandSplitModule
7
+ from .tfmodel import (
8
+ SeqBandModellingModule,
9
+ TransformerTimeFreqModule,
10
+ )
11
+
12
+
13
+ class BandsplitCoreBase(nn.Module, ABC):
14
+ band_split: nn.Module
15
+ tf_model: nn.Module
16
+ mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]]
17
+
18
+ def __init__(self) -> None:
19
+ super().__init__()
20
+
21
+ @staticmethod
22
+ def mask(x, m):
23
+ return x * m
mvsepless/models/bandit/core/model/bsrnn/bandsplit.py CHANGED
@@ -1,135 +1,119 @@
1
- from typing import List, Tuple
2
-
3
- import torch
4
- from torch import nn
5
-
6
- from .utils import (
7
- band_widths_from_specs,
8
- check_no_gap,
9
- check_no_overlap,
10
- check_nonzero_bandwidth,
11
- )
12
-
13
-
14
- class NormFC(nn.Module):
15
- def __init__(
16
- self,
17
- emb_dim: int,
18
- bandwidth: int,
19
- in_channel: int,
20
- normalize_channel_independently: bool = False,
21
- treat_channel_as_feature: bool = True,
22
- ) -> None:
23
- super().__init__()
24
-
25
- self.treat_channel_as_feature = treat_channel_as_feature
26
-
27
- if normalize_channel_independently:
28
- raise NotImplementedError
29
-
30
- reim = 2
31
-
32
- self.norm = nn.LayerNorm(in_channel * bandwidth * reim)
33
-
34
- fc_in = bandwidth * reim
35
-
36
- if treat_channel_as_feature:
37
- fc_in *= in_channel
38
- else:
39
- assert emb_dim % in_channel == 0
40
- emb_dim = emb_dim // in_channel
41
-
42
- self.fc = nn.Linear(fc_in, emb_dim)
43
-
44
- def forward(self, xb):
45
- # xb = (batch, n_time, in_chan, reim * band_width)
46
-
47
- batch, n_time, in_chan, ribw = xb.shape
48
- xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw))
49
- # (batch, n_time, in_chan * reim * band_width)
50
-
51
- if not self.treat_channel_as_feature:
52
- xb = xb.reshape(batch, n_time, in_chan, ribw)
53
- # (batch, n_time, in_chan, reim * band_width)
54
-
55
- zb = self.fc(xb)
56
- # (batch, n_time, emb_dim)
57
- # OR
58
- # (batch, n_time, in_chan, emb_dim_per_chan)
59
-
60
- if not self.treat_channel_as_feature:
61
- batch, n_time, in_chan, emb_dim_per_chan = zb.shape
62
- # (batch, n_time, in_chan, emb_dim_per_chan)
63
- zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan))
64
-
65
- return zb # (batch, n_time, emb_dim)
66
-
67
-
68
- class BandSplitModule(nn.Module):
69
- def __init__(
70
- self,
71
- band_specs: List[Tuple[float, float]],
72
- emb_dim: int,
73
- in_channel: int,
74
- require_no_overlap: bool = False,
75
- require_no_gap: bool = True,
76
- normalize_channel_independently: bool = False,
77
- treat_channel_as_feature: bool = True,
78
- ) -> None:
79
- super().__init__()
80
-
81
- check_nonzero_bandwidth(band_specs)
82
-
83
- if require_no_gap:
84
- check_no_gap(band_specs)
85
-
86
- if require_no_overlap:
87
- check_no_overlap(band_specs)
88
-
89
- self.band_specs = band_specs
90
- # list of [fstart, fend) in index.
91
- # Note that fend is exclusive.
92
- self.band_widths = band_widths_from_specs(band_specs)
93
- self.n_bands = len(band_specs)
94
- self.emb_dim = emb_dim
95
-
96
- self.norm_fc_modules = nn.ModuleList(
97
- [ # type: ignore
98
- (
99
- NormFC(
100
- emb_dim=emb_dim,
101
- bandwidth=bw,
102
- in_channel=in_channel,
103
- normalize_channel_independently=normalize_channel_independently,
104
- treat_channel_as_feature=treat_channel_as_feature,
105
- )
106
- )
107
- for bw in self.band_widths
108
- ]
109
- )
110
-
111
- def forward(self, x: torch.Tensor):
112
- # x = complex spectrogram (batch, in_chan, n_freq, n_time)
113
-
114
- batch, in_chan, _, n_time = x.shape
115
-
116
- z = torch.zeros(
117
- size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
118
- )
119
-
120
- xr = torch.view_as_real(x) # batch, in_chan, n_freq, n_time, 2
121
- xr = torch.permute(xr, (0, 3, 1, 4, 2)) # batch, n_time, in_chan, 2, n_freq
122
- batch, n_time, in_chan, reim, band_width = xr.shape
123
- for i, nfm in enumerate(self.norm_fc_modules):
124
- # print(f"bandsplit/band{i:02d}")
125
- fstart, fend = self.band_specs[i]
126
- xb = xr[..., fstart:fend]
127
- # (batch, n_time, in_chan, reim, band_width)
128
- xb = torch.reshape(xb, (batch, n_time, in_chan, -1))
129
- # (batch, n_time, in_chan, reim * band_width)
130
- # z.append(nfm(xb)) # (batch, n_time, emb_dim)
131
- z[:, i, :, :] = nfm(xb.contiguous())
132
-
133
- # z = torch.stack(z, dim=1)
134
-
135
- return z
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .utils import (
7
+ band_widths_from_specs,
8
+ check_no_gap,
9
+ check_no_overlap,
10
+ check_nonzero_bandwidth,
11
+ )
12
+
13
+
14
+ class NormFC(nn.Module):
15
+ def __init__(
16
+ self,
17
+ emb_dim: int,
18
+ bandwidth: int,
19
+ in_channel: int,
20
+ normalize_channel_independently: bool = False,
21
+ treat_channel_as_feature: bool = True,
22
+ ) -> None:
23
+ super().__init__()
24
+
25
+ self.treat_channel_as_feature = treat_channel_as_feature
26
+
27
+ if normalize_channel_independently:
28
+ raise NotImplementedError
29
+
30
+ reim = 2
31
+
32
+ self.norm = nn.LayerNorm(in_channel * bandwidth * reim)
33
+
34
+ fc_in = bandwidth * reim
35
+
36
+ if treat_channel_as_feature:
37
+ fc_in *= in_channel
38
+ else:
39
+ assert emb_dim % in_channel == 0
40
+ emb_dim = emb_dim // in_channel
41
+
42
+ self.fc = nn.Linear(fc_in, emb_dim)
43
+
44
+ def forward(self, xb):
45
+
46
+ batch, n_time, in_chan, ribw = xb.shape
47
+ xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw))
48
+
49
+ if not self.treat_channel_as_feature:
50
+ xb = xb.reshape(batch, n_time, in_chan, ribw)
51
+
52
+ zb = self.fc(xb)
53
+
54
+ if not self.treat_channel_as_feature:
55
+ batch, n_time, in_chan, emb_dim_per_chan = zb.shape
56
+ zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan))
57
+
58
+ return zb
59
+
60
+
61
+ class BandSplitModule(nn.Module):
62
+ def __init__(
63
+ self,
64
+ band_specs: List[Tuple[float, float]],
65
+ emb_dim: int,
66
+ in_channel: int,
67
+ require_no_overlap: bool = False,
68
+ require_no_gap: bool = True,
69
+ normalize_channel_independently: bool = False,
70
+ treat_channel_as_feature: bool = True,
71
+ ) -> None:
72
+ super().__init__()
73
+
74
+ check_nonzero_bandwidth(band_specs)
75
+
76
+ if require_no_gap:
77
+ check_no_gap(band_specs)
78
+
79
+ if require_no_overlap:
80
+ check_no_overlap(band_specs)
81
+
82
+ self.band_specs = band_specs
83
+ self.band_widths = band_widths_from_specs(band_specs)
84
+ self.n_bands = len(band_specs)
85
+ self.emb_dim = emb_dim
86
+
87
+ self.norm_fc_modules = nn.ModuleList(
88
+ [ # type: ignore
89
+ (
90
+ NormFC(
91
+ emb_dim=emb_dim,
92
+ bandwidth=bw,
93
+ in_channel=in_channel,
94
+ normalize_channel_independently=normalize_channel_independently,
95
+ treat_channel_as_feature=treat_channel_as_feature,
96
+ )
97
+ )
98
+ for bw in self.band_widths
99
+ ]
100
+ )
101
+
102
+ def forward(self, x: torch.Tensor):
103
+
104
+ batch, in_chan, _, n_time = x.shape
105
+
106
+ z = torch.zeros(
107
+ size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
108
+ )
109
+
110
+ xr = torch.view_as_real(x)
111
+ xr = torch.permute(xr, (0, 3, 1, 4, 2))
112
+ batch, n_time, in_chan, reim, band_width = xr.shape
113
+ for i, nfm in enumerate(self.norm_fc_modules):
114
+ fstart, fend = self.band_specs[i]
115
+ xb = xr[..., fstart:fend]
116
+ xb = torch.reshape(xb, (batch, n_time, in_chan, -1))
117
+ z[:, i, :, :] = nfm(xb.contiguous())
118
+
119
+ return z
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvsepless/models/bandit/core/model/bsrnn/core.py CHANGED
@@ -1,651 +1,619 @@
1
- from typing import Dict, List, Optional, Tuple
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
-
7
- from . import BandsplitCoreBase
8
- from .bandsplit import BandSplitModule
9
- from .maskestim import (
10
- MaskEstimationModule,
11
- OverlappingMaskEstimationModule,
12
- )
13
- from .tfmodel import (
14
- ConvolutionalTimeFreqModule,
15
- SeqBandModellingModule,
16
- TransformerTimeFreqModule,
17
- )
18
-
19
-
20
- class MultiMaskBandSplitCoreBase(BandsplitCoreBase):
21
- def __init__(self) -> None:
22
- super().__init__()
23
-
24
- def forward(self, x, cond=None, compute_residual: bool = True):
25
- # x = complex spectrogram (batch, in_chan, n_freq, n_time)
26
- # print(x.shape)
27
- batch, in_chan, n_freq, n_time = x.shape
28
- x = torch.reshape(x, (-1, 1, n_freq, n_time))
29
-
30
- z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
31
-
32
- # if torch.any(torch.isnan(z)):
33
- # raise ValueError("z nan")
34
-
35
- # print(z)
36
- q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
37
- # print(q)
38
-
39
- # if torch.any(torch.isnan(q)):
40
- # raise ValueError("q nan")
41
-
42
- out = {}
43
-
44
- for stem, mem in self.mask_estim.items():
45
- m = mem(q, cond=cond)
46
-
47
- # if torch.any(torch.isnan(m)):
48
- # raise ValueError("m nan", stem)
49
-
50
- s = self.mask(x, m)
51
- s = torch.reshape(s, (batch, in_chan, n_freq, n_time))
52
- out[stem] = s
53
-
54
- return {"spectrogram": out}
55
-
56
- def instantiate_mask_estim(
57
- self,
58
- in_channel: int,
59
- stems: List[str],
60
- band_specs: List[Tuple[float, float]],
61
- emb_dim: int,
62
- mlp_dim: int,
63
- cond_dim: int,
64
- hidden_activation: str,
65
- hidden_activation_kwargs: Optional[Dict] = None,
66
- complex_mask: bool = True,
67
- overlapping_band: bool = False,
68
- freq_weights: Optional[List[torch.Tensor]] = None,
69
- n_freq: Optional[int] = None,
70
- use_freq_weights: bool = True,
71
- mult_add_mask: bool = False,
72
- ):
73
- if hidden_activation_kwargs is None:
74
- hidden_activation_kwargs = {}
75
-
76
- if "mne:+" in stems:
77
- stems = [s for s in stems if s != "mne:+"]
78
-
79
- if overlapping_band:
80
- assert freq_weights is not None
81
- assert n_freq is not None
82
-
83
- if mult_add_mask:
84
-
85
- self.mask_estim = nn.ModuleDict(
86
- {
87
- stem: MultAddMaskEstimationModule(
88
- band_specs=band_specs,
89
- freq_weights=freq_weights,
90
- n_freq=n_freq,
91
- emb_dim=emb_dim,
92
- mlp_dim=mlp_dim,
93
- in_channel=in_channel,
94
- hidden_activation=hidden_activation,
95
- hidden_activation_kwargs=hidden_activation_kwargs,
96
- complex_mask=complex_mask,
97
- use_freq_weights=use_freq_weights,
98
- )
99
- for stem in stems
100
- }
101
- )
102
- else:
103
- self.mask_estim = nn.ModuleDict(
104
- {
105
- stem: OverlappingMaskEstimationModule(
106
- band_specs=band_specs,
107
- freq_weights=freq_weights,
108
- n_freq=n_freq,
109
- emb_dim=emb_dim,
110
- mlp_dim=mlp_dim,
111
- in_channel=in_channel,
112
- hidden_activation=hidden_activation,
113
- hidden_activation_kwargs=hidden_activation_kwargs,
114
- complex_mask=complex_mask,
115
- use_freq_weights=use_freq_weights,
116
- )
117
- for stem in stems
118
- }
119
- )
120
- else:
121
- self.mask_estim = nn.ModuleDict(
122
- {
123
- stem: MaskEstimationModule(
124
- band_specs=band_specs,
125
- emb_dim=emb_dim,
126
- mlp_dim=mlp_dim,
127
- cond_dim=cond_dim,
128
- in_channel=in_channel,
129
- hidden_activation=hidden_activation,
130
- hidden_activation_kwargs=hidden_activation_kwargs,
131
- complex_mask=complex_mask,
132
- )
133
- for stem in stems
134
- }
135
- )
136
-
137
- def instantiate_bandsplit(
138
- self,
139
- in_channel: int,
140
- band_specs: List[Tuple[float, float]],
141
- require_no_overlap: bool = False,
142
- require_no_gap: bool = True,
143
- normalize_channel_independently: bool = False,
144
- treat_channel_as_feature: bool = True,
145
- emb_dim: int = 128,
146
- ):
147
- self.band_split = BandSplitModule(
148
- in_channel=in_channel,
149
- band_specs=band_specs,
150
- require_no_overlap=require_no_overlap,
151
- require_no_gap=require_no_gap,
152
- normalize_channel_independently=normalize_channel_independently,
153
- treat_channel_as_feature=treat_channel_as_feature,
154
- emb_dim=emb_dim,
155
- )
156
-
157
-
158
- class SingleMaskBandsplitCoreBase(BandsplitCoreBase):
159
- def __init__(self, **kwargs) -> None:
160
- super().__init__()
161
-
162
- def forward(self, x):
163
- # x = complex spectrogram (batch, in_chan, n_freq, n_time)
164
- z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
165
- q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
166
- m = self.mask_estim(q) # (batch, in_chan, n_freq, n_time)
167
-
168
- s = self.mask(x, m)
169
-
170
- return s
171
-
172
-
173
- class SingleMaskBandsplitCoreRNN(
174
- SingleMaskBandsplitCoreBase,
175
- ):
176
- def __init__(
177
- self,
178
- in_channel: int,
179
- band_specs: List[Tuple[float, float]],
180
- require_no_overlap: bool = False,
181
- require_no_gap: bool = True,
182
- normalize_channel_independently: bool = False,
183
- treat_channel_as_feature: bool = True,
184
- n_sqm_modules: int = 12,
185
- emb_dim: int = 128,
186
- rnn_dim: int = 256,
187
- bidirectional: bool = True,
188
- rnn_type: str = "LSTM",
189
- mlp_dim: int = 512,
190
- hidden_activation: str = "Tanh",
191
- hidden_activation_kwargs: Optional[Dict] = None,
192
- complex_mask: bool = True,
193
- ) -> None:
194
- super().__init__()
195
- self.band_split = BandSplitModule(
196
- in_channel=in_channel,
197
- band_specs=band_specs,
198
- require_no_overlap=require_no_overlap,
199
- require_no_gap=require_no_gap,
200
- normalize_channel_independently=normalize_channel_independently,
201
- treat_channel_as_feature=treat_channel_as_feature,
202
- emb_dim=emb_dim,
203
- )
204
- self.tf_model = SeqBandModellingModule(
205
- n_modules=n_sqm_modules,
206
- emb_dim=emb_dim,
207
- rnn_dim=rnn_dim,
208
- bidirectional=bidirectional,
209
- rnn_type=rnn_type,
210
- )
211
- self.mask_estim = MaskEstimationModule(
212
- in_channel=in_channel,
213
- band_specs=band_specs,
214
- emb_dim=emb_dim,
215
- mlp_dim=mlp_dim,
216
- hidden_activation=hidden_activation,
217
- hidden_activation_kwargs=hidden_activation_kwargs,
218
- complex_mask=complex_mask,
219
- )
220
-
221
-
222
- class SingleMaskBandsplitCoreTransformer(
223
- SingleMaskBandsplitCoreBase,
224
- ):
225
- def __init__(
226
- self,
227
- in_channel: int,
228
- band_specs: List[Tuple[float, float]],
229
- require_no_overlap: bool = False,
230
- require_no_gap: bool = True,
231
- normalize_channel_independently: bool = False,
232
- treat_channel_as_feature: bool = True,
233
- n_sqm_modules: int = 12,
234
- emb_dim: int = 128,
235
- rnn_dim: int = 256,
236
- bidirectional: bool = True,
237
- tf_dropout: float = 0.0,
238
- mlp_dim: int = 512,
239
- hidden_activation: str = "Tanh",
240
- hidden_activation_kwargs: Optional[Dict] = None,
241
- complex_mask: bool = True,
242
- ) -> None:
243
- super().__init__()
244
- self.band_split = BandSplitModule(
245
- in_channel=in_channel,
246
- band_specs=band_specs,
247
- require_no_overlap=require_no_overlap,
248
- require_no_gap=require_no_gap,
249
- normalize_channel_independently=normalize_channel_independently,
250
- treat_channel_as_feature=treat_channel_as_feature,
251
- emb_dim=emb_dim,
252
- )
253
- self.tf_model = TransformerTimeFreqModule(
254
- n_modules=n_sqm_modules,
255
- emb_dim=emb_dim,
256
- rnn_dim=rnn_dim,
257
- bidirectional=bidirectional,
258
- dropout=tf_dropout,
259
- )
260
- self.mask_estim = MaskEstimationModule(
261
- in_channel=in_channel,
262
- band_specs=band_specs,
263
- emb_dim=emb_dim,
264
- mlp_dim=mlp_dim,
265
- hidden_activation=hidden_activation,
266
- hidden_activation_kwargs=hidden_activation_kwargs,
267
- complex_mask=complex_mask,
268
- )
269
-
270
-
271
- class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase):
272
- def __init__(
273
- self,
274
- in_channel: int,
275
- stems: List[str],
276
- band_specs: List[Tuple[float, float]],
277
- require_no_overlap: bool = False,
278
- require_no_gap: bool = True,
279
- normalize_channel_independently: bool = False,
280
- treat_channel_as_feature: bool = True,
281
- n_sqm_modules: int = 12,
282
- emb_dim: int = 128,
283
- rnn_dim: int = 256,
284
- bidirectional: bool = True,
285
- rnn_type: str = "LSTM",
286
- mlp_dim: int = 512,
287
- cond_dim: int = 0,
288
- hidden_activation: str = "Tanh",
289
- hidden_activation_kwargs: Optional[Dict] = None,
290
- complex_mask: bool = True,
291
- overlapping_band: bool = False,
292
- freq_weights: Optional[List[torch.Tensor]] = None,
293
- n_freq: Optional[int] = None,
294
- use_freq_weights: bool = True,
295
- mult_add_mask: bool = False,
296
- ) -> None:
297
-
298
- super().__init__()
299
- self.instantiate_bandsplit(
300
- in_channel=in_channel,
301
- band_specs=band_specs,
302
- require_no_overlap=require_no_overlap,
303
- require_no_gap=require_no_gap,
304
- normalize_channel_independently=normalize_channel_independently,
305
- treat_channel_as_feature=treat_channel_as_feature,
306
- emb_dim=emb_dim,
307
- )
308
-
309
- self.tf_model = SeqBandModellingModule(
310
- n_modules=n_sqm_modules,
311
- emb_dim=emb_dim,
312
- rnn_dim=rnn_dim,
313
- bidirectional=bidirectional,
314
- rnn_type=rnn_type,
315
- )
316
-
317
- self.mult_add_mask = mult_add_mask
318
-
319
- self.instantiate_mask_estim(
320
- in_channel=in_channel,
321
- stems=stems,
322
- band_specs=band_specs,
323
- emb_dim=emb_dim,
324
- mlp_dim=mlp_dim,
325
- cond_dim=cond_dim,
326
- hidden_activation=hidden_activation,
327
- hidden_activation_kwargs=hidden_activation_kwargs,
328
- complex_mask=complex_mask,
329
- overlapping_band=overlapping_band,
330
- freq_weights=freq_weights,
331
- n_freq=n_freq,
332
- use_freq_weights=use_freq_weights,
333
- mult_add_mask=mult_add_mask,
334
- )
335
-
336
- @staticmethod
337
- def _mult_add_mask(x, m):
338
-
339
- assert m.ndim == 5
340
-
341
- mm = m[..., 0]
342
- am = m[..., 1]
343
-
344
- # print(mm.shape, am.shape, x.shape, m.shape)
345
-
346
- return x * mm + am
347
-
348
- def mask(self, x, m):
349
- if self.mult_add_mask:
350
-
351
- return self._mult_add_mask(x, m)
352
- else:
353
- return super().mask(x, m)
354
-
355
-
356
- class MultiSourceMultiMaskBandSplitCoreTransformer(
357
- MultiMaskBandSplitCoreBase,
358
- ):
359
- def __init__(
360
- self,
361
- in_channel: int,
362
- stems: List[str],
363
- band_specs: List[Tuple[float, float]],
364
- require_no_overlap: bool = False,
365
- require_no_gap: bool = True,
366
- normalize_channel_independently: bool = False,
367
- treat_channel_as_feature: bool = True,
368
- n_sqm_modules: int = 12,
369
- emb_dim: int = 128,
370
- rnn_dim: int = 256,
371
- bidirectional: bool = True,
372
- tf_dropout: float = 0.0,
373
- mlp_dim: int = 512,
374
- hidden_activation: str = "Tanh",
375
- hidden_activation_kwargs: Optional[Dict] = None,
376
- complex_mask: bool = True,
377
- overlapping_band: bool = False,
378
- freq_weights: Optional[List[torch.Tensor]] = None,
379
- n_freq: Optional[int] = None,
380
- use_freq_weights: bool = True,
381
- rnn_type: str = "LSTM",
382
- cond_dim: int = 0,
383
- mult_add_mask: bool = False,
384
- ) -> None:
385
- super().__init__()
386
- self.instantiate_bandsplit(
387
- in_channel=in_channel,
388
- band_specs=band_specs,
389
- require_no_overlap=require_no_overlap,
390
- require_no_gap=require_no_gap,
391
- normalize_channel_independently=normalize_channel_independently,
392
- treat_channel_as_feature=treat_channel_as_feature,
393
- emb_dim=emb_dim,
394
- )
395
- self.tf_model = TransformerTimeFreqModule(
396
- n_modules=n_sqm_modules,
397
- emb_dim=emb_dim,
398
- rnn_dim=rnn_dim,
399
- bidirectional=bidirectional,
400
- dropout=tf_dropout,
401
- )
402
-
403
- self.instantiate_mask_estim(
404
- in_channel=in_channel,
405
- stems=stems,
406
- band_specs=band_specs,
407
- emb_dim=emb_dim,
408
- mlp_dim=mlp_dim,
409
- cond_dim=cond_dim,
410
- hidden_activation=hidden_activation,
411
- hidden_activation_kwargs=hidden_activation_kwargs,
412
- complex_mask=complex_mask,
413
- overlapping_band=overlapping_band,
414
- freq_weights=freq_weights,
415
- n_freq=n_freq,
416
- use_freq_weights=use_freq_weights,
417
- mult_add_mask=mult_add_mask,
418
- )
419
-
420
-
421
- class MultiSourceMultiMaskBandSplitCoreConv(
422
- MultiMaskBandSplitCoreBase,
423
- ):
424
- def __init__(
425
- self,
426
- in_channel: int,
427
- stems: List[str],
428
- band_specs: List[Tuple[float, float]],
429
- require_no_overlap: bool = False,
430
- require_no_gap: bool = True,
431
- normalize_channel_independently: bool = False,
432
- treat_channel_as_feature: bool = True,
433
- n_sqm_modules: int = 12,
434
- emb_dim: int = 128,
435
- rnn_dim: int = 256,
436
- bidirectional: bool = True,
437
- tf_dropout: float = 0.0,
438
- mlp_dim: int = 512,
439
- hidden_activation: str = "Tanh",
440
- hidden_activation_kwargs: Optional[Dict] = None,
441
- complex_mask: bool = True,
442
- overlapping_band: bool = False,
443
- freq_weights: Optional[List[torch.Tensor]] = None,
444
- n_freq: Optional[int] = None,
445
- use_freq_weights: bool = True,
446
- rnn_type: str = "LSTM",
447
- cond_dim: int = 0,
448
- mult_add_mask: bool = False,
449
- ) -> None:
450
- super().__init__()
451
- self.instantiate_bandsplit(
452
- in_channel=in_channel,
453
- band_specs=band_specs,
454
- require_no_overlap=require_no_overlap,
455
- require_no_gap=require_no_gap,
456
- normalize_channel_independently=normalize_channel_independently,
457
- treat_channel_as_feature=treat_channel_as_feature,
458
- emb_dim=emb_dim,
459
- )
460
- self.tf_model = ConvolutionalTimeFreqModule(
461
- n_modules=n_sqm_modules,
462
- emb_dim=emb_dim,
463
- rnn_dim=rnn_dim,
464
- bidirectional=bidirectional,
465
- dropout=tf_dropout,
466
- )
467
-
468
- self.instantiate_mask_estim(
469
- in_channel=in_channel,
470
- stems=stems,
471
- band_specs=band_specs,
472
- emb_dim=emb_dim,
473
- mlp_dim=mlp_dim,
474
- cond_dim=cond_dim,
475
- hidden_activation=hidden_activation,
476
- hidden_activation_kwargs=hidden_activation_kwargs,
477
- complex_mask=complex_mask,
478
- overlapping_band=overlapping_band,
479
- freq_weights=freq_weights,
480
- n_freq=n_freq,
481
- use_freq_weights=use_freq_weights,
482
- mult_add_mask=mult_add_mask,
483
- )
484
-
485
-
486
- class PatchingMaskBandsplitCoreBase(MultiMaskBandSplitCoreBase):
487
- def __init__(self) -> None:
488
- super().__init__()
489
-
490
- def mask(self, x, m):
491
- # x.shape = (batch, n_channel, n_freq, n_time)
492
- # m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time)
493
-
494
- _, n_channel, kernel_freq, kernel_time, n_freq, n_time = m.shape
495
- padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2)
496
-
497
- xf = F.unfold(
498
- x,
499
- kernel_size=(kernel_freq, kernel_time),
500
- padding=padding,
501
- stride=(1, 1),
502
- )
503
-
504
- xf = xf.view(
505
- -1,
506
- n_channel,
507
- kernel_freq,
508
- kernel_time,
509
- n_freq,
510
- n_time,
511
- )
512
-
513
- sf = xf * m
514
-
515
- sf = sf.view(
516
- -1,
517
- n_channel * kernel_freq * kernel_time,
518
- n_freq * n_time,
519
- )
520
-
521
- s = F.fold(
522
- sf,
523
- output_size=(n_freq, n_time),
524
- kernel_size=(kernel_freq, kernel_time),
525
- padding=padding,
526
- stride=(1, 1),
527
- ).view(
528
- -1,
529
- n_channel,
530
- n_freq,
531
- n_time,
532
- )
533
-
534
- return s
535
-
536
- def old_mask(self, x, m):
537
- # x.shape = (batch, n_channel, n_freq, n_time)
538
- # m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time)
539
-
540
- s = torch.zeros_like(x)
541
-
542
- _, n_channel, n_freq, n_time = x.shape
543
- kernel_freq, kernel_time, _, _, _, _ = m.shape
544
-
545
- # print(x.shape, m.shape)
546
-
547
- kernel_freq_half = (kernel_freq - 1) // 2
548
- kernel_time_half = (kernel_time - 1) // 2
549
-
550
- for ifreq in range(kernel_freq):
551
- for itime in range(kernel_time):
552
- df, dt = kernel_freq_half - ifreq, kernel_time_half - itime
553
- x = x.roll(shifts=(df, dt), dims=(2, 3))
554
-
555
- # if `df` > 0:
556
- # x[:, :, :df, :] = 0
557
- # elif `df` < 0:
558
- # x[:, :, df:, :] = 0
559
-
560
- # if `dt` > 0:
561
- # x[:, :, :, :dt] = 0
562
- # elif `dt` < 0:
563
- # x[:, :, :, dt:] = 0
564
-
565
- fslice = slice(max(0, df), min(n_freq, n_freq + df))
566
- tslice = slice(max(0, dt), min(n_time, n_time + dt))
567
-
568
- s[:, :, fslice, tslice] += (
569
- x[:, :, fslice, tslice] * m[ifreq, itime, :, :, fslice, tslice]
570
- )
571
-
572
- return s
573
-
574
-
575
- class MultiSourceMultiPatchingMaskBandSplitCoreRNN(PatchingMaskBandsplitCoreBase):
576
- def __init__(
577
- self,
578
- in_channel: int,
579
- stems: List[str],
580
- band_specs: List[Tuple[float, float]],
581
- mask_kernel_freq: int,
582
- mask_kernel_time: int,
583
- conv_kernel_freq: int,
584
- conv_kernel_time: int,
585
- kernel_norm_mlp_version: int,
586
- require_no_overlap: bool = False,
587
- require_no_gap: bool = True,
588
- normalize_channel_independently: bool = False,
589
- treat_channel_as_feature: bool = True,
590
- n_sqm_modules: int = 12,
591
- emb_dim: int = 128,
592
- rnn_dim: int = 256,
593
- bidirectional: bool = True,
594
- rnn_type: str = "LSTM",
595
- mlp_dim: int = 512,
596
- hidden_activation: str = "Tanh",
597
- hidden_activation_kwargs: Optional[Dict] = None,
598
- complex_mask: bool = True,
599
- overlapping_band: bool = False,
600
- freq_weights: Optional[List[torch.Tensor]] = None,
601
- n_freq: Optional[int] = None,
602
- ) -> None:
603
-
604
- super().__init__()
605
- self.band_split = BandSplitModule(
606
- in_channel=in_channel,
607
- band_specs=band_specs,
608
- require_no_overlap=require_no_overlap,
609
- require_no_gap=require_no_gap,
610
- normalize_channel_independently=normalize_channel_independently,
611
- treat_channel_as_feature=treat_channel_as_feature,
612
- emb_dim=emb_dim,
613
- )
614
-
615
- self.tf_model = SeqBandModellingModule(
616
- n_modules=n_sqm_modules,
617
- emb_dim=emb_dim,
618
- rnn_dim=rnn_dim,
619
- bidirectional=bidirectional,
620
- rnn_type=rnn_type,
621
- )
622
-
623
- if hidden_activation_kwargs is None:
624
- hidden_activation_kwargs = {}
625
-
626
- if overlapping_band:
627
- assert freq_weights is not None
628
- assert n_freq is not None
629
- self.mask_estim = nn.ModuleDict(
630
- {
631
- stem: PatchingMaskEstimationModule(
632
- band_specs=band_specs,
633
- freq_weights=freq_weights,
634
- n_freq=n_freq,
635
- emb_dim=emb_dim,
636
- mlp_dim=mlp_dim,
637
- in_channel=in_channel,
638
- hidden_activation=hidden_activation,
639
- hidden_activation_kwargs=hidden_activation_kwargs,
640
- complex_mask=complex_mask,
641
- mask_kernel_freq=mask_kernel_freq,
642
- mask_kernel_time=mask_kernel_time,
643
- conv_kernel_freq=conv_kernel_freq,
644
- conv_kernel_time=conv_kernel_time,
645
- kernel_norm_mlp_version=kernel_norm_mlp_version,
646
- )
647
- for stem in stems
648
- }
649
- )
650
- else:
651
- raise NotImplementedError
 
1
+ from typing import Dict, List, Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from . import BandsplitCoreBase
8
+ from .bandsplit import BandSplitModule
9
+ from .maskestim import (
10
+ MaskEstimationModule,
11
+ OverlappingMaskEstimationModule,
12
+ )
13
+ from .tfmodel import (
14
+ ConvolutionalTimeFreqModule,
15
+ SeqBandModellingModule,
16
+ TransformerTimeFreqModule,
17
+ )
18
+
19
+
20
+ class MultiMaskBandSplitCoreBase(BandsplitCoreBase):
21
+ def __init__(self) -> None:
22
+ super().__init__()
23
+
24
+ def forward(self, x, cond=None, compute_residual: bool = True):
25
+ batch, in_chan, n_freq, n_time = x.shape
26
+ x = torch.reshape(x, (-1, 1, n_freq, n_time))
27
+
28
+ z = self.band_split(x)
29
+
30
+ q = self.tf_model(z)
31
+
32
+ out = {}
33
+
34
+ for stem, mem in self.mask_estim.items():
35
+ m = mem(q, cond=cond)
36
+
37
+ s = self.mask(x, m)
38
+ s = torch.reshape(s, (batch, in_chan, n_freq, n_time))
39
+ out[stem] = s
40
+
41
+ return {"spectrogram": out}
42
+
43
+ def instantiate_mask_estim(
44
+ self,
45
+ in_channel: int,
46
+ stems: List[str],
47
+ band_specs: List[Tuple[float, float]],
48
+ emb_dim: int,
49
+ mlp_dim: int,
50
+ cond_dim: int,
51
+ hidden_activation: str,
52
+ hidden_activation_kwargs: Optional[Dict] = None,
53
+ complex_mask: bool = True,
54
+ overlapping_band: bool = False,
55
+ freq_weights: Optional[List[torch.Tensor]] = None,
56
+ n_freq: Optional[int] = None,
57
+ use_freq_weights: bool = True,
58
+ mult_add_mask: bool = False,
59
+ ):
60
+ if hidden_activation_kwargs is None:
61
+ hidden_activation_kwargs = {}
62
+
63
+ if "mne:+" in stems:
64
+ stems = [s for s in stems if s != "mne:+"]
65
+
66
+ if overlapping_band:
67
+ assert freq_weights is not None
68
+ assert n_freq is not None
69
+
70
+ if mult_add_mask:
71
+
72
+ self.mask_estim = nn.ModuleDict(
73
+ {
74
+ stem: MultAddMaskEstimationModule(
75
+ band_specs=band_specs,
76
+ freq_weights=freq_weights,
77
+ n_freq=n_freq,
78
+ emb_dim=emb_dim,
79
+ mlp_dim=mlp_dim,
80
+ in_channel=in_channel,
81
+ hidden_activation=hidden_activation,
82
+ hidden_activation_kwargs=hidden_activation_kwargs,
83
+ complex_mask=complex_mask,
84
+ use_freq_weights=use_freq_weights,
85
+ )
86
+ for stem in stems
87
+ }
88
+ )
89
+ else:
90
+ self.mask_estim = nn.ModuleDict(
91
+ {
92
+ stem: OverlappingMaskEstimationModule(
93
+ band_specs=band_specs,
94
+ freq_weights=freq_weights,
95
+ n_freq=n_freq,
96
+ emb_dim=emb_dim,
97
+ mlp_dim=mlp_dim,
98
+ in_channel=in_channel,
99
+ hidden_activation=hidden_activation,
100
+ hidden_activation_kwargs=hidden_activation_kwargs,
101
+ complex_mask=complex_mask,
102
+ use_freq_weights=use_freq_weights,
103
+ )
104
+ for stem in stems
105
+ }
106
+ )
107
+ else:
108
+ self.mask_estim = nn.ModuleDict(
109
+ {
110
+ stem: MaskEstimationModule(
111
+ band_specs=band_specs,
112
+ emb_dim=emb_dim,
113
+ mlp_dim=mlp_dim,
114
+ cond_dim=cond_dim,
115
+ in_channel=in_channel,
116
+ hidden_activation=hidden_activation,
117
+ hidden_activation_kwargs=hidden_activation_kwargs,
118
+ complex_mask=complex_mask,
119
+ )
120
+ for stem in stems
121
+ }
122
+ )
123
+
124
+ def instantiate_bandsplit(
125
+ self,
126
+ in_channel: int,
127
+ band_specs: List[Tuple[float, float]],
128
+ require_no_overlap: bool = False,
129
+ require_no_gap: bool = True,
130
+ normalize_channel_independently: bool = False,
131
+ treat_channel_as_feature: bool = True,
132
+ emb_dim: int = 128,
133
+ ):
134
+ self.band_split = BandSplitModule(
135
+ in_channel=in_channel,
136
+ band_specs=band_specs,
137
+ require_no_overlap=require_no_overlap,
138
+ require_no_gap=require_no_gap,
139
+ normalize_channel_independently=normalize_channel_independently,
140
+ treat_channel_as_feature=treat_channel_as_feature,
141
+ emb_dim=emb_dim,
142
+ )
143
+
144
+
145
+ class SingleMaskBandsplitCoreBase(BandsplitCoreBase):
146
+ def __init__(self, **kwargs) -> None:
147
+ super().__init__()
148
+
149
+ def forward(self, x):
150
+ z = self.band_split(x)
151
+ q = self.tf_model(z)
152
+ m = self.mask_estim(q)
153
+
154
+ s = self.mask(x, m)
155
+
156
+ return s
157
+
158
+
159
+ class SingleMaskBandsplitCoreRNN(
160
+ SingleMaskBandsplitCoreBase,
161
+ ):
162
+ def __init__(
163
+ self,
164
+ in_channel: int,
165
+ band_specs: List[Tuple[float, float]],
166
+ require_no_overlap: bool = False,
167
+ require_no_gap: bool = True,
168
+ normalize_channel_independently: bool = False,
169
+ treat_channel_as_feature: bool = True,
170
+ n_sqm_modules: int = 12,
171
+ emb_dim: int = 128,
172
+ rnn_dim: int = 256,
173
+ bidirectional: bool = True,
174
+ rnn_type: str = "LSTM",
175
+ mlp_dim: int = 512,
176
+ hidden_activation: str = "Tanh",
177
+ hidden_activation_kwargs: Optional[Dict] = None,
178
+ complex_mask: bool = True,
179
+ ) -> None:
180
+ super().__init__()
181
+ self.band_split = BandSplitModule(
182
+ in_channel=in_channel,
183
+ band_specs=band_specs,
184
+ require_no_overlap=require_no_overlap,
185
+ require_no_gap=require_no_gap,
186
+ normalize_channel_independently=normalize_channel_independently,
187
+ treat_channel_as_feature=treat_channel_as_feature,
188
+ emb_dim=emb_dim,
189
+ )
190
+ self.tf_model = SeqBandModellingModule(
191
+ n_modules=n_sqm_modules,
192
+ emb_dim=emb_dim,
193
+ rnn_dim=rnn_dim,
194
+ bidirectional=bidirectional,
195
+ rnn_type=rnn_type,
196
+ )
197
+ self.mask_estim = MaskEstimationModule(
198
+ in_channel=in_channel,
199
+ band_specs=band_specs,
200
+ emb_dim=emb_dim,
201
+ mlp_dim=mlp_dim,
202
+ hidden_activation=hidden_activation,
203
+ hidden_activation_kwargs=hidden_activation_kwargs,
204
+ complex_mask=complex_mask,
205
+ )
206
+
207
+
208
+ class SingleMaskBandsplitCoreTransformer(
209
+ SingleMaskBandsplitCoreBase,
210
+ ):
211
+ def __init__(
212
+ self,
213
+ in_channel: int,
214
+ band_specs: List[Tuple[float, float]],
215
+ require_no_overlap: bool = False,
216
+ require_no_gap: bool = True,
217
+ normalize_channel_independently: bool = False,
218
+ treat_channel_as_feature: bool = True,
219
+ n_sqm_modules: int = 12,
220
+ emb_dim: int = 128,
221
+ rnn_dim: int = 256,
222
+ bidirectional: bool = True,
223
+ tf_dropout: float = 0.0,
224
+ mlp_dim: int = 512,
225
+ hidden_activation: str = "Tanh",
226
+ hidden_activation_kwargs: Optional[Dict] = None,
227
+ complex_mask: bool = True,
228
+ ) -> None:
229
+ super().__init__()
230
+ self.band_split = BandSplitModule(
231
+ in_channel=in_channel,
232
+ band_specs=band_specs,
233
+ require_no_overlap=require_no_overlap,
234
+ require_no_gap=require_no_gap,
235
+ normalize_channel_independently=normalize_channel_independently,
236
+ treat_channel_as_feature=treat_channel_as_feature,
237
+ emb_dim=emb_dim,
238
+ )
239
+ self.tf_model = TransformerTimeFreqModule(
240
+ n_modules=n_sqm_modules,
241
+ emb_dim=emb_dim,
242
+ rnn_dim=rnn_dim,
243
+ bidirectional=bidirectional,
244
+ dropout=tf_dropout,
245
+ )
246
+ self.mask_estim = MaskEstimationModule(
247
+ in_channel=in_channel,
248
+ band_specs=band_specs,
249
+ emb_dim=emb_dim,
250
+ mlp_dim=mlp_dim,
251
+ hidden_activation=hidden_activation,
252
+ hidden_activation_kwargs=hidden_activation_kwargs,
253
+ complex_mask=complex_mask,
254
+ )
255
+
256
+
257
+ class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase):
258
+ def __init__(
259
+ self,
260
+ in_channel: int,
261
+ stems: List[str],
262
+ band_specs: List[Tuple[float, float]],
263
+ require_no_overlap: bool = False,
264
+ require_no_gap: bool = True,
265
+ normalize_channel_independently: bool = False,
266
+ treat_channel_as_feature: bool = True,
267
+ n_sqm_modules: int = 12,
268
+ emb_dim: int = 128,
269
+ rnn_dim: int = 256,
270
+ bidirectional: bool = True,
271
+ rnn_type: str = "LSTM",
272
+ mlp_dim: int = 512,
273
+ cond_dim: int = 0,
274
+ hidden_activation: str = "Tanh",
275
+ hidden_activation_kwargs: Optional[Dict] = None,
276
+ complex_mask: bool = True,
277
+ overlapping_band: bool = False,
278
+ freq_weights: Optional[List[torch.Tensor]] = None,
279
+ n_freq: Optional[int] = None,
280
+ use_freq_weights: bool = True,
281
+ mult_add_mask: bool = False,
282
+ ) -> None:
283
+
284
+ super().__init__()
285
+ self.instantiate_bandsplit(
286
+ in_channel=in_channel,
287
+ band_specs=band_specs,
288
+ require_no_overlap=require_no_overlap,
289
+ require_no_gap=require_no_gap,
290
+ normalize_channel_independently=normalize_channel_independently,
291
+ treat_channel_as_feature=treat_channel_as_feature,
292
+ emb_dim=emb_dim,
293
+ )
294
+
295
+ self.tf_model = SeqBandModellingModule(
296
+ n_modules=n_sqm_modules,
297
+ emb_dim=emb_dim,
298
+ rnn_dim=rnn_dim,
299
+ bidirectional=bidirectional,
300
+ rnn_type=rnn_type,
301
+ )
302
+
303
+ self.mult_add_mask = mult_add_mask
304
+
305
+ self.instantiate_mask_estim(
306
+ in_channel=in_channel,
307
+ stems=stems,
308
+ band_specs=band_specs,
309
+ emb_dim=emb_dim,
310
+ mlp_dim=mlp_dim,
311
+ cond_dim=cond_dim,
312
+ hidden_activation=hidden_activation,
313
+ hidden_activation_kwargs=hidden_activation_kwargs,
314
+ complex_mask=complex_mask,
315
+ overlapping_band=overlapping_band,
316
+ freq_weights=freq_weights,
317
+ n_freq=n_freq,
318
+ use_freq_weights=use_freq_weights,
319
+ mult_add_mask=mult_add_mask,
320
+ )
321
+
322
+ @staticmethod
323
+ def _mult_add_mask(x, m):
324
+
325
+ assert m.ndim == 5
326
+
327
+ mm = m[..., 0]
328
+ am = m[..., 1]
329
+
330
+ return x * mm + am
331
+
332
+ def mask(self, x, m):
333
+ if self.mult_add_mask:
334
+
335
+ return self._mult_add_mask(x, m)
336
+ else:
337
+ return super().mask(x, m)
338
+
339
+
340
+ class MultiSourceMultiMaskBandSplitCoreTransformer(
341
+ MultiMaskBandSplitCoreBase,
342
+ ):
343
+ def __init__(
344
+ self,
345
+ in_channel: int,
346
+ stems: List[str],
347
+ band_specs: List[Tuple[float, float]],
348
+ require_no_overlap: bool = False,
349
+ require_no_gap: bool = True,
350
+ normalize_channel_independently: bool = False,
351
+ treat_channel_as_feature: bool = True,
352
+ n_sqm_modules: int = 12,
353
+ emb_dim: int = 128,
354
+ rnn_dim: int = 256,
355
+ bidirectional: bool = True,
356
+ tf_dropout: float = 0.0,
357
+ mlp_dim: int = 512,
358
+ hidden_activation: str = "Tanh",
359
+ hidden_activation_kwargs: Optional[Dict] = None,
360
+ complex_mask: bool = True,
361
+ overlapping_band: bool = False,
362
+ freq_weights: Optional[List[torch.Tensor]] = None,
363
+ n_freq: Optional[int] = None,
364
+ use_freq_weights: bool = True,
365
+ rnn_type: str = "LSTM",
366
+ cond_dim: int = 0,
367
+ mult_add_mask: bool = False,
368
+ ) -> None:
369
+ super().__init__()
370
+ self.instantiate_bandsplit(
371
+ in_channel=in_channel,
372
+ band_specs=band_specs,
373
+ require_no_overlap=require_no_overlap,
374
+ require_no_gap=require_no_gap,
375
+ normalize_channel_independently=normalize_channel_independently,
376
+ treat_channel_as_feature=treat_channel_as_feature,
377
+ emb_dim=emb_dim,
378
+ )
379
+ self.tf_model = TransformerTimeFreqModule(
380
+ n_modules=n_sqm_modules,
381
+ emb_dim=emb_dim,
382
+ rnn_dim=rnn_dim,
383
+ bidirectional=bidirectional,
384
+ dropout=tf_dropout,
385
+ )
386
+
387
+ self.instantiate_mask_estim(
388
+ in_channel=in_channel,
389
+ stems=stems,
390
+ band_specs=band_specs,
391
+ emb_dim=emb_dim,
392
+ mlp_dim=mlp_dim,
393
+ cond_dim=cond_dim,
394
+ hidden_activation=hidden_activation,
395
+ hidden_activation_kwargs=hidden_activation_kwargs,
396
+ complex_mask=complex_mask,
397
+ overlapping_band=overlapping_band,
398
+ freq_weights=freq_weights,
399
+ n_freq=n_freq,
400
+ use_freq_weights=use_freq_weights,
401
+ mult_add_mask=mult_add_mask,
402
+ )
403
+
404
+
405
+ class MultiSourceMultiMaskBandSplitCoreConv(
406
+ MultiMaskBandSplitCoreBase,
407
+ ):
408
+ def __init__(
409
+ self,
410
+ in_channel: int,
411
+ stems: List[str],
412
+ band_specs: List[Tuple[float, float]],
413
+ require_no_overlap: bool = False,
414
+ require_no_gap: bool = True,
415
+ normalize_channel_independently: bool = False,
416
+ treat_channel_as_feature: bool = True,
417
+ n_sqm_modules: int = 12,
418
+ emb_dim: int = 128,
419
+ rnn_dim: int = 256,
420
+ bidirectional: bool = True,
421
+ tf_dropout: float = 0.0,
422
+ mlp_dim: int = 512,
423
+ hidden_activation: str = "Tanh",
424
+ hidden_activation_kwargs: Optional[Dict] = None,
425
+ complex_mask: bool = True,
426
+ overlapping_band: bool = False,
427
+ freq_weights: Optional[List[torch.Tensor]] = None,
428
+ n_freq: Optional[int] = None,
429
+ use_freq_weights: bool = True,
430
+ rnn_type: str = "LSTM",
431
+ cond_dim: int = 0,
432
+ mult_add_mask: bool = False,
433
+ ) -> None:
434
+ super().__init__()
435
+ self.instantiate_bandsplit(
436
+ in_channel=in_channel,
437
+ band_specs=band_specs,
438
+ require_no_overlap=require_no_overlap,
439
+ require_no_gap=require_no_gap,
440
+ normalize_channel_independently=normalize_channel_independently,
441
+ treat_channel_as_feature=treat_channel_as_feature,
442
+ emb_dim=emb_dim,
443
+ )
444
+ self.tf_model = ConvolutionalTimeFreqModule(
445
+ n_modules=n_sqm_modules,
446
+ emb_dim=emb_dim,
447
+ rnn_dim=rnn_dim,
448
+ bidirectional=bidirectional,
449
+ dropout=tf_dropout,
450
+ )
451
+
452
+ self.instantiate_mask_estim(
453
+ in_channel=in_channel,
454
+ stems=stems,
455
+ band_specs=band_specs,
456
+ emb_dim=emb_dim,
457
+ mlp_dim=mlp_dim,
458
+ cond_dim=cond_dim,
459
+ hidden_activation=hidden_activation,
460
+ hidden_activation_kwargs=hidden_activation_kwargs,
461
+ complex_mask=complex_mask,
462
+ overlapping_band=overlapping_band,
463
+ freq_weights=freq_weights,
464
+ n_freq=n_freq,
465
+ use_freq_weights=use_freq_weights,
466
+ mult_add_mask=mult_add_mask,
467
+ )
468
+
469
+
470
+ class PatchingMaskBandsplitCoreBase(MultiMaskBandSplitCoreBase):
471
+ def __init__(self) -> None:
472
+ super().__init__()
473
+
474
+ def mask(self, x, m):
475
+
476
+ _, n_channel, kernel_freq, kernel_time, n_freq, n_time = m.shape
477
+ padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2)
478
+
479
+ xf = F.unfold(
480
+ x,
481
+ kernel_size=(kernel_freq, kernel_time),
482
+ padding=padding,
483
+ stride=(1, 1),
484
+ )
485
+
486
+ xf = xf.view(
487
+ -1,
488
+ n_channel,
489
+ kernel_freq,
490
+ kernel_time,
491
+ n_freq,
492
+ n_time,
493
+ )
494
+
495
+ sf = xf * m
496
+
497
+ sf = sf.view(
498
+ -1,
499
+ n_channel * kernel_freq * kernel_time,
500
+ n_freq * n_time,
501
+ )
502
+
503
+ s = F.fold(
504
+ sf,
505
+ output_size=(n_freq, n_time),
506
+ kernel_size=(kernel_freq, kernel_time),
507
+ padding=padding,
508
+ stride=(1, 1),
509
+ ).view(
510
+ -1,
511
+ n_channel,
512
+ n_freq,
513
+ n_time,
514
+ )
515
+
516
+ return s
517
+
518
+ def old_mask(self, x, m):
519
+
520
+ s = torch.zeros_like(x)
521
+
522
+ _, n_channel, n_freq, n_time = x.shape
523
+ kernel_freq, kernel_time, _, _, _, _ = m.shape
524
+
525
+ kernel_freq_half = (kernel_freq - 1) // 2
526
+ kernel_time_half = (kernel_time - 1) // 2
527
+
528
+ for ifreq in range(kernel_freq):
529
+ for itime in range(kernel_time):
530
+ df, dt = kernel_freq_half - ifreq, kernel_time_half - itime
531
+ x = x.roll(shifts=(df, dt), dims=(2, 3))
532
+
533
+ fslice = slice(max(0, df), min(n_freq, n_freq + df))
534
+ tslice = slice(max(0, dt), min(n_time, n_time + dt))
535
+
536
+ s[:, :, fslice, tslice] += (
537
+ x[:, :, fslice, tslice] * m[ifreq, itime, :, :, fslice, tslice]
538
+ )
539
+
540
+ return s
541
+
542
+
543
+ class MultiSourceMultiPatchingMaskBandSplitCoreRNN(PatchingMaskBandsplitCoreBase):
544
+ def __init__(
545
+ self,
546
+ in_channel: int,
547
+ stems: List[str],
548
+ band_specs: List[Tuple[float, float]],
549
+ mask_kernel_freq: int,
550
+ mask_kernel_time: int,
551
+ conv_kernel_freq: int,
552
+ conv_kernel_time: int,
553
+ kernel_norm_mlp_version: int,
554
+ require_no_overlap: bool = False,
555
+ require_no_gap: bool = True,
556
+ normalize_channel_independently: bool = False,
557
+ treat_channel_as_feature: bool = True,
558
+ n_sqm_modules: int = 12,
559
+ emb_dim: int = 128,
560
+ rnn_dim: int = 256,
561
+ bidirectional: bool = True,
562
+ rnn_type: str = "LSTM",
563
+ mlp_dim: int = 512,
564
+ hidden_activation: str = "Tanh",
565
+ hidden_activation_kwargs: Optional[Dict] = None,
566
+ complex_mask: bool = True,
567
+ overlapping_band: bool = False,
568
+ freq_weights: Optional[List[torch.Tensor]] = None,
569
+ n_freq: Optional[int] = None,
570
+ ) -> None:
571
+
572
+ super().__init__()
573
+ self.band_split = BandSplitModule(
574
+ in_channel=in_channel,
575
+ band_specs=band_specs,
576
+ require_no_overlap=require_no_overlap,
577
+ require_no_gap=require_no_gap,
578
+ normalize_channel_independently=normalize_channel_independently,
579
+ treat_channel_as_feature=treat_channel_as_feature,
580
+ emb_dim=emb_dim,
581
+ )
582
+
583
+ self.tf_model = SeqBandModellingModule(
584
+ n_modules=n_sqm_modules,
585
+ emb_dim=emb_dim,
586
+ rnn_dim=rnn_dim,
587
+ bidirectional=bidirectional,
588
+ rnn_type=rnn_type,
589
+ )
590
+
591
+ if hidden_activation_kwargs is None:
592
+ hidden_activation_kwargs = {}
593
+
594
+ if overlapping_band:
595
+ assert freq_weights is not None
596
+ assert n_freq is not None
597
+ self.mask_estim = nn.ModuleDict(
598
+ {
599
+ stem: PatchingMaskEstimationModule(
600
+ band_specs=band_specs,
601
+ freq_weights=freq_weights,
602
+ n_freq=n_freq,
603
+ emb_dim=emb_dim,
604
+ mlp_dim=mlp_dim,
605
+ in_channel=in_channel,
606
+ hidden_activation=hidden_activation,
607
+ hidden_activation_kwargs=hidden_activation_kwargs,
608
+ complex_mask=complex_mask,
609
+ mask_kernel_freq=mask_kernel_freq,
610
+ mask_kernel_time=mask_kernel_time,
611
+ conv_kernel_freq=conv_kernel_freq,
612
+ conv_kernel_time=conv_kernel_time,
613
+ kernel_norm_mlp_version=kernel_norm_mlp_version,
614
+ )
615
+ for stem in stems
616
+ }
617
+ )
618
+ else:
619
+ raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvsepless/models/bandit/core/model/bsrnn/maskestim.py CHANGED
@@ -1,351 +1,327 @@
1
- import warnings
2
- from typing import Dict, List, Optional, Tuple, Type
3
-
4
- import torch
5
- from torch import nn
6
- from torch.nn.modules import activation
7
-
8
- from .utils import (
9
- band_widths_from_specs,
10
- check_no_gap,
11
- check_no_overlap,
12
- check_nonzero_bandwidth,
13
- )
14
-
15
-
16
- class BaseNormMLP(nn.Module):
17
- def __init__(
18
- self,
19
- emb_dim: int,
20
- mlp_dim: int,
21
- bandwidth: int,
22
- in_channel: Optional[int],
23
- hidden_activation: str = "Tanh",
24
- hidden_activation_kwargs=None,
25
- complex_mask: bool = True,
26
- ):
27
-
28
- super().__init__()
29
- if hidden_activation_kwargs is None:
30
- hidden_activation_kwargs = {}
31
- self.hidden_activation_kwargs = hidden_activation_kwargs
32
- self.norm = nn.LayerNorm(emb_dim)
33
- self.hidden = torch.jit.script(
34
- nn.Sequential(
35
- nn.Linear(in_features=emb_dim, out_features=mlp_dim),
36
- activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
37
- )
38
- )
39
-
40
- self.bandwidth = bandwidth
41
- self.in_channel = in_channel
42
-
43
- self.complex_mask = complex_mask
44
- self.reim = 2 if complex_mask else 1
45
- self.glu_mult = 2
46
-
47
-
48
- class NormMLP(BaseNormMLP):
49
- def __init__(
50
- self,
51
- emb_dim: int,
52
- mlp_dim: int,
53
- bandwidth: int,
54
- in_channel: Optional[int],
55
- hidden_activation: str = "Tanh",
56
- hidden_activation_kwargs=None,
57
- complex_mask: bool = True,
58
- ) -> None:
59
- super().__init__(
60
- emb_dim=emb_dim,
61
- mlp_dim=mlp_dim,
62
- bandwidth=bandwidth,
63
- in_channel=in_channel,
64
- hidden_activation=hidden_activation,
65
- hidden_activation_kwargs=hidden_activation_kwargs,
66
- complex_mask=complex_mask,
67
- )
68
-
69
- self.output = torch.jit.script(
70
- nn.Sequential(
71
- nn.Linear(
72
- in_features=mlp_dim,
73
- out_features=bandwidth * in_channel * self.reim * 2,
74
- ),
75
- nn.GLU(dim=-1),
76
- )
77
- )
78
-
79
- def reshape_output(self, mb):
80
- # print(mb.shape)
81
- batch, n_time, _ = mb.shape
82
- if self.complex_mask:
83
- mb = mb.reshape(
84
- batch, n_time, self.in_channel, self.bandwidth, self.reim
85
- ).contiguous()
86
- # print(mb.shape)
87
- mb = torch.view_as_complex(mb) # (batch, n_time, in_channel, bandwidth)
88
- else:
89
- mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth)
90
-
91
- mb = torch.permute(mb, (0, 2, 3, 1)) # (batch, in_channel, bandwidth, n_time)
92
-
93
- return mb
94
-
95
- def forward(self, qb):
96
- # qb = (batch, n_time, emb_dim)
97
-
98
- # if torch.any(torch.isnan(qb)):
99
- # raise ValueError("qb0")
100
-
101
- qb = self.norm(qb) # (batch, n_time, emb_dim)
102
-
103
- # if torch.any(torch.isnan(qb)):
104
- # raise ValueError("qb1")
105
-
106
- qb = self.hidden(qb) # (batch, n_time, mlp_dim)
107
- # if torch.any(torch.isnan(qb)):
108
- # raise ValueError("qb2")
109
- mb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim)
110
- # if torch.any(torch.isnan(qb)):
111
- # raise ValueError("mb")
112
- mb = self.reshape_output(mb) # (batch, in_channel, bandwidth, n_time)
113
-
114
- return mb
115
-
116
-
117
- class MultAddNormMLP(NormMLP):
118
- def __init__(
119
- self,
120
- emb_dim: int,
121
- mlp_dim: int,
122
- bandwidth: int,
123
- in_channel: "int | None",
124
- hidden_activation: str = "Tanh",
125
- hidden_activation_kwargs=None,
126
- complex_mask: bool = True,
127
- ) -> None:
128
- super().__init__(
129
- emb_dim,
130
- mlp_dim,
131
- bandwidth,
132
- in_channel,
133
- hidden_activation,
134
- hidden_activation_kwargs,
135
- complex_mask,
136
- )
137
-
138
- self.output2 = torch.jit.script(
139
- nn.Sequential(
140
- nn.Linear(
141
- in_features=mlp_dim,
142
- out_features=bandwidth * in_channel * self.reim * 2,
143
- ),
144
- nn.GLU(dim=-1),
145
- )
146
- )
147
-
148
- def forward(self, qb):
149
-
150
- qb = self.norm(qb) # (batch, n_time, emb_dim)
151
- qb = self.hidden(qb) # (batch, n_time, mlp_dim)
152
- mmb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim)
153
- mmb = self.reshape_output(mmb) # (batch, in_channel, bandwidth, n_time)
154
- amb = self.output2(qb) # (batch, n_time, bandwidth * in_channel * reim)
155
- amb = self.reshape_output(amb) # (batch, in_channel, bandwidth, n_time)
156
-
157
- return mmb, amb
158
-
159
-
160
- class MaskEstimationModuleSuperBase(nn.Module):
161
- pass
162
-
163
-
164
- class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
165
- def __init__(
166
- self,
167
- band_specs: List[Tuple[float, float]],
168
- emb_dim: int,
169
- mlp_dim: int,
170
- in_channel: Optional[int],
171
- hidden_activation: str = "Tanh",
172
- hidden_activation_kwargs: Dict = None,
173
- complex_mask: bool = True,
174
- norm_mlp_cls: Type[nn.Module] = NormMLP,
175
- norm_mlp_kwargs: Dict = None,
176
- ) -> None:
177
- super().__init__()
178
-
179
- self.band_widths = band_widths_from_specs(band_specs)
180
- self.n_bands = len(band_specs)
181
-
182
- if hidden_activation_kwargs is None:
183
- hidden_activation_kwargs = {}
184
-
185
- if norm_mlp_kwargs is None:
186
- norm_mlp_kwargs = {}
187
-
188
- self.norm_mlp = nn.ModuleList(
189
- [
190
- (
191
- norm_mlp_cls(
192
- bandwidth=self.band_widths[b],
193
- emb_dim=emb_dim,
194
- mlp_dim=mlp_dim,
195
- in_channel=in_channel,
196
- hidden_activation=hidden_activation,
197
- hidden_activation_kwargs=hidden_activation_kwargs,
198
- complex_mask=complex_mask,
199
- **norm_mlp_kwargs,
200
- )
201
- )
202
- for b in range(self.n_bands)
203
- ]
204
- )
205
-
206
- def compute_masks(self, q):
207
- batch, n_bands, n_time, emb_dim = q.shape
208
-
209
- masks = []
210
-
211
- for b, nmlp in enumerate(self.norm_mlp):
212
- # print(f"maskestim/{b:02d}")
213
- qb = q[:, b, :, :]
214
- mb = nmlp(qb)
215
- masks.append(mb)
216
-
217
- return masks
218
-
219
-
220
- class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
221
- def __init__(
222
- self,
223
- in_channel: int,
224
- band_specs: List[Tuple[float, float]],
225
- freq_weights: List[torch.Tensor],
226
- n_freq: int,
227
- emb_dim: int,
228
- mlp_dim: int,
229
- cond_dim: int = 0,
230
- hidden_activation: str = "Tanh",
231
- hidden_activation_kwargs: Dict = None,
232
- complex_mask: bool = True,
233
- norm_mlp_cls: Type[nn.Module] = NormMLP,
234
- norm_mlp_kwargs: Dict = None,
235
- use_freq_weights: bool = True,
236
- ) -> None:
237
- check_nonzero_bandwidth(band_specs)
238
- check_no_gap(band_specs)
239
-
240
- # if cond_dim > 0:
241
- # raise NotImplementedError
242
-
243
- super().__init__(
244
- band_specs=band_specs,
245
- emb_dim=emb_dim + cond_dim,
246
- mlp_dim=mlp_dim,
247
- in_channel=in_channel,
248
- hidden_activation=hidden_activation,
249
- hidden_activation_kwargs=hidden_activation_kwargs,
250
- complex_mask=complex_mask,
251
- norm_mlp_cls=norm_mlp_cls,
252
- norm_mlp_kwargs=norm_mlp_kwargs,
253
- )
254
-
255
- self.n_freq = n_freq
256
- self.band_specs = band_specs
257
- self.in_channel = in_channel
258
-
259
- if freq_weights is not None:
260
- for i, fw in enumerate(freq_weights):
261
- self.register_buffer(f"freq_weights/{i}", fw)
262
-
263
- self.use_freq_weights = use_freq_weights
264
- else:
265
- self.use_freq_weights = False
266
-
267
- self.cond_dim = cond_dim
268
-
269
- def forward(self, q, cond=None):
270
- # q = (batch, n_bands, n_time, emb_dim)
271
-
272
- batch, n_bands, n_time, emb_dim = q.shape
273
-
274
- if cond is not None:
275
- print(cond)
276
- if cond.ndim == 2:
277
- cond = cond[:, None, None, :].expand(-1, n_bands, n_time, -1)
278
- elif cond.ndim == 3:
279
- assert cond.shape[1] == n_time
280
- else:
281
- raise ValueError(f"Invalid cond shape: {cond.shape}")
282
-
283
- q = torch.cat([q, cond], dim=-1)
284
- elif self.cond_dim > 0:
285
- cond = torch.ones(
286
- (batch, n_bands, n_time, self.cond_dim),
287
- device=q.device,
288
- dtype=q.dtype,
289
- )
290
- q = torch.cat([q, cond], dim=-1)
291
- else:
292
- pass
293
-
294
- mask_list = self.compute_masks(
295
- q
296
- ) # [n_bands * (batch, in_channel, bandwidth, n_time)]
297
-
298
- masks = torch.zeros(
299
- (batch, self.in_channel, self.n_freq, n_time),
300
- device=q.device,
301
- dtype=mask_list[0].dtype,
302
- )
303
-
304
- for im, mask in enumerate(mask_list):
305
- fstart, fend = self.band_specs[im]
306
- if self.use_freq_weights:
307
- fw = self.get_buffer(f"freq_weights/{im}")[:, None]
308
- mask = mask * fw
309
- masks[:, :, fstart:fend, :] += mask
310
-
311
- return masks
312
-
313
-
314
- class MaskEstimationModule(OverlappingMaskEstimationModule):
315
- def __init__(
316
- self,
317
- band_specs: List[Tuple[float, float]],
318
- emb_dim: int,
319
- mlp_dim: int,
320
- in_channel: Optional[int],
321
- hidden_activation: str = "Tanh",
322
- hidden_activation_kwargs: Dict = None,
323
- complex_mask: bool = True,
324
- **kwargs,
325
- ) -> None:
326
- check_nonzero_bandwidth(band_specs)
327
- check_no_gap(band_specs)
328
- check_no_overlap(band_specs)
329
- super().__init__(
330
- in_channel=in_channel,
331
- band_specs=band_specs,
332
- freq_weights=None,
333
- n_freq=None,
334
- emb_dim=emb_dim,
335
- mlp_dim=mlp_dim,
336
- hidden_activation=hidden_activation,
337
- hidden_activation_kwargs=hidden_activation_kwargs,
338
- complex_mask=complex_mask,
339
- )
340
-
341
- def forward(self, q, cond=None):
342
- # q = (batch, n_bands, n_time, emb_dim)
343
-
344
- masks = self.compute_masks(
345
- q
346
- ) # [n_bands * (batch, in_channel, bandwidth, n_time)]
347
-
348
- # TODO: currently this requires band specs to have no gap and no overlap
349
- masks = torch.concat(masks, dim=2) # (batch, in_channel, n_freq, n_time)
350
-
351
- return masks
 
1
+ import warnings
2
+ from typing import Dict, List, Optional, Tuple, Type
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn.modules import activation
7
+
8
+ from .utils import (
9
+ band_widths_from_specs,
10
+ check_no_gap,
11
+ check_no_overlap,
12
+ check_nonzero_bandwidth,
13
+ )
14
+
15
+
16
+ class BaseNormMLP(nn.Module):
17
+ def __init__(
18
+ self,
19
+ emb_dim: int,
20
+ mlp_dim: int,
21
+ bandwidth: int,
22
+ in_channel: Optional[int],
23
+ hidden_activation: str = "Tanh",
24
+ hidden_activation_kwargs=None,
25
+ complex_mask: bool = True,
26
+ ):
27
+
28
+ super().__init__()
29
+ if hidden_activation_kwargs is None:
30
+ hidden_activation_kwargs = {}
31
+ self.hidden_activation_kwargs = hidden_activation_kwargs
32
+ self.norm = nn.LayerNorm(emb_dim)
33
+ self.hidden = torch.jit.script(
34
+ nn.Sequential(
35
+ nn.Linear(in_features=emb_dim, out_features=mlp_dim),
36
+ activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
37
+ )
38
+ )
39
+
40
+ self.bandwidth = bandwidth
41
+ self.in_channel = in_channel
42
+
43
+ self.complex_mask = complex_mask
44
+ self.reim = 2 if complex_mask else 1
45
+ self.glu_mult = 2
46
+
47
+
48
+ class NormMLP(BaseNormMLP):
49
+ def __init__(
50
+ self,
51
+ emb_dim: int,
52
+ mlp_dim: int,
53
+ bandwidth: int,
54
+ in_channel: Optional[int],
55
+ hidden_activation: str = "Tanh",
56
+ hidden_activation_kwargs=None,
57
+ complex_mask: bool = True,
58
+ ) -> None:
59
+ super().__init__(
60
+ emb_dim=emb_dim,
61
+ mlp_dim=mlp_dim,
62
+ bandwidth=bandwidth,
63
+ in_channel=in_channel,
64
+ hidden_activation=hidden_activation,
65
+ hidden_activation_kwargs=hidden_activation_kwargs,
66
+ complex_mask=complex_mask,
67
+ )
68
+
69
+ self.output = torch.jit.script(
70
+ nn.Sequential(
71
+ nn.Linear(
72
+ in_features=mlp_dim,
73
+ out_features=bandwidth * in_channel * self.reim * 2,
74
+ ),
75
+ nn.GLU(dim=-1),
76
+ )
77
+ )
78
+
79
+ def reshape_output(self, mb):
80
+ batch, n_time, _ = mb.shape
81
+ if self.complex_mask:
82
+ mb = mb.reshape(
83
+ batch, n_time, self.in_channel, self.bandwidth, self.reim
84
+ ).contiguous()
85
+ mb = torch.view_as_complex(mb)
86
+ else:
87
+ mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth)
88
+
89
+ mb = torch.permute(mb, (0, 2, 3, 1))
90
+
91
+ return mb
92
+
93
+ def forward(self, qb):
94
+
95
+ qb = self.norm(qb)
96
+
97
+ qb = self.hidden(qb)
98
+ mb = self.output(qb)
99
+ mb = self.reshape_output(mb)
100
+
101
+ return mb
102
+
103
+
104
+ class MultAddNormMLP(NormMLP):
105
+ def __init__(
106
+ self,
107
+ emb_dim: int,
108
+ mlp_dim: int,
109
+ bandwidth: int,
110
+ in_channel: "int | None",
111
+ hidden_activation: str = "Tanh",
112
+ hidden_activation_kwargs=None,
113
+ complex_mask: bool = True,
114
+ ) -> None:
115
+ super().__init__(
116
+ emb_dim,
117
+ mlp_dim,
118
+ bandwidth,
119
+ in_channel,
120
+ hidden_activation,
121
+ hidden_activation_kwargs,
122
+ complex_mask,
123
+ )
124
+
125
+ self.output2 = torch.jit.script(
126
+ nn.Sequential(
127
+ nn.Linear(
128
+ in_features=mlp_dim,
129
+ out_features=bandwidth * in_channel * self.reim * 2,
130
+ ),
131
+ nn.GLU(dim=-1),
132
+ )
133
+ )
134
+
135
+ def forward(self, qb):
136
+
137
+ qb = self.norm(qb)
138
+ qb = self.hidden(qb)
139
+ mmb = self.output(qb)
140
+ mmb = self.reshape_output(mmb)
141
+ amb = self.output2(qb)
142
+ amb = self.reshape_output(amb)
143
+
144
+ return mmb, amb
145
+
146
+
147
+ class MaskEstimationModuleSuperBase(nn.Module):
148
+ pass
149
+
150
+
151
+ class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
152
+ def __init__(
153
+ self,
154
+ band_specs: List[Tuple[float, float]],
155
+ emb_dim: int,
156
+ mlp_dim: int,
157
+ in_channel: Optional[int],
158
+ hidden_activation: str = "Tanh",
159
+ hidden_activation_kwargs: Dict = None,
160
+ complex_mask: bool = True,
161
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
162
+ norm_mlp_kwargs: Dict = None,
163
+ ) -> None:
164
+ super().__init__()
165
+
166
+ self.band_widths = band_widths_from_specs(band_specs)
167
+ self.n_bands = len(band_specs)
168
+
169
+ if hidden_activation_kwargs is None:
170
+ hidden_activation_kwargs = {}
171
+
172
+ if norm_mlp_kwargs is None:
173
+ norm_mlp_kwargs = {}
174
+
175
+ self.norm_mlp = nn.ModuleList(
176
+ [
177
+ (
178
+ norm_mlp_cls(
179
+ bandwidth=self.band_widths[b],
180
+ emb_dim=emb_dim,
181
+ mlp_dim=mlp_dim,
182
+ in_channel=in_channel,
183
+ hidden_activation=hidden_activation,
184
+ hidden_activation_kwargs=hidden_activation_kwargs,
185
+ complex_mask=complex_mask,
186
+ **norm_mlp_kwargs,
187
+ )
188
+ )
189
+ for b in range(self.n_bands)
190
+ ]
191
+ )
192
+
193
+ def compute_masks(self, q):
194
+ batch, n_bands, n_time, emb_dim = q.shape
195
+
196
+ masks = []
197
+
198
+ for b, nmlp in enumerate(self.norm_mlp):
199
+ qb = q[:, b, :, :]
200
+ mb = nmlp(qb)
201
+ masks.append(mb)
202
+
203
+ return masks
204
+
205
+
206
+ class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
207
+ def __init__(
208
+ self,
209
+ in_channel: int,
210
+ band_specs: List[Tuple[float, float]],
211
+ freq_weights: List[torch.Tensor],
212
+ n_freq: int,
213
+ emb_dim: int,
214
+ mlp_dim: int,
215
+ cond_dim: int = 0,
216
+ hidden_activation: str = "Tanh",
217
+ hidden_activation_kwargs: Dict = None,
218
+ complex_mask: bool = True,
219
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
220
+ norm_mlp_kwargs: Dict = None,
221
+ use_freq_weights: bool = True,
222
+ ) -> None:
223
+ check_nonzero_bandwidth(band_specs)
224
+ check_no_gap(band_specs)
225
+
226
+ super().__init__(
227
+ band_specs=band_specs,
228
+ emb_dim=emb_dim + cond_dim,
229
+ mlp_dim=mlp_dim,
230
+ in_channel=in_channel,
231
+ hidden_activation=hidden_activation,
232
+ hidden_activation_kwargs=hidden_activation_kwargs,
233
+ complex_mask=complex_mask,
234
+ norm_mlp_cls=norm_mlp_cls,
235
+ norm_mlp_kwargs=norm_mlp_kwargs,
236
+ )
237
+
238
+ self.n_freq = n_freq
239
+ self.band_specs = band_specs
240
+ self.in_channel = in_channel
241
+
242
+ if freq_weights is not None:
243
+ for i, fw in enumerate(freq_weights):
244
+ self.register_buffer(f"freq_weights/{i}", fw)
245
+
246
+ self.use_freq_weights = use_freq_weights
247
+ else:
248
+ self.use_freq_weights = False
249
+
250
+ self.cond_dim = cond_dim
251
+
252
+ def forward(self, q, cond=None):
253
+
254
+ batch, n_bands, n_time, emb_dim = q.shape
255
+
256
+ if cond is not None:
257
+ print(cond)
258
+ if cond.ndim == 2:
259
+ cond = cond[:, None, None, :].expand(-1, n_bands, n_time, -1)
260
+ elif cond.ndim == 3:
261
+ assert cond.shape[1] == n_time
262
+ else:
263
+ raise ValueError(f"Invalid cond shape: {cond.shape}")
264
+
265
+ q = torch.cat([q, cond], dim=-1)
266
+ elif self.cond_dim > 0:
267
+ cond = torch.ones(
268
+ (batch, n_bands, n_time, self.cond_dim),
269
+ device=q.device,
270
+ dtype=q.dtype,
271
+ )
272
+ q = torch.cat([q, cond], dim=-1)
273
+ else:
274
+ pass
275
+
276
+ mask_list = self.compute_masks(q)
277
+
278
+ masks = torch.zeros(
279
+ (batch, self.in_channel, self.n_freq, n_time),
280
+ device=q.device,
281
+ dtype=mask_list[0].dtype,
282
+ )
283
+
284
+ for im, mask in enumerate(mask_list):
285
+ fstart, fend = self.band_specs[im]
286
+ if self.use_freq_weights:
287
+ fw = self.get_buffer(f"freq_weights/{im}")[:, None]
288
+ mask = mask * fw
289
+ masks[:, :, fstart:fend, :] += mask
290
+
291
+ return masks
292
+
293
+
294
+ class MaskEstimationModule(OverlappingMaskEstimationModule):
295
+ def __init__(
296
+ self,
297
+ band_specs: List[Tuple[float, float]],
298
+ emb_dim: int,
299
+ mlp_dim: int,
300
+ in_channel: Optional[int],
301
+ hidden_activation: str = "Tanh",
302
+ hidden_activation_kwargs: Dict = None,
303
+ complex_mask: bool = True,
304
+ **kwargs,
305
+ ) -> None:
306
+ check_nonzero_bandwidth(band_specs)
307
+ check_no_gap(band_specs)
308
+ check_no_overlap(band_specs)
309
+ super().__init__(
310
+ in_channel=in_channel,
311
+ band_specs=band_specs,
312
+ freq_weights=None,
313
+ n_freq=None,
314
+ emb_dim=emb_dim,
315
+ mlp_dim=mlp_dim,
316
+ hidden_activation=hidden_activation,
317
+ hidden_activation_kwargs=hidden_activation_kwargs,
318
+ complex_mask=complex_mask,
319
+ )
320
+
321
+ def forward(self, q, cond=None):
322
+
323
+ masks = self.compute_masks(q)
324
+
325
+ masks = torch.concat(masks, dim=2)
326
+
327
+ return masks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvsepless/models/bandit/core/model/bsrnn/tfmodel.py CHANGED
@@ -1,320 +1,287 @@
1
- import warnings
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
- from torch.nn.modules import rnn
7
-
8
- import torch.backends.cuda
9
-
10
-
11
- class TimeFrequencyModellingModule(nn.Module):
12
- def __init__(self) -> None:
13
- super().__init__()
14
-
15
-
16
- class ResidualRNN(nn.Module):
17
- def __init__(
18
- self,
19
- emb_dim: int,
20
- rnn_dim: int,
21
- bidirectional: bool = True,
22
- rnn_type: str = "LSTM",
23
- use_batch_trick: bool = True,
24
- use_layer_norm: bool = True,
25
- ) -> None:
26
- # n_group is the size of the 2nd dim
27
- super().__init__()
28
-
29
- self.use_layer_norm = use_layer_norm
30
- if use_layer_norm:
31
- self.norm = nn.LayerNorm(emb_dim)
32
- else:
33
- self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim)
34
-
35
- self.rnn = rnn.__dict__[rnn_type](
36
- input_size=emb_dim,
37
- hidden_size=rnn_dim,
38
- num_layers=1,
39
- batch_first=True,
40
- bidirectional=bidirectional,
41
- )
42
-
43
- self.fc = nn.Linear(
44
- in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
45
- )
46
-
47
- self.use_batch_trick = use_batch_trick
48
- if not self.use_batch_trick:
49
- warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
50
-
51
- def forward(self, z):
52
- # z = (batch, n_uncrossed, n_across, emb_dim)
53
-
54
- z0 = torch.clone(z)
55
-
56
- # print(z.device)
57
-
58
- if self.use_layer_norm:
59
- z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim)
60
- else:
61
- z = torch.permute(
62
- z, (0, 3, 1, 2)
63
- ) # (batch, emb_dim, n_uncrossed, n_across)
64
-
65
- z = self.norm(z) # (batch, emb_dim, n_uncrossed, n_across)
66
-
67
- z = torch.permute(
68
- z, (0, 2, 3, 1)
69
- ) # (batch, n_uncrossed, n_across, emb_dim)
70
-
71
- batch, n_uncrossed, n_across, emb_dim = z.shape
72
-
73
- if self.use_batch_trick:
74
- z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
75
-
76
- z = self.rnn(z.contiguous())[
77
- 0
78
- ] # (batch * n_uncrossed, n_across, dir_rnn_dim)
79
-
80
- z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
81
- # (batch, n_uncrossed, n_across, dir_rnn_dim)
82
- else:
83
- # Note: this is EXTREMELY SLOW
84
- zlist = []
85
- for i in range(n_uncrossed):
86
- zi = self.rnn(z[:, i, :, :])[0] # (batch, n_across, emb_dim)
87
- zlist.append(zi)
88
-
89
- z = torch.stack(zlist, dim=1) # (batch, n_uncrossed, n_across, dir_rnn_dim)
90
-
91
- z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
92
-
93
- z = z + z0
94
-
95
- return z
96
-
97
-
98
- class SeqBandModellingModule(TimeFrequencyModellingModule):
99
- def __init__(
100
- self,
101
- n_modules: int = 12,
102
- emb_dim: int = 128,
103
- rnn_dim: int = 256,
104
- bidirectional: bool = True,
105
- rnn_type: str = "LSTM",
106
- parallel_mode=False,
107
- ) -> None:
108
- super().__init__()
109
- self.seqband = nn.ModuleList([])
110
-
111
- if parallel_mode:
112
- for _ in range(n_modules):
113
- self.seqband.append(
114
- nn.ModuleList(
115
- [
116
- ResidualRNN(
117
- emb_dim=emb_dim,
118
- rnn_dim=rnn_dim,
119
- bidirectional=bidirectional,
120
- rnn_type=rnn_type,
121
- ),
122
- ResidualRNN(
123
- emb_dim=emb_dim,
124
- rnn_dim=rnn_dim,
125
- bidirectional=bidirectional,
126
- rnn_type=rnn_type,
127
- ),
128
- ]
129
- )
130
- )
131
- else:
132
-
133
- for _ in range(2 * n_modules):
134
- self.seqband.append(
135
- ResidualRNN(
136
- emb_dim=emb_dim,
137
- rnn_dim=rnn_dim,
138
- bidirectional=bidirectional,
139
- rnn_type=rnn_type,
140
- )
141
- )
142
-
143
- self.parallel_mode = parallel_mode
144
-
145
- def forward(self, z):
146
- # z = (batch, n_bands, n_time, emb_dim)
147
-
148
- if self.parallel_mode:
149
- for sbm_pair in self.seqband:
150
- # z: (batch, n_bands, n_time, emb_dim)
151
- sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
152
- zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
153
- zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
154
- z = zt + zf.transpose(1, 2)
155
- else:
156
- for sbm in self.seqband:
157
- z = sbm(z)
158
- z = z.transpose(1, 2)
159
-
160
- # (batch, n_bands, n_time, emb_dim)
161
- # --> (batch, n_time, n_bands, emb_dim)
162
- # OR
163
- # (batch, n_time, n_bands, emb_dim)
164
- # --> (batch, n_bands, n_time, emb_dim)
165
-
166
- q = z
167
- return q # (batch, n_bands, n_time, emb_dim)
168
-
169
-
170
- class ResidualTransformer(nn.Module):
171
- def __init__(
172
- self,
173
- emb_dim: int = 128,
174
- rnn_dim: int = 256,
175
- bidirectional: bool = True,
176
- dropout: float = 0.0,
177
- ) -> None:
178
- # n_group is the size of the 2nd dim
179
- super().__init__()
180
-
181
- self.tf = nn.TransformerEncoderLayer(
182
- d_model=emb_dim, nhead=4, dim_feedforward=rnn_dim, batch_first=True
183
- )
184
-
185
- self.is_causal = not bidirectional
186
- self.dropout = dropout
187
-
188
- def forward(self, z):
189
- batch, n_uncrossed, n_across, emb_dim = z.shape
190
- z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
191
- z = self.tf(
192
- z, is_causal=self.is_causal
193
- ) # (batch, n_uncrossed, n_across, emb_dim)
194
- z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim))
195
-
196
- return z
197
-
198
-
199
- class TransformerTimeFreqModule(TimeFrequencyModellingModule):
200
- def __init__(
201
- self,
202
- n_modules: int = 12,
203
- emb_dim: int = 128,
204
- rnn_dim: int = 256,
205
- bidirectional: bool = True,
206
- dropout: float = 0.0,
207
- ) -> None:
208
- super().__init__()
209
- self.norm = nn.LayerNorm(emb_dim)
210
- self.seqband = nn.ModuleList([])
211
-
212
- for _ in range(2 * n_modules):
213
- self.seqband.append(
214
- ResidualTransformer(
215
- emb_dim=emb_dim,
216
- rnn_dim=rnn_dim,
217
- bidirectional=bidirectional,
218
- dropout=dropout,
219
- )
220
- )
221
-
222
- def forward(self, z):
223
- # z = (batch, n_bands, n_time, emb_dim)
224
- z = self.norm(z) # (batch, n_bands, n_time, emb_dim)
225
-
226
- for sbm in self.seqband:
227
- z = sbm(z)
228
- z = z.transpose(1, 2)
229
-
230
- # (batch, n_bands, n_time, emb_dim)
231
- # --> (batch, n_time, n_bands, emb_dim)
232
- # OR
233
- # (batch, n_time, n_bands, emb_dim)
234
- # --> (batch, n_bands, n_time, emb_dim)
235
-
236
- q = z
237
- return q # (batch, n_bands, n_time, emb_dim)
238
-
239
-
240
- class ResidualConvolution(nn.Module):
241
- def __init__(
242
- self,
243
- emb_dim: int = 128,
244
- rnn_dim: int = 256,
245
- bidirectional: bool = True,
246
- dropout: float = 0.0,
247
- ) -> None:
248
- # n_group is the size of the 2nd dim
249
- super().__init__()
250
- self.norm = nn.InstanceNorm2d(emb_dim, affine=True)
251
-
252
- self.conv = nn.Sequential(
253
- nn.Conv2d(
254
- in_channels=emb_dim,
255
- out_channels=rnn_dim,
256
- kernel_size=(3, 3),
257
- padding="same",
258
- stride=(1, 1),
259
- ),
260
- nn.Tanhshrink(),
261
- )
262
-
263
- self.is_causal = not bidirectional
264
- self.dropout = dropout
265
-
266
- self.fc = nn.Conv2d(
267
- in_channels=rnn_dim,
268
- out_channels=emb_dim,
269
- kernel_size=(1, 1),
270
- padding="same",
271
- stride=(1, 1),
272
- )
273
-
274
- def forward(self, z):
275
- # z = (batch, n_uncrossed, n_across, emb_dim)
276
-
277
- z0 = torch.clone(z)
278
-
279
- z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim)
280
- z = self.conv(z) # (batch, n_uncrossed, n_across, emb_dim)
281
- z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
282
- z = z + z0
283
-
284
- return z
285
-
286
-
287
- class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule):
288
- def __init__(
289
- self,
290
- n_modules: int = 12,
291
- emb_dim: int = 128,
292
- rnn_dim: int = 256,
293
- bidirectional: bool = True,
294
- dropout: float = 0.0,
295
- ) -> None:
296
- super().__init__()
297
- self.seqband = torch.jit.script(
298
- nn.Sequential(
299
- *[
300
- ResidualConvolution(
301
- emb_dim=emb_dim,
302
- rnn_dim=rnn_dim,
303
- bidirectional=bidirectional,
304
- dropout=dropout,
305
- )
306
- for _ in range(2 * n_modules)
307
- ]
308
- )
309
- )
310
-
311
- def forward(self, z):
312
- # z = (batch, n_bands, n_time, emb_dim)
313
-
314
- z = torch.permute(z, (0, 3, 1, 2)) # (batch, emb_dim, n_bands, n_time)
315
-
316
- z = self.seqband(z) # (batch, emb_dim, n_bands, n_time)
317
-
318
- z = torch.permute(z, (0, 2, 3, 1)) # (batch, n_bands, n_time, emb_dim)
319
-
320
- return z
 
1
+ import warnings
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.nn.modules import rnn
7
+
8
+ import torch.backends.cuda
9
+
10
+
11
+ class TimeFrequencyModellingModule(nn.Module):
12
+ def __init__(self) -> None:
13
+ super().__init__()
14
+
15
+
16
+ class ResidualRNN(nn.Module):
17
+ def __init__(
18
+ self,
19
+ emb_dim: int,
20
+ rnn_dim: int,
21
+ bidirectional: bool = True,
22
+ rnn_type: str = "LSTM",
23
+ use_batch_trick: bool = True,
24
+ use_layer_norm: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+
28
+ self.use_layer_norm = use_layer_norm
29
+ if use_layer_norm:
30
+ self.norm = nn.LayerNorm(emb_dim)
31
+ else:
32
+ self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim)
33
+
34
+ self.rnn = rnn.__dict__[rnn_type](
35
+ input_size=emb_dim,
36
+ hidden_size=rnn_dim,
37
+ num_layers=1,
38
+ batch_first=True,
39
+ bidirectional=bidirectional,
40
+ )
41
+
42
+ self.fc = nn.Linear(
43
+ in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
44
+ )
45
+
46
+ self.use_batch_trick = use_batch_trick
47
+ if not self.use_batch_trick:
48
+ warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
49
+
50
+ def forward(self, z):
51
+
52
+ z0 = torch.clone(z)
53
+
54
+ if self.use_layer_norm:
55
+ z = self.norm(z)
56
+ else:
57
+ z = torch.permute(z, (0, 3, 1, 2))
58
+
59
+ z = self.norm(z)
60
+
61
+ z = torch.permute(z, (0, 2, 3, 1))
62
+
63
+ batch, n_uncrossed, n_across, emb_dim = z.shape
64
+
65
+ if self.use_batch_trick:
66
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
67
+
68
+ z = self.rnn(z.contiguous())[0]
69
+
70
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
71
+ else:
72
+ zlist = []
73
+ for i in range(n_uncrossed):
74
+ zi = self.rnn(z[:, i, :, :])[0]
75
+ zlist.append(zi)
76
+
77
+ z = torch.stack(zlist, dim=1)
78
+
79
+ z = self.fc(z)
80
+
81
+ z = z + z0
82
+
83
+ return z
84
+
85
+
86
+ class SeqBandModellingModule(TimeFrequencyModellingModule):
87
+ def __init__(
88
+ self,
89
+ n_modules: int = 12,
90
+ emb_dim: int = 128,
91
+ rnn_dim: int = 256,
92
+ bidirectional: bool = True,
93
+ rnn_type: str = "LSTM",
94
+ parallel_mode=False,
95
+ ) -> None:
96
+ super().__init__()
97
+ self.seqband = nn.ModuleList([])
98
+
99
+ if parallel_mode:
100
+ for _ in range(n_modules):
101
+ self.seqband.append(
102
+ nn.ModuleList(
103
+ [
104
+ ResidualRNN(
105
+ emb_dim=emb_dim,
106
+ rnn_dim=rnn_dim,
107
+ bidirectional=bidirectional,
108
+ rnn_type=rnn_type,
109
+ ),
110
+ ResidualRNN(
111
+ emb_dim=emb_dim,
112
+ rnn_dim=rnn_dim,
113
+ bidirectional=bidirectional,
114
+ rnn_type=rnn_type,
115
+ ),
116
+ ]
117
+ )
118
+ )
119
+ else:
120
+
121
+ for _ in range(2 * n_modules):
122
+ self.seqband.append(
123
+ ResidualRNN(
124
+ emb_dim=emb_dim,
125
+ rnn_dim=rnn_dim,
126
+ bidirectional=bidirectional,
127
+ rnn_type=rnn_type,
128
+ )
129
+ )
130
+
131
+ self.parallel_mode = parallel_mode
132
+
133
+ def forward(self, z):
134
+
135
+ if self.parallel_mode:
136
+ for sbm_pair in self.seqband:
137
+ sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
138
+ zt = sbm_t(z)
139
+ zf = sbm_f(z.transpose(1, 2))
140
+ z = zt + zf.transpose(1, 2)
141
+ else:
142
+ for sbm in self.seqband:
143
+ z = sbm(z)
144
+ z = z.transpose(1, 2)
145
+
146
+ q = z
147
+ return q
148
+
149
+
150
+ class ResidualTransformer(nn.Module):
151
+ def __init__(
152
+ self,
153
+ emb_dim: int = 128,
154
+ rnn_dim: int = 256,
155
+ bidirectional: bool = True,
156
+ dropout: float = 0.0,
157
+ ) -> None:
158
+ super().__init__()
159
+
160
+ self.tf = nn.TransformerEncoderLayer(
161
+ d_model=emb_dim, nhead=4, dim_feedforward=rnn_dim, batch_first=True
162
+ )
163
+
164
+ self.is_causal = not bidirectional
165
+ self.dropout = dropout
166
+
167
+ def forward(self, z):
168
+ batch, n_uncrossed, n_across, emb_dim = z.shape
169
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
170
+ z = self.tf(z, is_causal=self.is_causal)
171
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim))
172
+
173
+ return z
174
+
175
+
176
+ class TransformerTimeFreqModule(TimeFrequencyModellingModule):
177
+ def __init__(
178
+ self,
179
+ n_modules: int = 12,
180
+ emb_dim: int = 128,
181
+ rnn_dim: int = 256,
182
+ bidirectional: bool = True,
183
+ dropout: float = 0.0,
184
+ ) -> None:
185
+ super().__init__()
186
+ self.norm = nn.LayerNorm(emb_dim)
187
+ self.seqband = nn.ModuleList([])
188
+
189
+ for _ in range(2 * n_modules):
190
+ self.seqband.append(
191
+ ResidualTransformer(
192
+ emb_dim=emb_dim,
193
+ rnn_dim=rnn_dim,
194
+ bidirectional=bidirectional,
195
+ dropout=dropout,
196
+ )
197
+ )
198
+
199
+ def forward(self, z):
200
+ z = self.norm(z)
201
+
202
+ for sbm in self.seqband:
203
+ z = sbm(z)
204
+ z = z.transpose(1, 2)
205
+
206
+ q = z
207
+ return q
208
+
209
+
210
+ class ResidualConvolution(nn.Module):
211
+ def __init__(
212
+ self,
213
+ emb_dim: int = 128,
214
+ rnn_dim: int = 256,
215
+ bidirectional: bool = True,
216
+ dropout: float = 0.0,
217
+ ) -> None:
218
+ super().__init__()
219
+ self.norm = nn.InstanceNorm2d(emb_dim, affine=True)
220
+
221
+ self.conv = nn.Sequential(
222
+ nn.Conv2d(
223
+ in_channels=emb_dim,
224
+ out_channels=rnn_dim,
225
+ kernel_size=(3, 3),
226
+ padding="same",
227
+ stride=(1, 1),
228
+ ),
229
+ nn.Tanhshrink(),
230
+ )
231
+
232
+ self.is_causal = not bidirectional
233
+ self.dropout = dropout
234
+
235
+ self.fc = nn.Conv2d(
236
+ in_channels=rnn_dim,
237
+ out_channels=emb_dim,
238
+ kernel_size=(1, 1),
239
+ padding="same",
240
+ stride=(1, 1),
241
+ )
242
+
243
+ def forward(self, z):
244
+
245
+ z0 = torch.clone(z)
246
+
247
+ z = self.norm(z)
248
+ z = self.conv(z)
249
+ z = self.fc(z)
250
+ z = z + z0
251
+
252
+ return z
253
+
254
+
255
+ class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule):
256
+ def __init__(
257
+ self,
258
+ n_modules: int = 12,
259
+ emb_dim: int = 128,
260
+ rnn_dim: int = 256,
261
+ bidirectional: bool = True,
262
+ dropout: float = 0.0,
263
+ ) -> None:
264
+ super().__init__()
265
+ self.seqband = torch.jit.script(
266
+ nn.Sequential(
267
+ *[
268
+ ResidualConvolution(
269
+ emb_dim=emb_dim,
270
+ rnn_dim=rnn_dim,
271
+ bidirectional=bidirectional,
272
+ dropout=dropout,
273
+ )
274
+ for _ in range(2 * n_modules)
275
+ ]
276
+ )
277
+ )
278
+
279
+ def forward(self, z):
280
+
281
+ z = torch.permute(z, (0, 3, 1, 2))
282
+
283
+ z = self.seqband(z)
284
+
285
+ z = torch.permute(z, (0, 2, 3, 1))
286
+
287
+ return z
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvsepless/models/bandit/core/model/bsrnn/utils.py CHANGED
@@ -1,525 +1,518 @@
1
- import os
2
- from abc import abstractmethod
3
- from typing import Any, Callable
4
-
5
- import numpy as np
6
- import torch
7
- from librosa import hz_to_midi, midi_to_hz
8
- from torch import Tensor
9
- from torchaudio import functional as taF
10
- from spafe.fbanks import bark_fbanks
11
- from spafe.utils.converters import erb2hz, hz2bark, hz2erb
12
- from torchaudio.functional.functional import _create_triangular_filterbank
13
-
14
-
15
- def band_widths_from_specs(band_specs):
16
- return [e - i for i, e in band_specs]
17
-
18
-
19
- def check_nonzero_bandwidth(band_specs):
20
- # pprint(band_specs)
21
- for fstart, fend in band_specs:
22
- if fend - fstart <= 0:
23
- raise ValueError("Bands cannot be zero-width")
24
-
25
-
26
- def check_no_overlap(band_specs):
27
- fend_prev = -1
28
- for fstart_curr, fend_curr in band_specs:
29
- if fstart_curr <= fend_prev:
30
- raise ValueError("Bands cannot overlap")
31
-
32
-
33
- def check_no_gap(band_specs):
34
- fstart, _ = band_specs[0]
35
- assert fstart == 0
36
-
37
- fend_prev = -1
38
- for fstart_curr, fend_curr in band_specs:
39
- if fstart_curr - fend_prev > 1:
40
- raise ValueError("Bands cannot leave gap")
41
- fend_prev = fend_curr
42
-
43
-
44
- class BandsplitSpecification:
45
- def __init__(self, nfft: int, fs: int) -> None:
46
- self.fs = fs
47
- self.nfft = nfft
48
- self.nyquist = fs / 2
49
- self.max_index = nfft // 2 + 1
50
-
51
- self.split500 = self.hertz_to_index(500)
52
- self.split1k = self.hertz_to_index(1000)
53
- self.split2k = self.hertz_to_index(2000)
54
- self.split4k = self.hertz_to_index(4000)
55
- self.split8k = self.hertz_to_index(8000)
56
- self.split16k = self.hertz_to_index(16000)
57
- self.split20k = self.hertz_to_index(20000)
58
-
59
- self.above20k = [(self.split20k, self.max_index)]
60
- self.above16k = [(self.split16k, self.split20k)] + self.above20k
61
-
62
- def index_to_hertz(self, index: int):
63
- return index * self.fs / self.nfft
64
-
65
- def hertz_to_index(self, hz: float, round: bool = True):
66
- index = hz * self.nfft / self.fs
67
-
68
- if round:
69
- index = int(np.round(index))
70
-
71
- return index
72
-
73
- def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
74
- band_specs = []
75
- lower = start_index
76
-
77
- while lower < end_index:
78
- upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
79
- upper = min(upper, end_index)
80
-
81
- band_specs.append((lower, upper))
82
- lower = upper
83
-
84
- return band_specs
85
-
86
- @abstractmethod
87
- def get_band_specs(self):
88
- raise NotImplementedError
89
-
90
-
91
- class VocalBandsplitSpecification(BandsplitSpecification):
92
- def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
93
- super().__init__(nfft=nfft, fs=fs)
94
-
95
- self.version = version
96
-
97
- def get_band_specs(self):
98
- return getattr(self, f"version{self.version}")()
99
-
100
- @property
101
- def version1(self):
102
- return self.get_band_specs_with_bandwidth(
103
- start_index=0, end_index=self.max_index, bandwidth_hz=1000
104
- )
105
-
106
- def version2(self):
107
- below16k = self.get_band_specs_with_bandwidth(
108
- start_index=0, end_index=self.split16k, bandwidth_hz=1000
109
- )
110
- below20k = self.get_band_specs_with_bandwidth(
111
- start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
112
- )
113
-
114
- return below16k + below20k + self.above20k
115
-
116
- def version3(self):
117
- below8k = self.get_band_specs_with_bandwidth(
118
- start_index=0, end_index=self.split8k, bandwidth_hz=1000
119
- )
120
- below16k = self.get_band_specs_with_bandwidth(
121
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
122
- )
123
-
124
- return below8k + below16k + self.above16k
125
-
126
- def version4(self):
127
- below1k = self.get_band_specs_with_bandwidth(
128
- start_index=0, end_index=self.split1k, bandwidth_hz=100
129
- )
130
- below8k = self.get_band_specs_with_bandwidth(
131
- start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
132
- )
133
- below16k = self.get_band_specs_with_bandwidth(
134
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
135
- )
136
-
137
- return below1k + below8k + below16k + self.above16k
138
-
139
- def version5(self):
140
- below1k = self.get_band_specs_with_bandwidth(
141
- start_index=0, end_index=self.split1k, bandwidth_hz=100
142
- )
143
- below16k = self.get_band_specs_with_bandwidth(
144
- start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
145
- )
146
- below20k = self.get_band_specs_with_bandwidth(
147
- start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
148
- )
149
- return below1k + below16k + below20k + self.above20k
150
-
151
- def version6(self):
152
- below1k = self.get_band_specs_with_bandwidth(
153
- start_index=0, end_index=self.split1k, bandwidth_hz=100
154
- )
155
- below4k = self.get_band_specs_with_bandwidth(
156
- start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
157
- )
158
- below8k = self.get_band_specs_with_bandwidth(
159
- start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
160
- )
161
- below16k = self.get_band_specs_with_bandwidth(
162
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
163
- )
164
- return below1k + below4k + below8k + below16k + self.above16k
165
-
166
- def version7(self):
167
- below1k = self.get_band_specs_with_bandwidth(
168
- start_index=0, end_index=self.split1k, bandwidth_hz=100
169
- )
170
- below4k = self.get_band_specs_with_bandwidth(
171
- start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
172
- )
173
- below8k = self.get_band_specs_with_bandwidth(
174
- start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
175
- )
176
- below16k = self.get_band_specs_with_bandwidth(
177
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
178
- )
179
- below20k = self.get_band_specs_with_bandwidth(
180
- start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
181
- )
182
- return below1k + below4k + below8k + below16k + below20k + self.above20k
183
-
184
-
185
- class OtherBandsplitSpecification(VocalBandsplitSpecification):
186
- def __init__(self, nfft: int, fs: int) -> None:
187
- super().__init__(nfft=nfft, fs=fs, version="7")
188
-
189
-
190
- class BassBandsplitSpecification(BandsplitSpecification):
191
- def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
192
- super().__init__(nfft=nfft, fs=fs)
193
-
194
- def get_band_specs(self):
195
- below500 = self.get_band_specs_with_bandwidth(
196
- start_index=0, end_index=self.split500, bandwidth_hz=50
197
- )
198
- below1k = self.get_band_specs_with_bandwidth(
199
- start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
200
- )
201
- below4k = self.get_band_specs_with_bandwidth(
202
- start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
203
- )
204
- below8k = self.get_band_specs_with_bandwidth(
205
- start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
206
- )
207
- below16k = self.get_band_specs_with_bandwidth(
208
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
209
- )
210
- above16k = [(self.split16k, self.max_index)]
211
-
212
- return below500 + below1k + below4k + below8k + below16k + above16k
213
-
214
-
215
- class DrumBandsplitSpecification(BandsplitSpecification):
216
- def __init__(self, nfft: int, fs: int) -> None:
217
- super().__init__(nfft=nfft, fs=fs)
218
-
219
- def get_band_specs(self):
220
- below1k = self.get_band_specs_with_bandwidth(
221
- start_index=0, end_index=self.split1k, bandwidth_hz=50
222
- )
223
- below2k = self.get_band_specs_with_bandwidth(
224
- start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
225
- )
226
- below4k = self.get_band_specs_with_bandwidth(
227
- start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
228
- )
229
- below8k = self.get_band_specs_with_bandwidth(
230
- start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
231
- )
232
- below16k = self.get_band_specs_with_bandwidth(
233
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
234
- )
235
- above16k = [(self.split16k, self.max_index)]
236
-
237
- return below1k + below2k + below4k + below8k + below16k + above16k
238
-
239
-
240
- class PerceptualBandsplitSpecification(BandsplitSpecification):
241
- def __init__(
242
- self,
243
- nfft: int,
244
- fs: int,
245
- fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
246
- n_bands: int,
247
- f_min: float = 0.0,
248
- f_max: float = None,
249
- ) -> None:
250
- super().__init__(nfft=nfft, fs=fs)
251
- self.n_bands = n_bands
252
- if f_max is None:
253
- f_max = fs / 2
254
-
255
- self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
256
-
257
- weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True) # (1, n_freqs)
258
- normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs)
259
-
260
- freq_weights = []
261
- band_specs = []
262
- for i in range(self.n_bands):
263
- active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
264
- if isinstance(active_bins, int):
265
- active_bins = (active_bins, active_bins)
266
- if len(active_bins) == 0:
267
- continue
268
- start_index = active_bins[0]
269
- end_index = active_bins[-1] + 1
270
- band_specs.append((start_index, end_index))
271
- freq_weights.append(normalized_mel_fb[i, start_index:end_index])
272
-
273
- self.freq_weights = freq_weights
274
- self.band_specs = band_specs
275
-
276
- def get_band_specs(self):
277
- return self.band_specs
278
-
279
- def get_freq_weights(self):
280
- return self.freq_weights
281
-
282
- def save_to_file(self, dir_path: str) -> None:
283
-
284
- os.makedirs(dir_path, exist_ok=True)
285
-
286
- import pickle
287
-
288
- with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
289
- pickle.dump(
290
- {
291
- "band_specs": self.band_specs,
292
- "freq_weights": self.freq_weights,
293
- "filterbank": self.filterbank,
294
- },
295
- f,
296
- )
297
-
298
-
299
- def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
300
- fb = taF.melscale_fbanks(
301
- n_mels=n_bands,
302
- sample_rate=fs,
303
- f_min=f_min,
304
- f_max=f_max,
305
- n_freqs=n_freqs,
306
- ).T
307
-
308
- fb[0, 0] = 1.0
309
-
310
- return fb
311
-
312
-
313
- class MelBandsplitSpecification(PerceptualBandsplitSpecification):
314
- def __init__(
315
- self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
316
- ) -> None:
317
- super().__init__(
318
- fbank_fn=mel_filterbank,
319
- nfft=nfft,
320
- fs=fs,
321
- n_bands=n_bands,
322
- f_min=f_min,
323
- f_max=f_max,
324
- )
325
-
326
-
327
- def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
328
-
329
- nfft = 2 * (n_freqs - 1)
330
- df = fs / nfft
331
- # init freqs
332
- f_max = f_max or fs / 2
333
- f_min = f_min or 0
334
- f_min = fs / nfft
335
-
336
- n_octaves = np.log2(f_max / f_min)
337
- n_octaves_per_band = n_octaves / n_bands
338
- bandwidth_mult = np.power(2.0, n_octaves_per_band)
339
-
340
- low_midi = max(0, hz_to_midi(f_min))
341
- high_midi = hz_to_midi(f_max)
342
- midi_points = np.linspace(low_midi, high_midi, n_bands)
343
- hz_pts = midi_to_hz(midi_points)
344
-
345
- low_pts = hz_pts / bandwidth_mult
346
- high_pts = hz_pts * bandwidth_mult
347
-
348
- low_bins = np.floor(low_pts / df).astype(int)
349
- high_bins = np.ceil(high_pts / df).astype(int)
350
-
351
- fb = np.zeros((n_bands, n_freqs))
352
-
353
- for i in range(n_bands):
354
- fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
355
-
356
- fb[0, : low_bins[0]] = 1.0
357
- fb[-1, high_bins[-1] + 1 :] = 1.0
358
-
359
- return torch.as_tensor(fb)
360
-
361
-
362
- class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
363
- def __init__(
364
- self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
365
- ) -> None:
366
- super().__init__(
367
- fbank_fn=musical_filterbank,
368
- nfft=nfft,
369
- fs=fs,
370
- n_bands=n_bands,
371
- f_min=f_min,
372
- f_max=f_max,
373
- )
374
-
375
-
376
- def bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
377
- nfft = 2 * (n_freqs - 1)
378
- fb, _ = bark_fbanks.bark_filter_banks(
379
- nfilts=n_bands,
380
- nfft=nfft,
381
- fs=fs,
382
- low_freq=f_min,
383
- high_freq=f_max,
384
- scale="constant",
385
- )
386
-
387
- return torch.as_tensor(fb)
388
-
389
-
390
- class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
391
- def __init__(
392
- self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
393
- ) -> None:
394
- super().__init__(
395
- fbank_fn=bark_filterbank,
396
- nfft=nfft,
397
- fs=fs,
398
- n_bands=n_bands,
399
- f_min=f_min,
400
- f_max=f_max,
401
- )
402
-
403
-
404
- def triangular_bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
405
-
406
- all_freqs = torch.linspace(0, fs // 2, n_freqs)
407
-
408
- # calculate mel freq bins
409
- m_min = hz2bark(f_min)
410
- m_max = hz2bark(f_max)
411
-
412
- m_pts = torch.linspace(m_min, m_max, n_bands + 2)
413
- f_pts = 600 * torch.sinh(m_pts / 6)
414
-
415
- # create filterbank
416
- fb = _create_triangular_filterbank(all_freqs, f_pts)
417
-
418
- fb = fb.T
419
-
420
- first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
421
- first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
422
-
423
- fb[first_active_band, :first_active_bin] = 1.0
424
-
425
- return fb
426
-
427
-
428
- class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
429
- def __init__(
430
- self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
431
- ) -> None:
432
- super().__init__(
433
- fbank_fn=triangular_bark_filterbank,
434
- nfft=nfft,
435
- fs=fs,
436
- n_bands=n_bands,
437
- f_min=f_min,
438
- f_max=f_max,
439
- )
440
-
441
-
442
- def minibark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
443
- fb = bark_filterbank(n_bands, fs, f_min, f_max, n_freqs)
444
-
445
- fb[fb < np.sqrt(0.5)] = 0.0
446
-
447
- return fb
448
-
449
-
450
- class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
451
- def __init__(
452
- self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
453
- ) -> None:
454
- super().__init__(
455
- fbank_fn=minibark_filterbank,
456
- nfft=nfft,
457
- fs=fs,
458
- n_bands=n_bands,
459
- f_min=f_min,
460
- f_max=f_max,
461
- )
462
-
463
-
464
- def erb_filterbank(
465
- n_bands: int,
466
- fs: int,
467
- f_min: float,
468
- f_max: float,
469
- n_freqs: int,
470
- ) -> Tensor:
471
- # freq bins
472
- A = (1000 * np.log(10)) / (24.7 * 4.37)
473
- all_freqs = torch.linspace(0, fs // 2, n_freqs)
474
-
475
- # calculate mel freq bins
476
- m_min = hz2erb(f_min)
477
- m_max = hz2erb(f_max)
478
-
479
- m_pts = torch.linspace(m_min, m_max, n_bands + 2)
480
- f_pts = (torch.pow(10, (m_pts / A)) - 1) / 0.00437
481
-
482
- # create filterbank
483
- fb = _create_triangular_filterbank(all_freqs, f_pts)
484
-
485
- fb = fb.T
486
-
487
- first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
488
- first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
489
-
490
- fb[first_active_band, :first_active_bin] = 1.0
491
-
492
- return fb
493
-
494
-
495
- class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
496
- def __init__(
497
- self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
498
- ) -> None:
499
- super().__init__(
500
- fbank_fn=erb_filterbank,
501
- nfft=nfft,
502
- fs=fs,
503
- n_bands=n_bands,
504
- f_min=f_min,
505
- f_max=f_max,
506
- )
507
-
508
-
509
- if __name__ == "__main__":
510
- import pandas as pd
511
-
512
- band_defs = []
513
-
514
- for bands in [VocalBandsplitSpecification]:
515
- band_name = bands.__name__.replace("BandsplitSpecification", "")
516
-
517
- mbs = bands(nfft=2048, fs=44100).get_band_specs()
518
-
519
- for i, (f_min, f_max) in enumerate(mbs):
520
- band_defs.append(
521
- {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
522
- )
523
-
524
- df = pd.DataFrame(band_defs)
525
- df.to_csv("vox7bands.csv", index=False)
 
1
+ import os
2
+ from abc import abstractmethod
3
+ from typing import Any, Callable
4
+
5
+ import numpy as np
6
+ import torch
7
+ from librosa import hz_to_midi, midi_to_hz
8
+ from torch import Tensor
9
+ from torchaudio import functional as taF
10
+ from spafe.fbanks import bark_fbanks
11
+ from spafe.utils.converters import erb2hz, hz2bark, hz2erb
12
+ from torchaudio.functional.functional import _create_triangular_filterbank
13
+
14
+
15
+ def band_widths_from_specs(band_specs):
16
+ return [e - i for i, e in band_specs]
17
+
18
+
19
+ def check_nonzero_bandwidth(band_specs):
20
+ for fstart, fend in band_specs:
21
+ if fend - fstart <= 0:
22
+ raise ValueError("Bands cannot be zero-width")
23
+
24
+
25
+ def check_no_overlap(band_specs):
26
+ fend_prev = -1
27
+ for fstart_curr, fend_curr in band_specs:
28
+ if fstart_curr <= fend_prev:
29
+ raise ValueError("Bands cannot overlap")
30
+
31
+
32
+ def check_no_gap(band_specs):
33
+ fstart, _ = band_specs[0]
34
+ assert fstart == 0
35
+
36
+ fend_prev = -1
37
+ for fstart_curr, fend_curr in band_specs:
38
+ if fstart_curr - fend_prev > 1:
39
+ raise ValueError("Bands cannot leave gap")
40
+ fend_prev = fend_curr
41
+
42
+
43
+ class BandsplitSpecification:
44
+ def __init__(self, nfft: int, fs: int) -> None:
45
+ self.fs = fs
46
+ self.nfft = nfft
47
+ self.nyquist = fs / 2
48
+ self.max_index = nfft // 2 + 1
49
+
50
+ self.split500 = self.hertz_to_index(500)
51
+ self.split1k = self.hertz_to_index(1000)
52
+ self.split2k = self.hertz_to_index(2000)
53
+ self.split4k = self.hertz_to_index(4000)
54
+ self.split8k = self.hertz_to_index(8000)
55
+ self.split16k = self.hertz_to_index(16000)
56
+ self.split20k = self.hertz_to_index(20000)
57
+
58
+ self.above20k = [(self.split20k, self.max_index)]
59
+ self.above16k = [(self.split16k, self.split20k)] + self.above20k
60
+
61
+ def index_to_hertz(self, index: int):
62
+ return index * self.fs / self.nfft
63
+
64
+ def hertz_to_index(self, hz: float, round: bool = True):
65
+ index = hz * self.nfft / self.fs
66
+
67
+ if round:
68
+ index = int(np.round(index))
69
+
70
+ return index
71
+
72
+ def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
73
+ band_specs = []
74
+ lower = start_index
75
+
76
+ while lower < end_index:
77
+ upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
78
+ upper = min(upper, end_index)
79
+
80
+ band_specs.append((lower, upper))
81
+ lower = upper
82
+
83
+ return band_specs
84
+
85
+ @abstractmethod
86
+ def get_band_specs(self):
87
+ raise NotImplementedError
88
+
89
+
90
+ class VocalBandsplitSpecification(BandsplitSpecification):
91
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
92
+ super().__init__(nfft=nfft, fs=fs)
93
+
94
+ self.version = version
95
+
96
+ def get_band_specs(self):
97
+ return getattr(self, f"version{self.version}")()
98
+
99
+ @property
100
+ def version1(self):
101
+ return self.get_band_specs_with_bandwidth(
102
+ start_index=0, end_index=self.max_index, bandwidth_hz=1000
103
+ )
104
+
105
+ def version2(self):
106
+ below16k = self.get_band_specs_with_bandwidth(
107
+ start_index=0, end_index=self.split16k, bandwidth_hz=1000
108
+ )
109
+ below20k = self.get_band_specs_with_bandwidth(
110
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
111
+ )
112
+
113
+ return below16k + below20k + self.above20k
114
+
115
+ def version3(self):
116
+ below8k = self.get_band_specs_with_bandwidth(
117
+ start_index=0, end_index=self.split8k, bandwidth_hz=1000
118
+ )
119
+ below16k = self.get_band_specs_with_bandwidth(
120
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
121
+ )
122
+
123
+ return below8k + below16k + self.above16k
124
+
125
+ def version4(self):
126
+ below1k = self.get_band_specs_with_bandwidth(
127
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
128
+ )
129
+ below8k = self.get_band_specs_with_bandwidth(
130
+ start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
131
+ )
132
+ below16k = self.get_band_specs_with_bandwidth(
133
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
134
+ )
135
+
136
+ return below1k + below8k + below16k + self.above16k
137
+
138
+ def version5(self):
139
+ below1k = self.get_band_specs_with_bandwidth(
140
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
141
+ )
142
+ below16k = self.get_band_specs_with_bandwidth(
143
+ start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
144
+ )
145
+ below20k = self.get_band_specs_with_bandwidth(
146
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
147
+ )
148
+ return below1k + below16k + below20k + self.above20k
149
+
150
+ def version6(self):
151
+ below1k = self.get_band_specs_with_bandwidth(
152
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
153
+ )
154
+ below4k = self.get_band_specs_with_bandwidth(
155
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
156
+ )
157
+ below8k = self.get_band_specs_with_bandwidth(
158
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
159
+ )
160
+ below16k = self.get_band_specs_with_bandwidth(
161
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
162
+ )
163
+ return below1k + below4k + below8k + below16k + self.above16k
164
+
165
+ def version7(self):
166
+ below1k = self.get_band_specs_with_bandwidth(
167
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
168
+ )
169
+ below4k = self.get_band_specs_with_bandwidth(
170
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
171
+ )
172
+ below8k = self.get_band_specs_with_bandwidth(
173
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
174
+ )
175
+ below16k = self.get_band_specs_with_bandwidth(
176
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
177
+ )
178
+ below20k = self.get_band_specs_with_bandwidth(
179
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
180
+ )
181
+ return below1k + below4k + below8k + below16k + below20k + self.above20k
182
+
183
+
184
+ class OtherBandsplitSpecification(VocalBandsplitSpecification):
185
+ def __init__(self, nfft: int, fs: int) -> None:
186
+ super().__init__(nfft=nfft, fs=fs, version="7")
187
+
188
+
189
+ class BassBandsplitSpecification(BandsplitSpecification):
190
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
191
+ super().__init__(nfft=nfft, fs=fs)
192
+
193
+ def get_band_specs(self):
194
+ below500 = self.get_band_specs_with_bandwidth(
195
+ start_index=0, end_index=self.split500, bandwidth_hz=50
196
+ )
197
+ below1k = self.get_band_specs_with_bandwidth(
198
+ start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
199
+ )
200
+ below4k = self.get_band_specs_with_bandwidth(
201
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
202
+ )
203
+ below8k = self.get_band_specs_with_bandwidth(
204
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
205
+ )
206
+ below16k = self.get_band_specs_with_bandwidth(
207
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
208
+ )
209
+ above16k = [(self.split16k, self.max_index)]
210
+
211
+ return below500 + below1k + below4k + below8k + below16k + above16k
212
+
213
+
214
+ class DrumBandsplitSpecification(BandsplitSpecification):
215
+ def __init__(self, nfft: int, fs: int) -> None:
216
+ super().__init__(nfft=nfft, fs=fs)
217
+
218
+ def get_band_specs(self):
219
+ below1k = self.get_band_specs_with_bandwidth(
220
+ start_index=0, end_index=self.split1k, bandwidth_hz=50
221
+ )
222
+ below2k = self.get_band_specs_with_bandwidth(
223
+ start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
224
+ )
225
+ below4k = self.get_band_specs_with_bandwidth(
226
+ start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
227
+ )
228
+ below8k = self.get_band_specs_with_bandwidth(
229
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
230
+ )
231
+ below16k = self.get_band_specs_with_bandwidth(
232
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
233
+ )
234
+ above16k = [(self.split16k, self.max_index)]
235
+
236
+ return below1k + below2k + below4k + below8k + below16k + above16k
237
+
238
+
239
+ class PerceptualBandsplitSpecification(BandsplitSpecification):
240
+ def __init__(
241
+ self,
242
+ nfft: int,
243
+ fs: int,
244
+ fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
245
+ n_bands: int,
246
+ f_min: float = 0.0,
247
+ f_max: float = None,
248
+ ) -> None:
249
+ super().__init__(nfft=nfft, fs=fs)
250
+ self.n_bands = n_bands
251
+ if f_max is None:
252
+ f_max = fs / 2
253
+
254
+ self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
255
+
256
+ weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True)
257
+ normalized_mel_fb = self.filterbank / weight_per_bin
258
+
259
+ freq_weights = []
260
+ band_specs = []
261
+ for i in range(self.n_bands):
262
+ active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
263
+ if isinstance(active_bins, int):
264
+ active_bins = (active_bins, active_bins)
265
+ if len(active_bins) == 0:
266
+ continue
267
+ start_index = active_bins[0]
268
+ end_index = active_bins[-1] + 1
269
+ band_specs.append((start_index, end_index))
270
+ freq_weights.append(normalized_mel_fb[i, start_index:end_index])
271
+
272
+ self.freq_weights = freq_weights
273
+ self.band_specs = band_specs
274
+
275
+ def get_band_specs(self):
276
+ return self.band_specs
277
+
278
+ def get_freq_weights(self):
279
+ return self.freq_weights
280
+
281
+ def save_to_file(self, dir_path: str) -> None:
282
+
283
+ os.makedirs(dir_path, exist_ok=True)
284
+
285
+ import pickle
286
+
287
+ with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
288
+ pickle.dump(
289
+ {
290
+ "band_specs": self.band_specs,
291
+ "freq_weights": self.freq_weights,
292
+ "filterbank": self.filterbank,
293
+ },
294
+ f,
295
+ )
296
+
297
+
298
+ def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
299
+ fb = taF.melscale_fbanks(
300
+ n_mels=n_bands,
301
+ sample_rate=fs,
302
+ f_min=f_min,
303
+ f_max=f_max,
304
+ n_freqs=n_freqs,
305
+ ).T
306
+
307
+ fb[0, 0] = 1.0
308
+
309
+ return fb
310
+
311
+
312
+ class MelBandsplitSpecification(PerceptualBandsplitSpecification):
313
+ def __init__(
314
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
315
+ ) -> None:
316
+ super().__init__(
317
+ fbank_fn=mel_filterbank,
318
+ nfft=nfft,
319
+ fs=fs,
320
+ n_bands=n_bands,
321
+ f_min=f_min,
322
+ f_max=f_max,
323
+ )
324
+
325
+
326
+ def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
327
+
328
+ nfft = 2 * (n_freqs - 1)
329
+ df = fs / nfft
330
+ f_max = f_max or fs / 2
331
+ f_min = f_min or 0
332
+ f_min = fs / nfft
333
+
334
+ n_octaves = np.log2(f_max / f_min)
335
+ n_octaves_per_band = n_octaves / n_bands
336
+ bandwidth_mult = np.power(2.0, n_octaves_per_band)
337
+
338
+ low_midi = max(0, hz_to_midi(f_min))
339
+ high_midi = hz_to_midi(f_max)
340
+ midi_points = np.linspace(low_midi, high_midi, n_bands)
341
+ hz_pts = midi_to_hz(midi_points)
342
+
343
+ low_pts = hz_pts / bandwidth_mult
344
+ high_pts = hz_pts * bandwidth_mult
345
+
346
+ low_bins = np.floor(low_pts / df).astype(int)
347
+ high_bins = np.ceil(high_pts / df).astype(int)
348
+
349
+ fb = np.zeros((n_bands, n_freqs))
350
+
351
+ for i in range(n_bands):
352
+ fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
353
+
354
+ fb[0, : low_bins[0]] = 1.0
355
+ fb[-1, high_bins[-1] + 1 :] = 1.0
356
+
357
+ return torch.as_tensor(fb)
358
+
359
+
360
+ class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
361
+ def __init__(
362
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
363
+ ) -> None:
364
+ super().__init__(
365
+ fbank_fn=musical_filterbank,
366
+ nfft=nfft,
367
+ fs=fs,
368
+ n_bands=n_bands,
369
+ f_min=f_min,
370
+ f_max=f_max,
371
+ )
372
+
373
+
374
+ def bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
375
+ nfft = 2 * (n_freqs - 1)
376
+ fb, _ = bark_fbanks.bark_filter_banks(
377
+ nfilts=n_bands,
378
+ nfft=nfft,
379
+ fs=fs,
380
+ low_freq=f_min,
381
+ high_freq=f_max,
382
+ scale="constant",
383
+ )
384
+
385
+ return torch.as_tensor(fb)
386
+
387
+
388
+ class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
389
+ def __init__(
390
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
391
+ ) -> None:
392
+ super().__init__(
393
+ fbank_fn=bark_filterbank,
394
+ nfft=nfft,
395
+ fs=fs,
396
+ n_bands=n_bands,
397
+ f_min=f_min,
398
+ f_max=f_max,
399
+ )
400
+
401
+
402
+ def triangular_bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
403
+
404
+ all_freqs = torch.linspace(0, fs // 2, n_freqs)
405
+
406
+ m_min = hz2bark(f_min)
407
+ m_max = hz2bark(f_max)
408
+
409
+ m_pts = torch.linspace(m_min, m_max, n_bands + 2)
410
+ f_pts = 600 * torch.sinh(m_pts / 6)
411
+
412
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
413
+
414
+ fb = fb.T
415
+
416
+ first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
417
+ first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
418
+
419
+ fb[first_active_band, :first_active_bin] = 1.0
420
+
421
+ return fb
422
+
423
+
424
+ class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
425
+ def __init__(
426
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
427
+ ) -> None:
428
+ super().__init__(
429
+ fbank_fn=triangular_bark_filterbank,
430
+ nfft=nfft,
431
+ fs=fs,
432
+ n_bands=n_bands,
433
+ f_min=f_min,
434
+ f_max=f_max,
435
+ )
436
+
437
+
438
+ def minibark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
439
+ fb = bark_filterbank(n_bands, fs, f_min, f_max, n_freqs)
440
+
441
+ fb[fb < np.sqrt(0.5)] = 0.0
442
+
443
+ return fb
444
+
445
+
446
+ class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
447
+ def __init__(
448
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
449
+ ) -> None:
450
+ super().__init__(
451
+ fbank_fn=minibark_filterbank,
452
+ nfft=nfft,
453
+ fs=fs,
454
+ n_bands=n_bands,
455
+ f_min=f_min,
456
+ f_max=f_max,
457
+ )
458
+
459
+
460
+ def erb_filterbank(
461
+ n_bands: int,
462
+ fs: int,
463
+ f_min: float,
464
+ f_max: float,
465
+ n_freqs: int,
466
+ ) -> Tensor:
467
+ A = (1000 * np.log(10)) / (24.7 * 4.37)
468
+ all_freqs = torch.linspace(0, fs // 2, n_freqs)
469
+
470
+ m_min = hz2erb(f_min)
471
+ m_max = hz2erb(f_max)
472
+
473
+ m_pts = torch.linspace(m_min, m_max, n_bands + 2)
474
+ f_pts = (torch.pow(10, (m_pts / A)) - 1) / 0.00437
475
+
476
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
477
+
478
+ fb = fb.T
479
+
480
+ first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
481
+ first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
482
+
483
+ fb[first_active_band, :first_active_bin] = 1.0
484
+
485
+ return fb
486
+
487
+
488
+ class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
489
+ def __init__(
490
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
491
+ ) -> None:
492
+ super().__init__(
493
+ fbank_fn=erb_filterbank,
494
+ nfft=nfft,
495
+ fs=fs,
496
+ n_bands=n_bands,
497
+ f_min=f_min,
498
+ f_max=f_max,
499
+ )
500
+
501
+
502
+ if __name__ == "__main__":
503
+ import pandas as pd
504
+
505
+ band_defs = []
506
+
507
+ for bands in [VocalBandsplitSpecification]:
508
+ band_name = bands.__name__.replace("BandsplitSpecification", "")
509
+
510
+ mbs = bands(nfft=2048, fs=44100).get_band_specs()
511
+
512
+ for i, (f_min, f_max) in enumerate(mbs):
513
+ band_defs.append(
514
+ {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
515
+ )
516
+
517
+ df = pd.DataFrame(band_defs)
518
+ df.to_csv("vox7bands.csv", index=False)
 
 
 
 
 
 
 
mvsepless/models/bandit/core/model/bsrnn/wrapper.py CHANGED
@@ -1,829 +1,828 @@
1
- from pprint import pprint
2
- from typing import Dict, List, Optional, Tuple, Union
3
-
4
- import torch
5
- from torch import nn
6
-
7
- from .._spectral import _SpectralComponent
8
- from .utils import (
9
- BarkBandsplitSpecification,
10
- BassBandsplitSpecification,
11
- DrumBandsplitSpecification,
12
- EquivalentRectangularBandsplitSpecification,
13
- MelBandsplitSpecification,
14
- MusicalBandsplitSpecification,
15
- OtherBandsplitSpecification,
16
- TriangularBarkBandsplitSpecification,
17
- VocalBandsplitSpecification,
18
- )
19
- from .core import (
20
- MultiSourceMultiMaskBandSplitCoreConv,
21
- MultiSourceMultiMaskBandSplitCoreRNN,
22
- MultiSourceMultiMaskBandSplitCoreTransformer,
23
- MultiSourceMultiPatchingMaskBandSplitCoreRNN,
24
- SingleMaskBandsplitCoreRNN,
25
- SingleMaskBandsplitCoreTransformer,
26
- )
27
-
28
- import pytorch_lightning as pl
29
-
30
-
31
- def get_band_specs(band_specs, n_fft, fs, n_bands=None):
32
- if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]:
33
- bsm = VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs()
34
- freq_weights = None
35
- overlapping_band = False
36
- elif "tribark" in band_specs:
37
- assert n_bands is not None
38
- specs = TriangularBarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
39
- bsm = specs.get_band_specs()
40
- freq_weights = specs.get_freq_weights()
41
- overlapping_band = True
42
- elif "bark" in band_specs:
43
- assert n_bands is not None
44
- specs = BarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
45
- bsm = specs.get_band_specs()
46
- freq_weights = specs.get_freq_weights()
47
- overlapping_band = True
48
- elif "erb" in band_specs:
49
- assert n_bands is not None
50
- specs = EquivalentRectangularBandsplitSpecification(
51
- nfft=n_fft, fs=fs, n_bands=n_bands
52
- )
53
- bsm = specs.get_band_specs()
54
- freq_weights = specs.get_freq_weights()
55
- overlapping_band = True
56
- elif "musical" in band_specs:
57
- assert n_bands is not None
58
- specs = MusicalBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
59
- bsm = specs.get_band_specs()
60
- freq_weights = specs.get_freq_weights()
61
- overlapping_band = True
62
- elif band_specs == "dnr:mel" or "mel" in band_specs:
63
- assert n_bands is not None
64
- specs = MelBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
65
- bsm = specs.get_band_specs()
66
- freq_weights = specs.get_freq_weights()
67
- overlapping_band = True
68
- else:
69
- raise NameError
70
-
71
- return bsm, freq_weights, overlapping_band
72
-
73
-
74
- def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None):
75
- if band_specs_map == "musdb:all":
76
- bsm = {
77
- "vocals": VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
78
- "drums": DrumBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
79
- "bass": BassBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
80
- "other": OtherBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
81
- }
82
- freq_weights = None
83
- overlapping_band = False
84
- elif band_specs_map == "dnr:vox7":
85
- bsm_, freq_weights, overlapping_band = get_band_specs(
86
- "dnr:speech", n_fft, fs, n_bands
87
- )
88
- bsm = {"speech": bsm_, "music": bsm_, "effects": bsm_}
89
- elif "dnr:vox7:" in band_specs_map:
90
- stem = band_specs_map.split(":")[-1]
91
- bsm_, freq_weights, overlapping_band = get_band_specs(
92
- "dnr:speech", n_fft, fs, n_bands
93
- )
94
- bsm = {stem: bsm_}
95
- else:
96
- raise NameError
97
-
98
- return bsm, freq_weights, overlapping_band
99
-
100
-
101
- class BandSplitWrapperBase(pl.LightningModule):
102
- bsrnn: nn.Module
103
-
104
- def __init__(self, **kwargs):
105
- super().__init__()
106
-
107
-
108
- class SingleMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
109
- def __init__(
110
- self,
111
- band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
112
- fs: int = 44100,
113
- n_fft: int = 2048,
114
- win_length: Optional[int] = 2048,
115
- hop_length: int = 512,
116
- window_fn: str = "hann_window",
117
- wkwargs: Optional[Dict] = None,
118
- power: Optional[int] = None,
119
- center: bool = True,
120
- normalized: bool = True,
121
- pad_mode: str = "constant",
122
- onesided: bool = True,
123
- n_bands: int = None,
124
- ) -> None:
125
- super().__init__(
126
- n_fft=n_fft,
127
- win_length=win_length,
128
- hop_length=hop_length,
129
- window_fn=window_fn,
130
- wkwargs=wkwargs,
131
- power=power,
132
- center=center,
133
- normalized=normalized,
134
- pad_mode=pad_mode,
135
- onesided=onesided,
136
- )
137
-
138
- if isinstance(band_specs_map, str):
139
- self.band_specs_map, self.freq_weights, self.overlapping_band = (
140
- get_band_specs_map(band_specs_map, n_fft, fs, n_bands=n_bands)
141
- )
142
-
143
- self.stems = list(self.band_specs_map.keys())
144
-
145
- def forward(self, batch):
146
- audio = batch["audio"]
147
-
148
- with torch.no_grad():
149
- batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
150
-
151
- X = batch["spectrogram"]["mixture"]
152
- length = batch["audio"]["mixture"].shape[-1]
153
-
154
- output = {"spectrogram": {}, "audio": {}}
155
-
156
- for stem, bsrnn in self.bsrnn.items():
157
- S = bsrnn(X)
158
- s = self.istft(S, length)
159
- output["spectrogram"][stem] = S
160
- output["audio"][stem] = s
161
-
162
- return batch, output
163
-
164
-
165
- class MultiMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
166
- def __init__(
167
- self,
168
- stems: List[str],
169
- band_specs: Union[str, List[Tuple[float, float]]],
170
- fs: int = 44100,
171
- n_fft: int = 2048,
172
- win_length: Optional[int] = 2048,
173
- hop_length: int = 512,
174
- window_fn: str = "hann_window",
175
- wkwargs: Optional[Dict] = None,
176
- power: Optional[int] = None,
177
- center: bool = True,
178
- normalized: bool = True,
179
- pad_mode: str = "constant",
180
- onesided: bool = True,
181
- n_bands: int = None,
182
- ) -> None:
183
- super().__init__(
184
- n_fft=n_fft,
185
- win_length=win_length,
186
- hop_length=hop_length,
187
- window_fn=window_fn,
188
- wkwargs=wkwargs,
189
- power=power,
190
- center=center,
191
- normalized=normalized,
192
- pad_mode=pad_mode,
193
- onesided=onesided,
194
- )
195
-
196
- if isinstance(band_specs, str):
197
- self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
198
- band_specs, n_fft, fs, n_bands
199
- )
200
-
201
- self.stems = stems
202
-
203
- def forward(self, batch):
204
- # with torch.no_grad():
205
- audio = batch["audio"]
206
- cond = batch.get("condition", None)
207
- with torch.no_grad():
208
- batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
209
-
210
- X = batch["spectrogram"]["mixture"]
211
- length = batch["audio"]["mixture"].shape[-1]
212
-
213
- output = self.bsrnn(X, cond=cond)
214
- output["audio"] = {}
215
-
216
- for stem, S in output["spectrogram"].items():
217
- s = self.istft(S, length)
218
- output["audio"][stem] = s
219
-
220
- return batch, output
221
-
222
-
223
- class MultiMaskMultiSourceBandSplitBaseSimple(BandSplitWrapperBase, _SpectralComponent):
224
- def __init__(
225
- self,
226
- stems: List[str],
227
- band_specs: Union[str, List[Tuple[float, float]]],
228
- fs: int = 44100,
229
- n_fft: int = 2048,
230
- win_length: Optional[int] = 2048,
231
- hop_length: int = 512,
232
- window_fn: str = "hann_window",
233
- wkwargs: Optional[Dict] = None,
234
- power: Optional[int] = None,
235
- center: bool = True,
236
- normalized: bool = True,
237
- pad_mode: str = "constant",
238
- onesided: bool = True,
239
- n_bands: int = None,
240
- ) -> None:
241
- super().__init__(
242
- n_fft=n_fft,
243
- win_length=win_length,
244
- hop_length=hop_length,
245
- window_fn=window_fn,
246
- wkwargs=wkwargs,
247
- power=power,
248
- center=center,
249
- normalized=normalized,
250
- pad_mode=pad_mode,
251
- onesided=onesided,
252
- )
253
-
254
- if isinstance(band_specs, str):
255
- self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
256
- band_specs, n_fft, fs, n_bands
257
- )
258
-
259
- self.stems = stems
260
-
261
- def forward(self, batch):
262
- with torch.no_grad():
263
- X = self.stft(batch)
264
- length = batch.shape[-1]
265
- output = self.bsrnn(X, cond=None)
266
- res = []
267
- for stem, S in output["spectrogram"].items():
268
- s = self.istft(S, length)
269
- res.append(s)
270
- res = torch.stack(res, dim=1)
271
- return res
272
-
273
-
274
- class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase):
275
- def __init__(
276
- self,
277
- in_channel: int,
278
- band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
279
- fs: int = 44100,
280
- require_no_overlap: bool = False,
281
- require_no_gap: bool = True,
282
- normalize_channel_independently: bool = False,
283
- treat_channel_as_feature: bool = True,
284
- n_sqm_modules: int = 12,
285
- emb_dim: int = 128,
286
- rnn_dim: int = 256,
287
- bidirectional: bool = True,
288
- rnn_type: str = "LSTM",
289
- mlp_dim: int = 512,
290
- hidden_activation: str = "Tanh",
291
- hidden_activation_kwargs: Optional[Dict] = None,
292
- complex_mask: bool = True,
293
- n_fft: int = 2048,
294
- win_length: Optional[int] = 2048,
295
- hop_length: int = 512,
296
- window_fn: str = "hann_window",
297
- wkwargs: Optional[Dict] = None,
298
- power: Optional[int] = None,
299
- center: bool = True,
300
- normalized: bool = True,
301
- pad_mode: str = "constant",
302
- onesided: bool = True,
303
- ) -> None:
304
- super().__init__(
305
- band_specs_map=band_specs_map,
306
- fs=fs,
307
- n_fft=n_fft,
308
- win_length=win_length,
309
- hop_length=hop_length,
310
- window_fn=window_fn,
311
- wkwargs=wkwargs,
312
- power=power,
313
- center=center,
314
- normalized=normalized,
315
- pad_mode=pad_mode,
316
- onesided=onesided,
317
- )
318
-
319
- self.bsrnn = nn.ModuleDict(
320
- {
321
- src: SingleMaskBandsplitCoreRNN(
322
- band_specs=specs,
323
- in_channel=in_channel,
324
- require_no_overlap=require_no_overlap,
325
- require_no_gap=require_no_gap,
326
- normalize_channel_independently=normalize_channel_independently,
327
- treat_channel_as_feature=treat_channel_as_feature,
328
- n_sqm_modules=n_sqm_modules,
329
- emb_dim=emb_dim,
330
- rnn_dim=rnn_dim,
331
- bidirectional=bidirectional,
332
- rnn_type=rnn_type,
333
- mlp_dim=mlp_dim,
334
- hidden_activation=hidden_activation,
335
- hidden_activation_kwargs=hidden_activation_kwargs,
336
- complex_mask=complex_mask,
337
- )
338
- for src, specs in self.band_specs_map.items()
339
- }
340
- )
341
-
342
-
343
- class SingleMaskMultiSourceBandSplitTransformer(SingleMaskMultiSourceBandSplitBase):
344
- def __init__(
345
- self,
346
- in_channel: int,
347
- band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
348
- fs: int = 44100,
349
- require_no_overlap: bool = False,
350
- require_no_gap: bool = True,
351
- normalize_channel_independently: bool = False,
352
- treat_channel_as_feature: bool = True,
353
- n_sqm_modules: int = 12,
354
- emb_dim: int = 128,
355
- rnn_dim: int = 256,
356
- bidirectional: bool = True,
357
- tf_dropout: float = 0.0,
358
- mlp_dim: int = 512,
359
- hidden_activation: str = "Tanh",
360
- hidden_activation_kwargs: Optional[Dict] = None,
361
- complex_mask: bool = True,
362
- n_fft: int = 2048,
363
- win_length: Optional[int] = 2048,
364
- hop_length: int = 512,
365
- window_fn: str = "hann_window",
366
- wkwargs: Optional[Dict] = None,
367
- power: Optional[int] = None,
368
- center: bool = True,
369
- normalized: bool = True,
370
- pad_mode: str = "constant",
371
- onesided: bool = True,
372
- ) -> None:
373
- super().__init__(
374
- band_specs_map=band_specs_map,
375
- fs=fs,
376
- n_fft=n_fft,
377
- win_length=win_length,
378
- hop_length=hop_length,
379
- window_fn=window_fn,
380
- wkwargs=wkwargs,
381
- power=power,
382
- center=center,
383
- normalized=normalized,
384
- pad_mode=pad_mode,
385
- onesided=onesided,
386
- )
387
-
388
- self.bsrnn = nn.ModuleDict(
389
- {
390
- src: SingleMaskBandsplitCoreTransformer(
391
- band_specs=specs,
392
- in_channel=in_channel,
393
- require_no_overlap=require_no_overlap,
394
- require_no_gap=require_no_gap,
395
- normalize_channel_independently=normalize_channel_independently,
396
- treat_channel_as_feature=treat_channel_as_feature,
397
- n_sqm_modules=n_sqm_modules,
398
- emb_dim=emb_dim,
399
- rnn_dim=rnn_dim,
400
- bidirectional=bidirectional,
401
- tf_dropout=tf_dropout,
402
- mlp_dim=mlp_dim,
403
- hidden_activation=hidden_activation,
404
- hidden_activation_kwargs=hidden_activation_kwargs,
405
- complex_mask=complex_mask,
406
- )
407
- for src, specs in self.band_specs_map.items()
408
- }
409
- )
410
-
411
-
412
- class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
413
- def __init__(
414
- self,
415
- in_channel: int,
416
- stems: List[str],
417
- band_specs: Union[str, List[Tuple[float, float]]],
418
- fs: int = 44100,
419
- require_no_overlap: bool = False,
420
- require_no_gap: bool = True,
421
- normalize_channel_independently: bool = False,
422
- treat_channel_as_feature: bool = True,
423
- n_sqm_modules: int = 12,
424
- emb_dim: int = 128,
425
- rnn_dim: int = 256,
426
- cond_dim: int = 0,
427
- bidirectional: bool = True,
428
- rnn_type: str = "LSTM",
429
- mlp_dim: int = 512,
430
- hidden_activation: str = "Tanh",
431
- hidden_activation_kwargs: Optional[Dict] = None,
432
- complex_mask: bool = True,
433
- n_fft: int = 2048,
434
- win_length: Optional[int] = 2048,
435
- hop_length: int = 512,
436
- window_fn: str = "hann_window",
437
- wkwargs: Optional[Dict] = None,
438
- power: Optional[int] = None,
439
- center: bool = True,
440
- normalized: bool = True,
441
- pad_mode: str = "constant",
442
- onesided: bool = True,
443
- n_bands: int = None,
444
- use_freq_weights: bool = True,
445
- normalize_input: bool = False,
446
- mult_add_mask: bool = False,
447
- freeze_encoder: bool = False,
448
- ) -> None:
449
- super().__init__(
450
- stems=stems,
451
- band_specs=band_specs,
452
- fs=fs,
453
- n_fft=n_fft,
454
- win_length=win_length,
455
- hop_length=hop_length,
456
- window_fn=window_fn,
457
- wkwargs=wkwargs,
458
- power=power,
459
- center=center,
460
- normalized=normalized,
461
- pad_mode=pad_mode,
462
- onesided=onesided,
463
- n_bands=n_bands,
464
- )
465
-
466
- self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
467
- stems=stems,
468
- band_specs=self.band_specs,
469
- in_channel=in_channel,
470
- require_no_overlap=require_no_overlap,
471
- require_no_gap=require_no_gap,
472
- normalize_channel_independently=normalize_channel_independently,
473
- treat_channel_as_feature=treat_channel_as_feature,
474
- n_sqm_modules=n_sqm_modules,
475
- emb_dim=emb_dim,
476
- rnn_dim=rnn_dim,
477
- bidirectional=bidirectional,
478
- rnn_type=rnn_type,
479
- mlp_dim=mlp_dim,
480
- cond_dim=cond_dim,
481
- hidden_activation=hidden_activation,
482
- hidden_activation_kwargs=hidden_activation_kwargs,
483
- complex_mask=complex_mask,
484
- overlapping_band=self.overlapping_band,
485
- freq_weights=self.freq_weights,
486
- n_freq=n_fft // 2 + 1,
487
- use_freq_weights=use_freq_weights,
488
- mult_add_mask=mult_add_mask,
489
- )
490
-
491
- self.normalize_input = normalize_input
492
- self.cond_dim = cond_dim
493
-
494
- if freeze_encoder:
495
- for param in self.bsrnn.band_split.parameters():
496
- param.requires_grad = False
497
-
498
- for param in self.bsrnn.tf_model.parameters():
499
- param.requires_grad = False
500
-
501
-
502
- class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple):
503
- def __init__(
504
- self,
505
- in_channel: int,
506
- stems: List[str],
507
- band_specs: Union[str, List[Tuple[float, float]]],
508
- fs: int = 44100,
509
- require_no_overlap: bool = False,
510
- require_no_gap: bool = True,
511
- normalize_channel_independently: bool = False,
512
- treat_channel_as_feature: bool = True,
513
- n_sqm_modules: int = 12,
514
- emb_dim: int = 128,
515
- rnn_dim: int = 256,
516
- cond_dim: int = 0,
517
- bidirectional: bool = True,
518
- rnn_type: str = "LSTM",
519
- mlp_dim: int = 512,
520
- hidden_activation: str = "Tanh",
521
- hidden_activation_kwargs: Optional[Dict] = None,
522
- complex_mask: bool = True,
523
- n_fft: int = 2048,
524
- win_length: Optional[int] = 2048,
525
- hop_length: int = 512,
526
- window_fn: str = "hann_window",
527
- wkwargs: Optional[Dict] = None,
528
- power: Optional[int] = None,
529
- center: bool = True,
530
- normalized: bool = True,
531
- pad_mode: str = "constant",
532
- onesided: bool = True,
533
- n_bands: int = None,
534
- use_freq_weights: bool = True,
535
- normalize_input: bool = False,
536
- mult_add_mask: bool = False,
537
- freeze_encoder: bool = False,
538
- ) -> None:
539
- super().__init__(
540
- stems=stems,
541
- band_specs=band_specs,
542
- fs=fs,
543
- n_fft=n_fft,
544
- win_length=win_length,
545
- hop_length=hop_length,
546
- window_fn=window_fn,
547
- wkwargs=wkwargs,
548
- power=power,
549
- center=center,
550
- normalized=normalized,
551
- pad_mode=pad_mode,
552
- onesided=onesided,
553
- n_bands=n_bands,
554
- )
555
-
556
- self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
557
- stems=stems,
558
- band_specs=self.band_specs,
559
- in_channel=in_channel,
560
- require_no_overlap=require_no_overlap,
561
- require_no_gap=require_no_gap,
562
- normalize_channel_independently=normalize_channel_independently,
563
- treat_channel_as_feature=treat_channel_as_feature,
564
- n_sqm_modules=n_sqm_modules,
565
- emb_dim=emb_dim,
566
- rnn_dim=rnn_dim,
567
- bidirectional=bidirectional,
568
- rnn_type=rnn_type,
569
- mlp_dim=mlp_dim,
570
- cond_dim=cond_dim,
571
- hidden_activation=hidden_activation,
572
- hidden_activation_kwargs=hidden_activation_kwargs,
573
- complex_mask=complex_mask,
574
- overlapping_band=self.overlapping_band,
575
- freq_weights=self.freq_weights,
576
- n_freq=n_fft // 2 + 1,
577
- use_freq_weights=use_freq_weights,
578
- mult_add_mask=mult_add_mask,
579
- )
580
-
581
- self.normalize_input = normalize_input
582
- self.cond_dim = cond_dim
583
-
584
- if freeze_encoder:
585
- for param in self.bsrnn.band_split.parameters():
586
- param.requires_grad = False
587
-
588
- for param in self.bsrnn.tf_model.parameters():
589
- param.requires_grad = False
590
-
591
-
592
- class MultiMaskMultiSourceBandSplitTransformer(MultiMaskMultiSourceBandSplitBase):
593
- def __init__(
594
- self,
595
- in_channel: int,
596
- stems: List[str],
597
- band_specs: Union[str, List[Tuple[float, float]]],
598
- fs: int = 44100,
599
- require_no_overlap: bool = False,
600
- require_no_gap: bool = True,
601
- normalize_channel_independently: bool = False,
602
- treat_channel_as_feature: bool = True,
603
- n_sqm_modules: int = 12,
604
- emb_dim: int = 128,
605
- rnn_dim: int = 256,
606
- cond_dim: int = 0,
607
- bidirectional: bool = True,
608
- rnn_type: str = "LSTM",
609
- mlp_dim: int = 512,
610
- hidden_activation: str = "Tanh",
611
- hidden_activation_kwargs: Optional[Dict] = None,
612
- complex_mask: bool = True,
613
- n_fft: int = 2048,
614
- win_length: Optional[int] = 2048,
615
- hop_length: int = 512,
616
- window_fn: str = "hann_window",
617
- wkwargs: Optional[Dict] = None,
618
- power: Optional[int] = None,
619
- center: bool = True,
620
- normalized: bool = True,
621
- pad_mode: str = "constant",
622
- onesided: bool = True,
623
- n_bands: int = None,
624
- use_freq_weights: bool = True,
625
- normalize_input: bool = False,
626
- mult_add_mask: bool = False,
627
- ) -> None:
628
- super().__init__(
629
- stems=stems,
630
- band_specs=band_specs,
631
- fs=fs,
632
- n_fft=n_fft,
633
- win_length=win_length,
634
- hop_length=hop_length,
635
- window_fn=window_fn,
636
- wkwargs=wkwargs,
637
- power=power,
638
- center=center,
639
- normalized=normalized,
640
- pad_mode=pad_mode,
641
- onesided=onesided,
642
- n_bands=n_bands,
643
- )
644
-
645
- self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer(
646
- stems=stems,
647
- band_specs=self.band_specs,
648
- in_channel=in_channel,
649
- require_no_overlap=require_no_overlap,
650
- require_no_gap=require_no_gap,
651
- normalize_channel_independently=normalize_channel_independently,
652
- treat_channel_as_feature=treat_channel_as_feature,
653
- n_sqm_modules=n_sqm_modules,
654
- emb_dim=emb_dim,
655
- rnn_dim=rnn_dim,
656
- bidirectional=bidirectional,
657
- rnn_type=rnn_type,
658
- mlp_dim=mlp_dim,
659
- cond_dim=cond_dim,
660
- hidden_activation=hidden_activation,
661
- hidden_activation_kwargs=hidden_activation_kwargs,
662
- complex_mask=complex_mask,
663
- overlapping_band=self.overlapping_band,
664
- freq_weights=self.freq_weights,
665
- n_freq=n_fft // 2 + 1,
666
- use_freq_weights=use_freq_weights,
667
- mult_add_mask=mult_add_mask,
668
- )
669
-
670
-
671
- class MultiMaskMultiSourceBandSplitConv(MultiMaskMultiSourceBandSplitBase):
672
- def __init__(
673
- self,
674
- in_channel: int,
675
- stems: List[str],
676
- band_specs: Union[str, List[Tuple[float, float]]],
677
- fs: int = 44100,
678
- require_no_overlap: bool = False,
679
- require_no_gap: bool = True,
680
- normalize_channel_independently: bool = False,
681
- treat_channel_as_feature: bool = True,
682
- n_sqm_modules: int = 12,
683
- emb_dim: int = 128,
684
- rnn_dim: int = 256,
685
- cond_dim: int = 0,
686
- bidirectional: bool = True,
687
- rnn_type: str = "LSTM",
688
- mlp_dim: int = 512,
689
- hidden_activation: str = "Tanh",
690
- hidden_activation_kwargs: Optional[Dict] = None,
691
- complex_mask: bool = True,
692
- n_fft: int = 2048,
693
- win_length: Optional[int] = 2048,
694
- hop_length: int = 512,
695
- window_fn: str = "hann_window",
696
- wkwargs: Optional[Dict] = None,
697
- power: Optional[int] = None,
698
- center: bool = True,
699
- normalized: bool = True,
700
- pad_mode: str = "constant",
701
- onesided: bool = True,
702
- n_bands: int = None,
703
- use_freq_weights: bool = True,
704
- normalize_input: bool = False,
705
- mult_add_mask: bool = False,
706
- ) -> None:
707
- super().__init__(
708
- stems=stems,
709
- band_specs=band_specs,
710
- fs=fs,
711
- n_fft=n_fft,
712
- win_length=win_length,
713
- hop_length=hop_length,
714
- window_fn=window_fn,
715
- wkwargs=wkwargs,
716
- power=power,
717
- center=center,
718
- normalized=normalized,
719
- pad_mode=pad_mode,
720
- onesided=onesided,
721
- n_bands=n_bands,
722
- )
723
-
724
- self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv(
725
- stems=stems,
726
- band_specs=self.band_specs,
727
- in_channel=in_channel,
728
- require_no_overlap=require_no_overlap,
729
- require_no_gap=require_no_gap,
730
- normalize_channel_independently=normalize_channel_independently,
731
- treat_channel_as_feature=treat_channel_as_feature,
732
- n_sqm_modules=n_sqm_modules,
733
- emb_dim=emb_dim,
734
- rnn_dim=rnn_dim,
735
- bidirectional=bidirectional,
736
- rnn_type=rnn_type,
737
- mlp_dim=mlp_dim,
738
- cond_dim=cond_dim,
739
- hidden_activation=hidden_activation,
740
- hidden_activation_kwargs=hidden_activation_kwargs,
741
- complex_mask=complex_mask,
742
- overlapping_band=self.overlapping_band,
743
- freq_weights=self.freq_weights,
744
- n_freq=n_fft // 2 + 1,
745
- use_freq_weights=use_freq_weights,
746
- mult_add_mask=mult_add_mask,
747
- )
748
-
749
-
750
- class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
751
- def __init__(
752
- self,
753
- in_channel: int,
754
- stems: List[str],
755
- band_specs: Union[str, List[Tuple[float, float]]],
756
- kernel_norm_mlp_version: int = 1,
757
- mask_kernel_freq: int = 3,
758
- mask_kernel_time: int = 3,
759
- conv_kernel_freq: int = 1,
760
- conv_kernel_time: int = 1,
761
- fs: int = 44100,
762
- require_no_overlap: bool = False,
763
- require_no_gap: bool = True,
764
- normalize_channel_independently: bool = False,
765
- treat_channel_as_feature: bool = True,
766
- n_sqm_modules: int = 12,
767
- emb_dim: int = 128,
768
- rnn_dim: int = 256,
769
- bidirectional: bool = True,
770
- rnn_type: str = "LSTM",
771
- mlp_dim: int = 512,
772
- hidden_activation: str = "Tanh",
773
- hidden_activation_kwargs: Optional[Dict] = None,
774
- complex_mask: bool = True,
775
- n_fft: int = 2048,
776
- win_length: Optional[int] = 2048,
777
- hop_length: int = 512,
778
- window_fn: str = "hann_window",
779
- wkwargs: Optional[Dict] = None,
780
- power: Optional[int] = None,
781
- center: bool = True,
782
- normalized: bool = True,
783
- pad_mode: str = "constant",
784
- onesided: bool = True,
785
- n_bands: int = None,
786
- ) -> None:
787
- super().__init__(
788
- stems=stems,
789
- band_specs=band_specs,
790
- fs=fs,
791
- n_fft=n_fft,
792
- win_length=win_length,
793
- hop_length=hop_length,
794
- window_fn=window_fn,
795
- wkwargs=wkwargs,
796
- power=power,
797
- center=center,
798
- normalized=normalized,
799
- pad_mode=pad_mode,
800
- onesided=onesided,
801
- n_bands=n_bands,
802
- )
803
-
804
- self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN(
805
- stems=stems,
806
- band_specs=self.band_specs,
807
- in_channel=in_channel,
808
- require_no_overlap=require_no_overlap,
809
- require_no_gap=require_no_gap,
810
- normalize_channel_independently=normalize_channel_independently,
811
- treat_channel_as_feature=treat_channel_as_feature,
812
- n_sqm_modules=n_sqm_modules,
813
- emb_dim=emb_dim,
814
- rnn_dim=rnn_dim,
815
- bidirectional=bidirectional,
816
- rnn_type=rnn_type,
817
- mlp_dim=mlp_dim,
818
- hidden_activation=hidden_activation,
819
- hidden_activation_kwargs=hidden_activation_kwargs,
820
- complex_mask=complex_mask,
821
- overlapping_band=self.overlapping_band,
822
- freq_weights=self.freq_weights,
823
- n_freq=n_fft // 2 + 1,
824
- mask_kernel_freq=mask_kernel_freq,
825
- mask_kernel_time=mask_kernel_time,
826
- conv_kernel_freq=conv_kernel_freq,
827
- conv_kernel_time=conv_kernel_time,
828
- kernel_norm_mlp_version=kernel_norm_mlp_version,
829
- )
 
1
+ from pprint import pprint
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from .._spectral import _SpectralComponent
8
+ from .utils import (
9
+ BarkBandsplitSpecification,
10
+ BassBandsplitSpecification,
11
+ DrumBandsplitSpecification,
12
+ EquivalentRectangularBandsplitSpecification,
13
+ MelBandsplitSpecification,
14
+ MusicalBandsplitSpecification,
15
+ OtherBandsplitSpecification,
16
+ TriangularBarkBandsplitSpecification,
17
+ VocalBandsplitSpecification,
18
+ )
19
+ from .core import (
20
+ MultiSourceMultiMaskBandSplitCoreConv,
21
+ MultiSourceMultiMaskBandSplitCoreRNN,
22
+ MultiSourceMultiMaskBandSplitCoreTransformer,
23
+ MultiSourceMultiPatchingMaskBandSplitCoreRNN,
24
+ SingleMaskBandsplitCoreRNN,
25
+ SingleMaskBandsplitCoreTransformer,
26
+ )
27
+
28
+ import pytorch_lightning as pl
29
+
30
+
31
+ def get_band_specs(band_specs, n_fft, fs, n_bands=None):
32
+ if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]:
33
+ bsm = VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs()
34
+ freq_weights = None
35
+ overlapping_band = False
36
+ elif "tribark" in band_specs:
37
+ assert n_bands is not None
38
+ specs = TriangularBarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
39
+ bsm = specs.get_band_specs()
40
+ freq_weights = specs.get_freq_weights()
41
+ overlapping_band = True
42
+ elif "bark" in band_specs:
43
+ assert n_bands is not None
44
+ specs = BarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
45
+ bsm = specs.get_band_specs()
46
+ freq_weights = specs.get_freq_weights()
47
+ overlapping_band = True
48
+ elif "erb" in band_specs:
49
+ assert n_bands is not None
50
+ specs = EquivalentRectangularBandsplitSpecification(
51
+ nfft=n_fft, fs=fs, n_bands=n_bands
52
+ )
53
+ bsm = specs.get_band_specs()
54
+ freq_weights = specs.get_freq_weights()
55
+ overlapping_band = True
56
+ elif "musical" in band_specs:
57
+ assert n_bands is not None
58
+ specs = MusicalBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
59
+ bsm = specs.get_band_specs()
60
+ freq_weights = specs.get_freq_weights()
61
+ overlapping_band = True
62
+ elif band_specs == "dnr:mel" or "mel" in band_specs:
63
+ assert n_bands is not None
64
+ specs = MelBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
65
+ bsm = specs.get_band_specs()
66
+ freq_weights = specs.get_freq_weights()
67
+ overlapping_band = True
68
+ else:
69
+ raise NameError
70
+
71
+ return bsm, freq_weights, overlapping_band
72
+
73
+
74
+ def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None):
75
+ if band_specs_map == "musdb:all":
76
+ bsm = {
77
+ "vocals": VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
78
+ "drums": DrumBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
79
+ "bass": BassBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
80
+ "other": OtherBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
81
+ }
82
+ freq_weights = None
83
+ overlapping_band = False
84
+ elif band_specs_map == "dnr:vox7":
85
+ bsm_, freq_weights, overlapping_band = get_band_specs(
86
+ "dnr:speech", n_fft, fs, n_bands
87
+ )
88
+ bsm = {"speech": bsm_, "music": bsm_, "effects": bsm_}
89
+ elif "dnr:vox7:" in band_specs_map:
90
+ stem = band_specs_map.split(":")[-1]
91
+ bsm_, freq_weights, overlapping_band = get_band_specs(
92
+ "dnr:speech", n_fft, fs, n_bands
93
+ )
94
+ bsm = {stem: bsm_}
95
+ else:
96
+ raise NameError
97
+
98
+ return bsm, freq_weights, overlapping_band
99
+
100
+
101
+ class BandSplitWrapperBase(pl.LightningModule):
102
+ bsrnn: nn.Module
103
+
104
+ def __init__(self, **kwargs):
105
+ super().__init__()
106
+
107
+
108
+ class SingleMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
109
+ def __init__(
110
+ self,
111
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
112
+ fs: int = 44100,
113
+ n_fft: int = 2048,
114
+ win_length: Optional[int] = 2048,
115
+ hop_length: int = 512,
116
+ window_fn: str = "hann_window",
117
+ wkwargs: Optional[Dict] = None,
118
+ power: Optional[int] = None,
119
+ center: bool = True,
120
+ normalized: bool = True,
121
+ pad_mode: str = "constant",
122
+ onesided: bool = True,
123
+ n_bands: int = None,
124
+ ) -> None:
125
+ super().__init__(
126
+ n_fft=n_fft,
127
+ win_length=win_length,
128
+ hop_length=hop_length,
129
+ window_fn=window_fn,
130
+ wkwargs=wkwargs,
131
+ power=power,
132
+ center=center,
133
+ normalized=normalized,
134
+ pad_mode=pad_mode,
135
+ onesided=onesided,
136
+ )
137
+
138
+ if isinstance(band_specs_map, str):
139
+ self.band_specs_map, self.freq_weights, self.overlapping_band = (
140
+ get_band_specs_map(band_specs_map, n_fft, fs, n_bands=n_bands)
141
+ )
142
+
143
+ self.stems = list(self.band_specs_map.keys())
144
+
145
+ def forward(self, batch):
146
+ audio = batch["audio"]
147
+
148
+ with torch.no_grad():
149
+ batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
150
+
151
+ X = batch["spectrogram"]["mixture"]
152
+ length = batch["audio"]["mixture"].shape[-1]
153
+
154
+ output = {"spectrogram": {}, "audio": {}}
155
+
156
+ for stem, bsrnn in self.bsrnn.items():
157
+ S = bsrnn(X)
158
+ s = self.istft(S, length)
159
+ output["spectrogram"][stem] = S
160
+ output["audio"][stem] = s
161
+
162
+ return batch, output
163
+
164
+
165
+ class MultiMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
166
+ def __init__(
167
+ self,
168
+ stems: List[str],
169
+ band_specs: Union[str, List[Tuple[float, float]]],
170
+ fs: int = 44100,
171
+ n_fft: int = 2048,
172
+ win_length: Optional[int] = 2048,
173
+ hop_length: int = 512,
174
+ window_fn: str = "hann_window",
175
+ wkwargs: Optional[Dict] = None,
176
+ power: Optional[int] = None,
177
+ center: bool = True,
178
+ normalized: bool = True,
179
+ pad_mode: str = "constant",
180
+ onesided: bool = True,
181
+ n_bands: int = None,
182
+ ) -> None:
183
+ super().__init__(
184
+ n_fft=n_fft,
185
+ win_length=win_length,
186
+ hop_length=hop_length,
187
+ window_fn=window_fn,
188
+ wkwargs=wkwargs,
189
+ power=power,
190
+ center=center,
191
+ normalized=normalized,
192
+ pad_mode=pad_mode,
193
+ onesided=onesided,
194
+ )
195
+
196
+ if isinstance(band_specs, str):
197
+ self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
198
+ band_specs, n_fft, fs, n_bands
199
+ )
200
+
201
+ self.stems = stems
202
+
203
+ def forward(self, batch):
204
+ audio = batch["audio"]
205
+ cond = batch.get("condition", None)
206
+ with torch.no_grad():
207
+ batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
208
+
209
+ X = batch["spectrogram"]["mixture"]
210
+ length = batch["audio"]["mixture"].shape[-1]
211
+
212
+ output = self.bsrnn(X, cond=cond)
213
+ output["audio"] = {}
214
+
215
+ for stem, S in output["spectrogram"].items():
216
+ s = self.istft(S, length)
217
+ output["audio"][stem] = s
218
+
219
+ return batch, output
220
+
221
+
222
+ class MultiMaskMultiSourceBandSplitBaseSimple(BandSplitWrapperBase, _SpectralComponent):
223
+ def __init__(
224
+ self,
225
+ stems: List[str],
226
+ band_specs: Union[str, List[Tuple[float, float]]],
227
+ fs: int = 44100,
228
+ n_fft: int = 2048,
229
+ win_length: Optional[int] = 2048,
230
+ hop_length: int = 512,
231
+ window_fn: str = "hann_window",
232
+ wkwargs: Optional[Dict] = None,
233
+ power: Optional[int] = None,
234
+ center: bool = True,
235
+ normalized: bool = True,
236
+ pad_mode: str = "constant",
237
+ onesided: bool = True,
238
+ n_bands: int = None,
239
+ ) -> None:
240
+ super().__init__(
241
+ n_fft=n_fft,
242
+ win_length=win_length,
243
+ hop_length=hop_length,
244
+ window_fn=window_fn,
245
+ wkwargs=wkwargs,
246
+ power=power,
247
+ center=center,
248
+ normalized=normalized,
249
+ pad_mode=pad_mode,
250
+ onesided=onesided,
251
+ )
252
+
253
+ if isinstance(band_specs, str):
254
+ self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
255
+ band_specs, n_fft, fs, n_bands
256
+ )
257
+
258
+ self.stems = stems
259
+
260
+ def forward(self, batch):
261
+ with torch.no_grad():
262
+ X = self.stft(batch)
263
+ length = batch.shape[-1]
264
+ output = self.bsrnn(X, cond=None)
265
+ res = []
266
+ for stem, S in output["spectrogram"].items():
267
+ s = self.istft(S, length)
268
+ res.append(s)
269
+ res = torch.stack(res, dim=1)
270
+ return res
271
+
272
+
273
+ class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase):
274
+ def __init__(
275
+ self,
276
+ in_channel: int,
277
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
278
+ fs: int = 44100,
279
+ require_no_overlap: bool = False,
280
+ require_no_gap: bool = True,
281
+ normalize_channel_independently: bool = False,
282
+ treat_channel_as_feature: bool = True,
283
+ n_sqm_modules: int = 12,
284
+ emb_dim: int = 128,
285
+ rnn_dim: int = 256,
286
+ bidirectional: bool = True,
287
+ rnn_type: str = "LSTM",
288
+ mlp_dim: int = 512,
289
+ hidden_activation: str = "Tanh",
290
+ hidden_activation_kwargs: Optional[Dict] = None,
291
+ complex_mask: bool = True,
292
+ n_fft: int = 2048,
293
+ win_length: Optional[int] = 2048,
294
+ hop_length: int = 512,
295
+ window_fn: str = "hann_window",
296
+ wkwargs: Optional[Dict] = None,
297
+ power: Optional[int] = None,
298
+ center: bool = True,
299
+ normalized: bool = True,
300
+ pad_mode: str = "constant",
301
+ onesided: bool = True,
302
+ ) -> None:
303
+ super().__init__(
304
+ band_specs_map=band_specs_map,
305
+ fs=fs,
306
+ n_fft=n_fft,
307
+ win_length=win_length,
308
+ hop_length=hop_length,
309
+ window_fn=window_fn,
310
+ wkwargs=wkwargs,
311
+ power=power,
312
+ center=center,
313
+ normalized=normalized,
314
+ pad_mode=pad_mode,
315
+ onesided=onesided,
316
+ )
317
+
318
+ self.bsrnn = nn.ModuleDict(
319
+ {
320
+ src: SingleMaskBandsplitCoreRNN(
321
+ band_specs=specs,
322
+ in_channel=in_channel,
323
+ require_no_overlap=require_no_overlap,
324
+ require_no_gap=require_no_gap,
325
+ normalize_channel_independently=normalize_channel_independently,
326
+ treat_channel_as_feature=treat_channel_as_feature,
327
+ n_sqm_modules=n_sqm_modules,
328
+ emb_dim=emb_dim,
329
+ rnn_dim=rnn_dim,
330
+ bidirectional=bidirectional,
331
+ rnn_type=rnn_type,
332
+ mlp_dim=mlp_dim,
333
+ hidden_activation=hidden_activation,
334
+ hidden_activation_kwargs=hidden_activation_kwargs,
335
+ complex_mask=complex_mask,
336
+ )
337
+ for src, specs in self.band_specs_map.items()
338
+ }
339
+ )
340
+
341
+
342
+ class SingleMaskMultiSourceBandSplitTransformer(SingleMaskMultiSourceBandSplitBase):
343
+ def __init__(
344
+ self,
345
+ in_channel: int,
346
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
347
+ fs: int = 44100,
348
+ require_no_overlap: bool = False,
349
+ require_no_gap: bool = True,
350
+ normalize_channel_independently: bool = False,
351
+ treat_channel_as_feature: bool = True,
352
+ n_sqm_modules: int = 12,
353
+ emb_dim: int = 128,
354
+ rnn_dim: int = 256,
355
+ bidirectional: bool = True,
356
+ tf_dropout: float = 0.0,
357
+ mlp_dim: int = 512,
358
+ hidden_activation: str = "Tanh",
359
+ hidden_activation_kwargs: Optional[Dict] = None,
360
+ complex_mask: bool = True,
361
+ n_fft: int = 2048,
362
+ win_length: Optional[int] = 2048,
363
+ hop_length: int = 512,
364
+ window_fn: str = "hann_window",
365
+ wkwargs: Optional[Dict] = None,
366
+ power: Optional[int] = None,
367
+ center: bool = True,
368
+ normalized: bool = True,
369
+ pad_mode: str = "constant",
370
+ onesided: bool = True,
371
+ ) -> None:
372
+ super().__init__(
373
+ band_specs_map=band_specs_map,
374
+ fs=fs,
375
+ n_fft=n_fft,
376
+ win_length=win_length,
377
+ hop_length=hop_length,
378
+ window_fn=window_fn,
379
+ wkwargs=wkwargs,
380
+ power=power,
381
+ center=center,
382
+ normalized=normalized,
383
+ pad_mode=pad_mode,
384
+ onesided=onesided,
385
+ )
386
+
387
+ self.bsrnn = nn.ModuleDict(
388
+ {
389
+ src: SingleMaskBandsplitCoreTransformer(
390
+ band_specs=specs,
391
+ in_channel=in_channel,
392
+ require_no_overlap=require_no_overlap,
393
+ require_no_gap=require_no_gap,
394
+ normalize_channel_independently=normalize_channel_independently,
395
+ treat_channel_as_feature=treat_channel_as_feature,
396
+ n_sqm_modules=n_sqm_modules,
397
+ emb_dim=emb_dim,
398
+ rnn_dim=rnn_dim,
399
+ bidirectional=bidirectional,
400
+ tf_dropout=tf_dropout,
401
+ mlp_dim=mlp_dim,
402
+ hidden_activation=hidden_activation,
403
+ hidden_activation_kwargs=hidden_activation_kwargs,
404
+ complex_mask=complex_mask,
405
+ )
406
+ for src, specs in self.band_specs_map.items()
407
+ }
408
+ )
409
+
410
+
411
+ class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
412
+ def __init__(
413
+ self,
414
+ in_channel: int,
415
+ stems: List[str],
416
+ band_specs: Union[str, List[Tuple[float, float]]],
417
+ fs: int = 44100,
418
+ require_no_overlap: bool = False,
419
+ require_no_gap: bool = True,
420
+ normalize_channel_independently: bool = False,
421
+ treat_channel_as_feature: bool = True,
422
+ n_sqm_modules: int = 12,
423
+ emb_dim: int = 128,
424
+ rnn_dim: int = 256,
425
+ cond_dim: int = 0,
426
+ bidirectional: bool = True,
427
+ rnn_type: str = "LSTM",
428
+ mlp_dim: int = 512,
429
+ hidden_activation: str = "Tanh",
430
+ hidden_activation_kwargs: Optional[Dict] = None,
431
+ complex_mask: bool = True,
432
+ n_fft: int = 2048,
433
+ win_length: Optional[int] = 2048,
434
+ hop_length: int = 512,
435
+ window_fn: str = "hann_window",
436
+ wkwargs: Optional[Dict] = None,
437
+ power: Optional[int] = None,
438
+ center: bool = True,
439
+ normalized: bool = True,
440
+ pad_mode: str = "constant",
441
+ onesided: bool = True,
442
+ n_bands: int = None,
443
+ use_freq_weights: bool = True,
444
+ normalize_input: bool = False,
445
+ mult_add_mask: bool = False,
446
+ freeze_encoder: bool = False,
447
+ ) -> None:
448
+ super().__init__(
449
+ stems=stems,
450
+ band_specs=band_specs,
451
+ fs=fs,
452
+ n_fft=n_fft,
453
+ win_length=win_length,
454
+ hop_length=hop_length,
455
+ window_fn=window_fn,
456
+ wkwargs=wkwargs,
457
+ power=power,
458
+ center=center,
459
+ normalized=normalized,
460
+ pad_mode=pad_mode,
461
+ onesided=onesided,
462
+ n_bands=n_bands,
463
+ )
464
+
465
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
466
+ stems=stems,
467
+ band_specs=self.band_specs,
468
+ in_channel=in_channel,
469
+ require_no_overlap=require_no_overlap,
470
+ require_no_gap=require_no_gap,
471
+ normalize_channel_independently=normalize_channel_independently,
472
+ treat_channel_as_feature=treat_channel_as_feature,
473
+ n_sqm_modules=n_sqm_modules,
474
+ emb_dim=emb_dim,
475
+ rnn_dim=rnn_dim,
476
+ bidirectional=bidirectional,
477
+ rnn_type=rnn_type,
478
+ mlp_dim=mlp_dim,
479
+ cond_dim=cond_dim,
480
+ hidden_activation=hidden_activation,
481
+ hidden_activation_kwargs=hidden_activation_kwargs,
482
+ complex_mask=complex_mask,
483
+ overlapping_band=self.overlapping_band,
484
+ freq_weights=self.freq_weights,
485
+ n_freq=n_fft // 2 + 1,
486
+ use_freq_weights=use_freq_weights,
487
+ mult_add_mask=mult_add_mask,
488
+ )
489
+
490
+ self.normalize_input = normalize_input
491
+ self.cond_dim = cond_dim
492
+
493
+ if freeze_encoder:
494
+ for param in self.bsrnn.band_split.parameters():
495
+ param.requires_grad = False
496
+
497
+ for param in self.bsrnn.tf_model.parameters():
498
+ param.requires_grad = False
499
+
500
+
501
+ class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple):
502
+ def __init__(
503
+ self,
504
+ in_channel: int,
505
+ stems: List[str],
506
+ band_specs: Union[str, List[Tuple[float, float]]],
507
+ fs: int = 44100,
508
+ require_no_overlap: bool = False,
509
+ require_no_gap: bool = True,
510
+ normalize_channel_independently: bool = False,
511
+ treat_channel_as_feature: bool = True,
512
+ n_sqm_modules: int = 12,
513
+ emb_dim: int = 128,
514
+ rnn_dim: int = 256,
515
+ cond_dim: int = 0,
516
+ bidirectional: bool = True,
517
+ rnn_type: str = "LSTM",
518
+ mlp_dim: int = 512,
519
+ hidden_activation: str = "Tanh",
520
+ hidden_activation_kwargs: Optional[Dict] = None,
521
+ complex_mask: bool = True,
522
+ n_fft: int = 2048,
523
+ win_length: Optional[int] = 2048,
524
+ hop_length: int = 512,
525
+ window_fn: str = "hann_window",
526
+ wkwargs: Optional[Dict] = None,
527
+ power: Optional[int] = None,
528
+ center: bool = True,
529
+ normalized: bool = True,
530
+ pad_mode: str = "constant",
531
+ onesided: bool = True,
532
+ n_bands: int = None,
533
+ use_freq_weights: bool = True,
534
+ normalize_input: bool = False,
535
+ mult_add_mask: bool = False,
536
+ freeze_encoder: bool = False,
537
+ ) -> None:
538
+ super().__init__(
539
+ stems=stems,
540
+ band_specs=band_specs,
541
+ fs=fs,
542
+ n_fft=n_fft,
543
+ win_length=win_length,
544
+ hop_length=hop_length,
545
+ window_fn=window_fn,
546
+ wkwargs=wkwargs,
547
+ power=power,
548
+ center=center,
549
+ normalized=normalized,
550
+ pad_mode=pad_mode,
551
+ onesided=onesided,
552
+ n_bands=n_bands,
553
+ )
554
+
555
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
556
+ stems=stems,
557
+ band_specs=self.band_specs,
558
+ in_channel=in_channel,
559
+ require_no_overlap=require_no_overlap,
560
+ require_no_gap=require_no_gap,
561
+ normalize_channel_independently=normalize_channel_independently,
562
+ treat_channel_as_feature=treat_channel_as_feature,
563
+ n_sqm_modules=n_sqm_modules,
564
+ emb_dim=emb_dim,
565
+ rnn_dim=rnn_dim,
566
+ bidirectional=bidirectional,
567
+ rnn_type=rnn_type,
568
+ mlp_dim=mlp_dim,
569
+ cond_dim=cond_dim,
570
+ hidden_activation=hidden_activation,
571
+ hidden_activation_kwargs=hidden_activation_kwargs,
572
+ complex_mask=complex_mask,
573
+ overlapping_band=self.overlapping_band,
574
+ freq_weights=self.freq_weights,
575
+ n_freq=n_fft // 2 + 1,
576
+ use_freq_weights=use_freq_weights,
577
+ mult_add_mask=mult_add_mask,
578
+ )
579
+
580
+ self.normalize_input = normalize_input
581
+ self.cond_dim = cond_dim
582
+
583
+ if freeze_encoder:
584
+ for param in self.bsrnn.band_split.parameters():
585
+ param.requires_grad = False
586
+
587
+ for param in self.bsrnn.tf_model.parameters():
588
+ param.requires_grad = False
589
+
590
+
591
+ class MultiMaskMultiSourceBandSplitTransformer(MultiMaskMultiSourceBandSplitBase):
592
+ def __init__(
593
+ self,
594
+ in_channel: int,
595
+ stems: List[str],
596
+ band_specs: Union[str, List[Tuple[float, float]]],
597
+ fs: int = 44100,
598
+ require_no_overlap: bool = False,
599
+ require_no_gap: bool = True,
600
+ normalize_channel_independently: bool = False,
601
+ treat_channel_as_feature: bool = True,
602
+ n_sqm_modules: int = 12,
603
+ emb_dim: int = 128,
604
+ rnn_dim: int = 256,
605
+ cond_dim: int = 0,
606
+ bidirectional: bool = True,
607
+ rnn_type: str = "LSTM",
608
+ mlp_dim: int = 512,
609
+ hidden_activation: str = "Tanh",
610
+ hidden_activation_kwargs: Optional[Dict] = None,
611
+ complex_mask: bool = True,
612
+ n_fft: int = 2048,
613
+ win_length: Optional[int] = 2048,
614
+ hop_length: int = 512,
615
+ window_fn: str = "hann_window",
616
+ wkwargs: Optional[Dict] = None,
617
+ power: Optional[int] = None,
618
+ center: bool = True,
619
+ normalized: bool = True,
620
+ pad_mode: str = "constant",
621
+ onesided: bool = True,
622
+ n_bands: int = None,
623
+ use_freq_weights: bool = True,
624
+ normalize_input: bool = False,
625
+ mult_add_mask: bool = False,
626
+ ) -> None:
627
+ super().__init__(
628
+ stems=stems,
629
+ band_specs=band_specs,
630
+ fs=fs,
631
+ n_fft=n_fft,
632
+ win_length=win_length,
633
+ hop_length=hop_length,
634
+ window_fn=window_fn,
635
+ wkwargs=wkwargs,
636
+ power=power,
637
+ center=center,
638
+ normalized=normalized,
639
+ pad_mode=pad_mode,
640
+ onesided=onesided,
641
+ n_bands=n_bands,
642
+ )
643
+
644
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer(
645
+ stems=stems,
646
+ band_specs=self.band_specs,
647
+ in_channel=in_channel,
648
+ require_no_overlap=require_no_overlap,
649
+ require_no_gap=require_no_gap,
650
+ normalize_channel_independently=normalize_channel_independently,
651
+ treat_channel_as_feature=treat_channel_as_feature,
652
+ n_sqm_modules=n_sqm_modules,
653
+ emb_dim=emb_dim,
654
+ rnn_dim=rnn_dim,
655
+ bidirectional=bidirectional,
656
+ rnn_type=rnn_type,
657
+ mlp_dim=mlp_dim,
658
+ cond_dim=cond_dim,
659
+ hidden_activation=hidden_activation,
660
+ hidden_activation_kwargs=hidden_activation_kwargs,
661
+ complex_mask=complex_mask,
662
+ overlapping_band=self.overlapping_band,
663
+ freq_weights=self.freq_weights,
664
+ n_freq=n_fft // 2 + 1,
665
+ use_freq_weights=use_freq_weights,
666
+ mult_add_mask=mult_add_mask,
667
+ )
668
+
669
+
670
+ class MultiMaskMultiSourceBandSplitConv(MultiMaskMultiSourceBandSplitBase):
671
+ def __init__(
672
+ self,
673
+ in_channel: int,
674
+ stems: List[str],
675
+ band_specs: Union[str, List[Tuple[float, float]]],
676
+ fs: int = 44100,
677
+ require_no_overlap: bool = False,
678
+ require_no_gap: bool = True,
679
+ normalize_channel_independently: bool = False,
680
+ treat_channel_as_feature: bool = True,
681
+ n_sqm_modules: int = 12,
682
+ emb_dim: int = 128,
683
+ rnn_dim: int = 256,
684
+ cond_dim: int = 0,
685
+ bidirectional: bool = True,
686
+ rnn_type: str = "LSTM",
687
+ mlp_dim: int = 512,
688
+ hidden_activation: str = "Tanh",
689
+ hidden_activation_kwargs: Optional[Dict] = None,
690
+ complex_mask: bool = True,
691
+ n_fft: int = 2048,
692
+ win_length: Optional[int] = 2048,
693
+ hop_length: int = 512,
694
+ window_fn: str = "hann_window",
695
+ wkwargs: Optional[Dict] = None,
696
+ power: Optional[int] = None,
697
+ center: bool = True,
698
+ normalized: bool = True,
699
+ pad_mode: str = "constant",
700
+ onesided: bool = True,
701
+ n_bands: int = None,
702
+ use_freq_weights: bool = True,
703
+ normalize_input: bool = False,
704
+ mult_add_mask: bool = False,
705
+ ) -> None:
706
+ super().__init__(
707
+ stems=stems,
708
+ band_specs=band_specs,
709
+ fs=fs,
710
+ n_fft=n_fft,
711
+ win_length=win_length,
712
+ hop_length=hop_length,
713
+ window_fn=window_fn,
714
+ wkwargs=wkwargs,
715
+ power=power,
716
+ center=center,
717
+ normalized=normalized,
718
+ pad_mode=pad_mode,
719
+ onesided=onesided,
720
+ n_bands=n_bands,
721
+ )
722
+
723
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv(
724
+ stems=stems,
725
+ band_specs=self.band_specs,
726
+ in_channel=in_channel,
727
+ require_no_overlap=require_no_overlap,
728
+ require_no_gap=require_no_gap,
729
+ normalize_channel_independently=normalize_channel_independently,
730
+ treat_channel_as_feature=treat_channel_as_feature,
731
+ n_sqm_modules=n_sqm_modules,
732
+ emb_dim=emb_dim,
733
+ rnn_dim=rnn_dim,
734
+ bidirectional=bidirectional,
735
+ rnn_type=rnn_type,
736
+ mlp_dim=mlp_dim,
737
+ cond_dim=cond_dim,
738
+ hidden_activation=hidden_activation,
739
+ hidden_activation_kwargs=hidden_activation_kwargs,
740
+ complex_mask=complex_mask,
741
+ overlapping_band=self.overlapping_band,
742
+ freq_weights=self.freq_weights,
743
+ n_freq=n_fft // 2 + 1,
744
+ use_freq_weights=use_freq_weights,
745
+ mult_add_mask=mult_add_mask,
746
+ )
747
+
748
+
749
+ class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
750
+ def __init__(
751
+ self,
752
+ in_channel: int,
753
+ stems: List[str],
754
+ band_specs: Union[str, List[Tuple[float, float]]],
755
+ kernel_norm_mlp_version: int = 1,
756
+ mask_kernel_freq: int = 3,
757
+ mask_kernel_time: int = 3,
758
+ conv_kernel_freq: int = 1,
759
+ conv_kernel_time: int = 1,
760
+ fs: int = 44100,
761
+ require_no_overlap: bool = False,
762
+ require_no_gap: bool = True,
763
+ normalize_channel_independently: bool = False,
764
+ treat_channel_as_feature: bool = True,
765
+ n_sqm_modules: int = 12,
766
+ emb_dim: int = 128,
767
+ rnn_dim: int = 256,
768
+ bidirectional: bool = True,
769
+ rnn_type: str = "LSTM",
770
+ mlp_dim: int = 512,
771
+ hidden_activation: str = "Tanh",
772
+ hidden_activation_kwargs: Optional[Dict] = None,
773
+ complex_mask: bool = True,
774
+ n_fft: int = 2048,
775
+ win_length: Optional[int] = 2048,
776
+ hop_length: int = 512,
777
+ window_fn: str = "hann_window",
778
+ wkwargs: Optional[Dict] = None,
779
+ power: Optional[int] = None,
780
+ center: bool = True,
781
+ normalized: bool = True,
782
+ pad_mode: str = "constant",
783
+ onesided: bool = True,
784
+ n_bands: int = None,
785
+ ) -> None:
786
+ super().__init__(
787
+ stems=stems,
788
+ band_specs=band_specs,
789
+ fs=fs,
790
+ n_fft=n_fft,
791
+ win_length=win_length,
792
+ hop_length=hop_length,
793
+ window_fn=window_fn,
794
+ wkwargs=wkwargs,
795
+ power=power,
796
+ center=center,
797
+ normalized=normalized,
798
+ pad_mode=pad_mode,
799
+ onesided=onesided,
800
+ n_bands=n_bands,
801
+ )
802
+
803
+ self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN(
804
+ stems=stems,
805
+ band_specs=self.band_specs,
806
+ in_channel=in_channel,
807
+ require_no_overlap=require_no_overlap,
808
+ require_no_gap=require_no_gap,
809
+ normalize_channel_independently=normalize_channel_independently,
810
+ treat_channel_as_feature=treat_channel_as_feature,
811
+ n_sqm_modules=n_sqm_modules,
812
+ emb_dim=emb_dim,
813
+ rnn_dim=rnn_dim,
814
+ bidirectional=bidirectional,
815
+ rnn_type=rnn_type,
816
+ mlp_dim=mlp_dim,
817
+ hidden_activation=hidden_activation,
818
+ hidden_activation_kwargs=hidden_activation_kwargs,
819
+ complex_mask=complex_mask,
820
+ overlapping_band=self.overlapping_band,
821
+ freq_weights=self.freq_weights,
822
+ n_freq=n_fft // 2 + 1,
823
+ mask_kernel_freq=mask_kernel_freq,
824
+ mask_kernel_time=mask_kernel_time,
825
+ conv_kernel_freq=conv_kernel_freq,
826
+ conv_kernel_time=conv_kernel_time,
827
+ kernel_norm_mlp_version=kernel_norm_mlp_version,
828
+ )
 
mvsepless/models/bandit/core/utils/audio.py CHANGED
@@ -1,412 +1,324 @@
1
- from collections import defaultdict
2
-
3
- from tqdm.auto import tqdm
4
- from typing import Callable, Dict, List, Optional, Tuple
5
-
6
- import numpy as np
7
- import torch
8
- from torch import nn
9
- from torch.nn import functional as F
10
-
11
-
12
- @torch.jit.script
13
- def merge(
14
- combined: torch.Tensor,
15
- original_batch_size: int,
16
- n_channel: int,
17
- n_chunks: int,
18
- chunk_size: int,
19
- ):
20
- combined = torch.reshape(
21
- combined, (original_batch_size, n_chunks, n_channel, chunk_size)
22
- )
23
- combined = torch.permute(combined, (0, 2, 3, 1)).reshape(
24
- original_batch_size * n_channel, chunk_size, n_chunks
25
- )
26
-
27
- return combined
28
-
29
-
30
- @torch.jit.script
31
- def unfold(
32
- padded_audio: torch.Tensor,
33
- original_batch_size: int,
34
- n_channel: int,
35
- chunk_size: int,
36
- hop_size: int,
37
- ) -> torch.Tensor:
38
-
39
- unfolded_input = F.unfold(
40
- padded_audio[:, :, None, :], kernel_size=(1, chunk_size), stride=(1, hop_size)
41
- )
42
-
43
- _, _, n_chunks = unfolded_input.shape
44
- unfolded_input = unfolded_input.view(
45
- original_batch_size, n_channel, chunk_size, n_chunks
46
- )
47
- unfolded_input = torch.permute(unfolded_input, (0, 3, 1, 2)).reshape(
48
- original_batch_size * n_chunks, n_channel, chunk_size
49
- )
50
-
51
- return unfolded_input
52
-
53
-
54
- @torch.jit.script
55
- # @torch.compile
56
- def merge_chunks_all(
57
- combined: torch.Tensor,
58
- original_batch_size: int,
59
- n_channel: int,
60
- n_samples: int,
61
- n_padded_samples: int,
62
- n_chunks: int,
63
- chunk_size: int,
64
- hop_size: int,
65
- edge_frame_pad_sizes: Tuple[int, int],
66
- standard_window: torch.Tensor,
67
- first_window: torch.Tensor,
68
- last_window: torch.Tensor,
69
- ):
70
- combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
71
-
72
- combined = combined * standard_window[:, None].to(combined.device)
73
-
74
- combined = F.fold(
75
- combined.to(torch.float32),
76
- output_size=(1, n_padded_samples),
77
- kernel_size=(1, chunk_size),
78
- stride=(1, hop_size),
79
- )
80
-
81
- combined = combined.view(original_batch_size, n_channel, n_padded_samples)
82
-
83
- pad_front, pad_back = edge_frame_pad_sizes
84
- combined = combined[..., pad_front:-pad_back]
85
-
86
- combined = combined[..., :n_samples]
87
-
88
- return combined
89
-
90
- # @torch.jit.script
91
-
92
-
93
- def merge_chunks_edge(
94
- combined: torch.Tensor,
95
- original_batch_size: int,
96
- n_channel: int,
97
- n_samples: int,
98
- n_padded_samples: int,
99
- n_chunks: int,
100
- chunk_size: int,
101
- hop_size: int,
102
- edge_frame_pad_sizes: Tuple[int, int],
103
- standard_window: torch.Tensor,
104
- first_window: torch.Tensor,
105
- last_window: torch.Tensor,
106
- ):
107
- combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
108
-
109
- combined[..., 0] = combined[..., 0] * first_window
110
- combined[..., -1] = combined[..., -1] * last_window
111
- combined[..., 1:-1] = combined[..., 1:-1] * standard_window[:, None]
112
-
113
- combined = F.fold(
114
- combined,
115
- output_size=(1, n_padded_samples),
116
- kernel_size=(1, chunk_size),
117
- stride=(1, hop_size),
118
- )
119
-
120
- combined = combined.view(original_batch_size, n_channel, n_padded_samples)
121
-
122
- combined = combined[..., :n_samples]
123
-
124
- return combined
125
-
126
-
127
- class BaseFader(nn.Module):
128
- def __init__(
129
- self,
130
- chunk_size_second: float,
131
- hop_size_second: float,
132
- fs: int,
133
- fade_edge_frames: bool,
134
- batch_size: int,
135
- ) -> None:
136
- super().__init__()
137
-
138
- self.chunk_size = int(chunk_size_second * fs)
139
- self.hop_size = int(hop_size_second * fs)
140
- self.overlap_size = self.chunk_size - self.hop_size
141
- self.fade_edge_frames = fade_edge_frames
142
- self.batch_size = batch_size
143
-
144
- # @torch.jit.script
145
- def prepare(self, audio):
146
-
147
- if self.fade_edge_frames:
148
- audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect")
149
-
150
- n_samples = audio.shape[-1]
151
- n_chunks = int(np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1)
152
-
153
- padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size
154
- pad_size = padded_size - n_samples
155
-
156
- padded_audio = F.pad(audio, (0, pad_size))
157
-
158
- return padded_audio, n_chunks
159
-
160
- def forward(
161
- self,
162
- audio: torch.Tensor,
163
- model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
164
- ):
165
-
166
- original_dtype = audio.dtype
167
- original_device = audio.device
168
-
169
- audio = audio.to("cpu")
170
-
171
- original_batch_size, n_channel, n_samples = audio.shape
172
- padded_audio, n_chunks = self.prepare(audio)
173
- del audio
174
- n_padded_samples = padded_audio.shape[-1]
175
-
176
- if n_channel > 1:
177
- padded_audio = padded_audio.view(
178
- original_batch_size * n_channel, 1, n_padded_samples
179
- )
180
-
181
- unfolded_input = unfold(
182
- padded_audio, original_batch_size, n_channel, self.chunk_size, self.hop_size
183
- )
184
-
185
- n_total_chunks, n_channel, chunk_size = unfolded_input.shape
186
-
187
- n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int)
188
-
189
- chunks_in = [
190
- unfolded_input[b * self.batch_size : (b + 1) * self.batch_size, ...].clone()
191
- for b in range(n_batch)
192
- ]
193
-
194
- all_chunks_out = defaultdict(
195
- lambda: torch.zeros_like(unfolded_input, device="cpu")
196
- )
197
-
198
- # for b, cin in enumerate(tqdm(chunks_in)):
199
- for b, cin in enumerate(chunks_in):
200
- if torch.allclose(cin, torch.tensor(0.0)):
201
- del cin
202
- continue
203
-
204
- chunks_out = model_fn(cin.to(original_device))
205
- del cin
206
- for s, c in chunks_out.items():
207
- all_chunks_out[s][
208
- b * self.batch_size : (b + 1) * self.batch_size, ...
209
- ] = c.cpu()
210
- del chunks_out
211
-
212
- del unfolded_input
213
- del padded_audio
214
-
215
- if self.fade_edge_frames:
216
- fn = merge_chunks_all
217
- else:
218
- fn = merge_chunks_edge
219
- outputs = {}
220
-
221
- torch.cuda.empty_cache()
222
-
223
- for s, c in all_chunks_out.items():
224
- combined: torch.Tensor = fn(
225
- c,
226
- original_batch_size,
227
- n_channel,
228
- n_samples,
229
- n_padded_samples,
230
- n_chunks,
231
- self.chunk_size,
232
- self.hop_size,
233
- self.edge_frame_pad_sizes,
234
- self.standard_window,
235
- self.__dict__.get("first_window", self.standard_window),
236
- self.__dict__.get("last_window", self.standard_window),
237
- )
238
-
239
- outputs[s] = combined.to(dtype=original_dtype, device=original_device)
240
-
241
- return {"audio": outputs}
242
-
243
- #
244
- # def old_forward(
245
- # self,
246
- # audio: torch.Tensor,
247
- # model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
248
- # ):
249
- #
250
- # n_samples = audio.shape[-1]
251
- # original_batch_size = audio.shape[0]
252
- #
253
- # padded_audio, n_chunks = self.prepare(audio)
254
- #
255
- # ndim = padded_audio.ndim
256
- # broadcaster = [1 for _ in range(ndim - 1)] + [self.chunk_size]
257
- #
258
- # outputs = defaultdict(
259
- # lambda: torch.zeros_like(
260
- # padded_audio, device=audio.device, dtype=torch.float64
261
- # )
262
- # )
263
- #
264
- # all_chunks_out = []
265
- # len_chunks_in = []
266
- #
267
- # batch_size_ = int(self.batch_size // original_batch_size)
268
- # for b in range(int(np.ceil(n_chunks / batch_size_))):
269
- # chunks_in = []
270
- # for j in range(batch_size_):
271
- # i = b * batch_size_ + j
272
- # if i == n_chunks:
273
- # break
274
- #
275
- # start = i * hop_size
276
- # end = start + self.chunk_size
277
- # chunk_in = padded_audio[..., start:end]
278
- # chunks_in.append(chunk_in)
279
- #
280
- # chunks_in = torch.concat(chunks_in, dim=0)
281
- # chunks_out = model_fn(chunks_in)
282
- # all_chunks_out.append(chunks_out)
283
- # len_chunks_in.append(len(chunks_in))
284
- #
285
- # for b, (chunks_out, lci) in enumerate(
286
- # zip(all_chunks_out, len_chunks_in)
287
- # ):
288
- # for stem in chunks_out:
289
- # for j in range(lci // original_batch_size):
290
- # i = b * batch_size_ + j
291
- #
292
- # if self.fade_edge_frames:
293
- # window = self.standard_window
294
- # else:
295
- # if i == 0:
296
- # window = self.first_window
297
- # elif i == n_chunks - 1:
298
- # window = self.last_window
299
- # else:
300
- # window = self.standard_window
301
- #
302
- # start = i * hop_size
303
- # end = start + self.chunk_size
304
- #
305
- # chunk_out = chunks_out[stem][j * original_batch_size: (j + 1) * original_batch_size,
306
- # ...]
307
- # contrib = window.view(*broadcaster) * chunk_out
308
- # outputs[stem][..., start:end] = (
309
- # outputs[stem][..., start:end] + contrib
310
- # )
311
- #
312
- # if self.fade_edge_frames:
313
- # pad_front, pad_back = self.edge_frame_pad_sizes
314
- # outputs = {k: v[..., pad_front:-pad_back] for k, v in
315
- # outputs.items()}
316
- #
317
- # outputs = {k: v[..., :n_samples].to(audio.dtype) for k, v in
318
- # outputs.items()}
319
- #
320
- # return {
321
- # "audio": outputs
322
- # }
323
-
324
-
325
- class LinearFader(BaseFader):
326
- def __init__(
327
- self,
328
- chunk_size_second: float,
329
- hop_size_second: float,
330
- fs: int,
331
- fade_edge_frames: bool = False,
332
- batch_size: int = 1,
333
- ) -> None:
334
-
335
- assert hop_size_second >= chunk_size_second / 2
336
-
337
- super().__init__(
338
- chunk_size_second=chunk_size_second,
339
- hop_size_second=hop_size_second,
340
- fs=fs,
341
- fade_edge_frames=fade_edge_frames,
342
- batch_size=batch_size,
343
- )
344
-
345
- in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1]
346
- out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:]
347
- center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size)
348
- inout_ones = torch.ones(self.overlap_size)
349
-
350
- # using nn.Parameters allows lightning to take care of devices for us
351
- self.register_buffer(
352
- "standard_window", torch.concat([in_fade, center_ones, out_fade])
353
- )
354
-
355
- self.fade_edge_frames = fade_edge_frames
356
- self.edge_frame_pad_size = (self.overlap_size, self.overlap_size)
357
-
358
- if not self.fade_edge_frames:
359
- self.first_window = nn.Parameter(
360
- torch.concat([inout_ones, center_ones, out_fade]), requires_grad=False
361
- )
362
- self.last_window = nn.Parameter(
363
- torch.concat([in_fade, center_ones, inout_ones]), requires_grad=False
364
- )
365
-
366
-
367
- class OverlapAddFader(BaseFader):
368
- def __init__(
369
- self,
370
- window_type: str,
371
- chunk_size_second: float,
372
- hop_size_second: float,
373
- fs: int,
374
- batch_size: int = 1,
375
- ) -> None:
376
- assert (chunk_size_second / hop_size_second) % 2 == 0
377
- assert int(chunk_size_second * fs) % 2 == 0
378
-
379
- super().__init__(
380
- chunk_size_second=chunk_size_second,
381
- hop_size_second=hop_size_second,
382
- fs=fs,
383
- fade_edge_frames=True,
384
- batch_size=batch_size,
385
- )
386
-
387
- self.hop_multiplier = self.chunk_size / (2 * self.hop_size)
388
- # print(f"hop multiplier: {self.hop_multiplier}")
389
-
390
- self.edge_frame_pad_sizes = (2 * self.overlap_size, 2 * self.overlap_size)
391
-
392
- self.register_buffer(
393
- "standard_window",
394
- torch.windows.__dict__[window_type](
395
- self.chunk_size,
396
- sym=False, # dtype=torch.float64
397
- )
398
- / self.hop_multiplier,
399
- )
400
-
401
-
402
- if __name__ == "__main__":
403
- import torchaudio as ta
404
-
405
- fs = 44100
406
- ola = OverlapAddFader("hann", 6.0, 1.0, fs, batch_size=16)
407
- audio_, _ = ta.load(
408
- "$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too " "Much/vocals.wav"
409
- )
410
- audio_ = audio_[None, ...]
411
- out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"]
412
- print(torch.allclose(out, audio_))
 
1
+ from collections import defaultdict
2
+
3
+ from tqdm.auto import tqdm
4
+ from typing import Callable, Dict, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+
12
+ @torch.jit.script
13
+ def merge(
14
+ combined: torch.Tensor,
15
+ original_batch_size: int,
16
+ n_channel: int,
17
+ n_chunks: int,
18
+ chunk_size: int,
19
+ ):
20
+ combined = torch.reshape(
21
+ combined, (original_batch_size, n_chunks, n_channel, chunk_size)
22
+ )
23
+ combined = torch.permute(combined, (0, 2, 3, 1)).reshape(
24
+ original_batch_size * n_channel, chunk_size, n_chunks
25
+ )
26
+
27
+ return combined
28
+
29
+
30
+ @torch.jit.script
31
+ def unfold(
32
+ padded_audio: torch.Tensor,
33
+ original_batch_size: int,
34
+ n_channel: int,
35
+ chunk_size: int,
36
+ hop_size: int,
37
+ ) -> torch.Tensor:
38
+
39
+ unfolded_input = F.unfold(
40
+ padded_audio[:, :, None, :], kernel_size=(1, chunk_size), stride=(1, hop_size)
41
+ )
42
+
43
+ _, _, n_chunks = unfolded_input.shape
44
+ unfolded_input = unfolded_input.view(
45
+ original_batch_size, n_channel, chunk_size, n_chunks
46
+ )
47
+ unfolded_input = torch.permute(unfolded_input, (0, 3, 1, 2)).reshape(
48
+ original_batch_size * n_chunks, n_channel, chunk_size
49
+ )
50
+
51
+ return unfolded_input
52
+
53
+
54
+ @torch.jit.script
55
+ def merge_chunks_all(
56
+ combined: torch.Tensor,
57
+ original_batch_size: int,
58
+ n_channel: int,
59
+ n_samples: int,
60
+ n_padded_samples: int,
61
+ n_chunks: int,
62
+ chunk_size: int,
63
+ hop_size: int,
64
+ edge_frame_pad_sizes: Tuple[int, int],
65
+ standard_window: torch.Tensor,
66
+ first_window: torch.Tensor,
67
+ last_window: torch.Tensor,
68
+ ):
69
+ combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
70
+
71
+ combined = combined * standard_window[:, None].to(combined.device)
72
+
73
+ combined = F.fold(
74
+ combined.to(torch.float32),
75
+ output_size=(1, n_padded_samples),
76
+ kernel_size=(1, chunk_size),
77
+ stride=(1, hop_size),
78
+ )
79
+
80
+ combined = combined.view(original_batch_size, n_channel, n_padded_samples)
81
+
82
+ pad_front, pad_back = edge_frame_pad_sizes
83
+ combined = combined[..., pad_front:-pad_back]
84
+
85
+ combined = combined[..., :n_samples]
86
+
87
+ return combined
88
+
89
+
90
+ def merge_chunks_edge(
91
+ combined: torch.Tensor,
92
+ original_batch_size: int,
93
+ n_channel: int,
94
+ n_samples: int,
95
+ n_padded_samples: int,
96
+ n_chunks: int,
97
+ chunk_size: int,
98
+ hop_size: int,
99
+ edge_frame_pad_sizes: Tuple[int, int],
100
+ standard_window: torch.Tensor,
101
+ first_window: torch.Tensor,
102
+ last_window: torch.Tensor,
103
+ ):
104
+ combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
105
+
106
+ combined[..., 0] = combined[..., 0] * first_window
107
+ combined[..., -1] = combined[..., -1] * last_window
108
+ combined[..., 1:-1] = combined[..., 1:-1] * standard_window[:, None]
109
+
110
+ combined = F.fold(
111
+ combined,
112
+ output_size=(1, n_padded_samples),
113
+ kernel_size=(1, chunk_size),
114
+ stride=(1, hop_size),
115
+ )
116
+
117
+ combined = combined.view(original_batch_size, n_channel, n_padded_samples)
118
+
119
+ combined = combined[..., :n_samples]
120
+
121
+ return combined
122
+
123
+
124
+ class BaseFader(nn.Module):
125
+ def __init__(
126
+ self,
127
+ chunk_size_second: float,
128
+ hop_size_second: float,
129
+ fs: int,
130
+ fade_edge_frames: bool,
131
+ batch_size: int,
132
+ ) -> None:
133
+ super().__init__()
134
+
135
+ self.chunk_size = int(chunk_size_second * fs)
136
+ self.hop_size = int(hop_size_second * fs)
137
+ self.overlap_size = self.chunk_size - self.hop_size
138
+ self.fade_edge_frames = fade_edge_frames
139
+ self.batch_size = batch_size
140
+
141
+ def prepare(self, audio):
142
+
143
+ if self.fade_edge_frames:
144
+ audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect")
145
+
146
+ n_samples = audio.shape[-1]
147
+ n_chunks = int(np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1)
148
+
149
+ padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size
150
+ pad_size = padded_size - n_samples
151
+
152
+ padded_audio = F.pad(audio, (0, pad_size))
153
+
154
+ return padded_audio, n_chunks
155
+
156
+ def forward(
157
+ self,
158
+ audio: torch.Tensor,
159
+ model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
160
+ ):
161
+
162
+ original_dtype = audio.dtype
163
+ original_device = audio.device
164
+
165
+ audio = audio.to("cpu")
166
+
167
+ original_batch_size, n_channel, n_samples = audio.shape
168
+ padded_audio, n_chunks = self.prepare(audio)
169
+ del audio
170
+ n_padded_samples = padded_audio.shape[-1]
171
+
172
+ if n_channel > 1:
173
+ padded_audio = padded_audio.view(
174
+ original_batch_size * n_channel, 1, n_padded_samples
175
+ )
176
+
177
+ unfolded_input = unfold(
178
+ padded_audio, original_batch_size, n_channel, self.chunk_size, self.hop_size
179
+ )
180
+
181
+ n_total_chunks, n_channel, chunk_size = unfolded_input.shape
182
+
183
+ n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int)
184
+
185
+ chunks_in = [
186
+ unfolded_input[b * self.batch_size : (b + 1) * self.batch_size, ...].clone()
187
+ for b in range(n_batch)
188
+ ]
189
+
190
+ all_chunks_out = defaultdict(
191
+ lambda: torch.zeros_like(unfolded_input, device="cpu")
192
+ )
193
+
194
+ for b, cin in enumerate(chunks_in):
195
+ if torch.allclose(cin, torch.tensor(0.0)):
196
+ del cin
197
+ continue
198
+
199
+ chunks_out = model_fn(cin.to(original_device))
200
+ del cin
201
+ for s, c in chunks_out.items():
202
+ all_chunks_out[s][
203
+ b * self.batch_size : (b + 1) * self.batch_size, ...
204
+ ] = c.cpu()
205
+ del chunks_out
206
+
207
+ del unfolded_input
208
+ del padded_audio
209
+
210
+ if self.fade_edge_frames:
211
+ fn = merge_chunks_all
212
+ else:
213
+ fn = merge_chunks_edge
214
+ outputs = {}
215
+
216
+ torch.cuda.empty_cache()
217
+
218
+ for s, c in all_chunks_out.items():
219
+ combined: torch.Tensor = fn(
220
+ c,
221
+ original_batch_size,
222
+ n_channel,
223
+ n_samples,
224
+ n_padded_samples,
225
+ n_chunks,
226
+ self.chunk_size,
227
+ self.hop_size,
228
+ self.edge_frame_pad_sizes,
229
+ self.standard_window,
230
+ self.__dict__.get("first_window", self.standard_window),
231
+ self.__dict__.get("last_window", self.standard_window),
232
+ )
233
+
234
+ outputs[s] = combined.to(dtype=original_dtype, device=original_device)
235
+
236
+ return {"audio": outputs}
237
+
238
+
239
+ class LinearFader(BaseFader):
240
+ def __init__(
241
+ self,
242
+ chunk_size_second: float,
243
+ hop_size_second: float,
244
+ fs: int,
245
+ fade_edge_frames: bool = False,
246
+ batch_size: int = 1,
247
+ ) -> None:
248
+
249
+ assert hop_size_second >= chunk_size_second / 2
250
+
251
+ super().__init__(
252
+ chunk_size_second=chunk_size_second,
253
+ hop_size_second=hop_size_second,
254
+ fs=fs,
255
+ fade_edge_frames=fade_edge_frames,
256
+ batch_size=batch_size,
257
+ )
258
+
259
+ in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1]
260
+ out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:]
261
+ center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size)
262
+ inout_ones = torch.ones(self.overlap_size)
263
+
264
+ self.register_buffer(
265
+ "standard_window", torch.concat([in_fade, center_ones, out_fade])
266
+ )
267
+
268
+ self.fade_edge_frames = fade_edge_frames
269
+ self.edge_frame_pad_size = (self.overlap_size, self.overlap_size)
270
+
271
+ if not self.fade_edge_frames:
272
+ self.first_window = nn.Parameter(
273
+ torch.concat([inout_ones, center_ones, out_fade]), requires_grad=False
274
+ )
275
+ self.last_window = nn.Parameter(
276
+ torch.concat([in_fade, center_ones, inout_ones]), requires_grad=False
277
+ )
278
+
279
+
280
+ class OverlapAddFader(BaseFader):
281
+ def __init__(
282
+ self,
283
+ window_type: str,
284
+ chunk_size_second: float,
285
+ hop_size_second: float,
286
+ fs: int,
287
+ batch_size: int = 1,
288
+ ) -> None:
289
+ assert (chunk_size_second / hop_size_second) % 2 == 0
290
+ assert int(chunk_size_second * fs) % 2 == 0
291
+
292
+ super().__init__(
293
+ chunk_size_second=chunk_size_second,
294
+ hop_size_second=hop_size_second,
295
+ fs=fs,
296
+ fade_edge_frames=True,
297
+ batch_size=batch_size,
298
+ )
299
+
300
+ self.hop_multiplier = self.chunk_size / (2 * self.hop_size)
301
+
302
+ self.edge_frame_pad_sizes = (2 * self.overlap_size, 2 * self.overlap_size)
303
+
304
+ self.register_buffer(
305
+ "standard_window",
306
+ torch.windows.__dict__[window_type](
307
+ self.chunk_size,
308
+ sym=False,
309
+ )
310
+ / self.hop_multiplier,
311
+ )
312
+
313
+
314
+ if __name__ == "__main__":
315
+ import torchaudio as ta
316
+
317
+ fs = 44100
318
+ ola = OverlapAddFader("hann", 6.0, 1.0, fs, batch_size=16)
319
+ audio_, _ = ta.load(
320
+ "$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too " "Much/vocals.wav"
321
+ )
322
+ audio_ = audio_[None, ...]
323
+ out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"]
324
+ print(torch.allclose(out, audio_))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvsepless/models/bandit/model_from_config.py CHANGED
@@ -1,26 +1,26 @@
1
- import sys
2
- import os.path
3
- import torch
4
-
5
- import yaml
6
- from ml_collections import ConfigDict
7
-
8
- torch.set_float32_matmul_precision("medium")
9
-
10
-
11
- def get_model(
12
- config_path,
13
- weights_path,
14
- device,
15
- ):
16
- from .core.model import MultiMaskMultiSourceBandSplitRNNSimple
17
-
18
- f = open(config_path)
19
- config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
20
- f.close()
21
-
22
- model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model)
23
- d = torch.load(code_path + "model_bandit_plus_dnr_sdr_11.47.chpt")
24
- model.load_state_dict(d)
25
- model.to(device)
26
- return model, config
 
1
+ import sys
2
+ import os.path
3
+ import torch
4
+
5
+ import yaml
6
+ from ml_collections import ConfigDict
7
+
8
+ torch.set_float32_matmul_precision("medium")
9
+
10
+
11
+ def get_model(
12
+ config_path,
13
+ weights_path,
14
+ device,
15
+ ):
16
+ from .core.model import MultiMaskMultiSourceBandSplitRNNSimple
17
+
18
+ f = open(config_path)
19
+ config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
20
+ f.close()
21
+
22
+ model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model)
23
+ d = torch.load(code_path + "model_bandit_plus_dnr_sdr_11.47.chpt")
24
+ model.load_state_dict(d)
25
+ model.to(device)
26
+ return model, config
mvsepless/models/bandit_v2/bandit.py CHANGED
@@ -1,363 +1,360 @@
1
- from typing import Dict, List, Optional
2
-
3
- import torch
4
- import torchaudio as ta
5
- from torch import nn
6
- import pytorch_lightning as pl
7
-
8
- from .bandsplit import BandSplitModule
9
- from .maskestim import OverlappingMaskEstimationModule
10
- from .tfmodel import SeqBandModellingModule
11
- from .utils import MusicalBandsplitSpecification
12
-
13
-
14
- class BaseEndToEndModule(pl.LightningModule):
15
- def __init__(
16
- self,
17
- ) -> None:
18
- super().__init__()
19
-
20
-
21
- class BaseBandit(BaseEndToEndModule):
22
- def __init__(
23
- self,
24
- in_channels: int,
25
- fs: int,
26
- band_type: str = "musical",
27
- n_bands: int = 64,
28
- require_no_overlap: bool = False,
29
- require_no_gap: bool = True,
30
- normalize_channel_independently: bool = False,
31
- treat_channel_as_feature: bool = True,
32
- n_sqm_modules: int = 12,
33
- emb_dim: int = 128,
34
- rnn_dim: int = 256,
35
- bidirectional: bool = True,
36
- rnn_type: str = "LSTM",
37
- n_fft: int = 2048,
38
- win_length: Optional[int] = 2048,
39
- hop_length: int = 512,
40
- window_fn: str = "hann_window",
41
- wkwargs: Optional[Dict] = None,
42
- power: Optional[int] = None,
43
- center: bool = True,
44
- normalized: bool = True,
45
- pad_mode: str = "constant",
46
- onesided: bool = True,
47
- ):
48
- super().__init__()
49
-
50
- self.in_channels = in_channels
51
-
52
- self.instantitate_spectral(
53
- n_fft=n_fft,
54
- win_length=win_length,
55
- hop_length=hop_length,
56
- window_fn=window_fn,
57
- wkwargs=wkwargs,
58
- power=power,
59
- normalized=normalized,
60
- center=center,
61
- pad_mode=pad_mode,
62
- onesided=onesided,
63
- )
64
-
65
- self.instantiate_bandsplit(
66
- in_channels=in_channels,
67
- band_type=band_type,
68
- n_bands=n_bands,
69
- require_no_overlap=require_no_overlap,
70
- require_no_gap=require_no_gap,
71
- normalize_channel_independently=normalize_channel_independently,
72
- treat_channel_as_feature=treat_channel_as_feature,
73
- emb_dim=emb_dim,
74
- n_fft=n_fft,
75
- fs=fs,
76
- )
77
-
78
- self.instantiate_tf_modelling(
79
- n_sqm_modules=n_sqm_modules,
80
- emb_dim=emb_dim,
81
- rnn_dim=rnn_dim,
82
- bidirectional=bidirectional,
83
- rnn_type=rnn_type,
84
- )
85
-
86
- def instantitate_spectral(
87
- self,
88
- n_fft: int = 2048,
89
- win_length: Optional[int] = 2048,
90
- hop_length: int = 512,
91
- window_fn: str = "hann_window",
92
- wkwargs: Optional[Dict] = None,
93
- power: Optional[int] = None,
94
- normalized: bool = True,
95
- center: bool = True,
96
- pad_mode: str = "constant",
97
- onesided: bool = True,
98
- ):
99
- assert power is None
100
-
101
- window_fn = torch.__dict__[window_fn]
102
-
103
- self.stft = ta.transforms.Spectrogram(
104
- n_fft=n_fft,
105
- win_length=win_length,
106
- hop_length=hop_length,
107
- pad_mode=pad_mode,
108
- pad=0,
109
- window_fn=window_fn,
110
- wkwargs=wkwargs,
111
- power=power,
112
- normalized=normalized,
113
- center=center,
114
- onesided=onesided,
115
- )
116
-
117
- self.istft = ta.transforms.InverseSpectrogram(
118
- n_fft=n_fft,
119
- win_length=win_length,
120
- hop_length=hop_length,
121
- pad_mode=pad_mode,
122
- pad=0,
123
- window_fn=window_fn,
124
- wkwargs=wkwargs,
125
- normalized=normalized,
126
- center=center,
127
- onesided=onesided,
128
- )
129
-
130
- def instantiate_bandsplit(
131
- self,
132
- in_channels: int,
133
- band_type: str = "musical",
134
- n_bands: int = 64,
135
- require_no_overlap: bool = False,
136
- require_no_gap: bool = True,
137
- normalize_channel_independently: bool = False,
138
- treat_channel_as_feature: bool = True,
139
- emb_dim: int = 128,
140
- n_fft: int = 2048,
141
- fs: int = 44100,
142
- ):
143
- assert band_type == "musical"
144
-
145
- self.band_specs = MusicalBandsplitSpecification(
146
- nfft=n_fft, fs=fs, n_bands=n_bands
147
- )
148
-
149
- self.band_split = BandSplitModule(
150
- in_channels=in_channels,
151
- band_specs=self.band_specs.get_band_specs(),
152
- require_no_overlap=require_no_overlap,
153
- require_no_gap=require_no_gap,
154
- normalize_channel_independently=normalize_channel_independently,
155
- treat_channel_as_feature=treat_channel_as_feature,
156
- emb_dim=emb_dim,
157
- )
158
-
159
- def instantiate_tf_modelling(
160
- self,
161
- n_sqm_modules: int = 12,
162
- emb_dim: int = 128,
163
- rnn_dim: int = 256,
164
- bidirectional: bool = True,
165
- rnn_type: str = "LSTM",
166
- ):
167
- try:
168
- self.tf_model = torch.compile(
169
- SeqBandModellingModule(
170
- n_modules=n_sqm_modules,
171
- emb_dim=emb_dim,
172
- rnn_dim=rnn_dim,
173
- bidirectional=bidirectional,
174
- rnn_type=rnn_type,
175
- ),
176
- disable=True,
177
- )
178
- except Exception as e:
179
- self.tf_model = SeqBandModellingModule(
180
- n_modules=n_sqm_modules,
181
- emb_dim=emb_dim,
182
- rnn_dim=rnn_dim,
183
- bidirectional=bidirectional,
184
- rnn_type=rnn_type,
185
- )
186
-
187
- def mask(self, x, m):
188
- return x * m
189
-
190
- def forward(self, batch, mode="train"):
191
- # Model takes mono as input we give stereo, so we do process of each channel independently
192
- init_shape = batch.shape
193
- if not isinstance(batch, dict):
194
- mono = batch.view(-1, 1, batch.shape[-1])
195
- batch = {"mixture": {"audio": mono}}
196
-
197
- with torch.no_grad():
198
- mixture = batch["mixture"]["audio"]
199
-
200
- x = self.stft(mixture)
201
- batch["mixture"]["spectrogram"] = x
202
-
203
- if "sources" in batch.keys():
204
- for stem in batch["sources"].keys():
205
- s = batch["sources"][stem]["audio"]
206
- s = self.stft(s)
207
- batch["sources"][stem]["spectrogram"] = s
208
-
209
- batch = self.separate(batch)
210
-
211
- if 1:
212
- b = []
213
- for s in self.stems:
214
- # We need to obtain stereo again
215
- r = batch["estimates"][s]["audio"].view(
216
- -1, init_shape[1], init_shape[2]
217
- )
218
- b.append(r)
219
- # And we need to return back tensor and not independent stems
220
- batch = torch.stack(b, dim=1)
221
- return batch
222
-
223
- def encode(self, batch):
224
- x = batch["mixture"]["spectrogram"]
225
- length = batch["mixture"]["audio"].shape[-1]
226
-
227
- z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
228
- q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
229
-
230
- return x, q, length
231
-
232
- def separate(self, batch):
233
- raise NotImplementedError
234
-
235
-
236
- class Bandit(BaseBandit):
237
- def __init__(
238
- self,
239
- in_channels: int,
240
- stems: List[str],
241
- band_type: str = "musical",
242
- n_bands: int = 64,
243
- require_no_overlap: bool = False,
244
- require_no_gap: bool = True,
245
- normalize_channel_independently: bool = False,
246
- treat_channel_as_feature: bool = True,
247
- n_sqm_modules: int = 12,
248
- emb_dim: int = 128,
249
- rnn_dim: int = 256,
250
- bidirectional: bool = True,
251
- rnn_type: str = "LSTM",
252
- mlp_dim: int = 512,
253
- hidden_activation: str = "Tanh",
254
- hidden_activation_kwargs: Dict | None = None,
255
- complex_mask: bool = True,
256
- use_freq_weights: bool = True,
257
- n_fft: int = 2048,
258
- win_length: int | None = 2048,
259
- hop_length: int = 512,
260
- window_fn: str = "hann_window",
261
- wkwargs: Dict | None = None,
262
- power: int | None = None,
263
- center: bool = True,
264
- normalized: bool = True,
265
- pad_mode: str = "constant",
266
- onesided: bool = True,
267
- fs: int = 44100,
268
- stft_precisions="32",
269
- bandsplit_precisions="bf16",
270
- tf_model_precisions="bf16",
271
- mask_estim_precisions="bf16",
272
- ):
273
- super().__init__(
274
- in_channels=in_channels,
275
- band_type=band_type,
276
- n_bands=n_bands,
277
- require_no_overlap=require_no_overlap,
278
- require_no_gap=require_no_gap,
279
- normalize_channel_independently=normalize_channel_independently,
280
- treat_channel_as_feature=treat_channel_as_feature,
281
- n_sqm_modules=n_sqm_modules,
282
- emb_dim=emb_dim,
283
- rnn_dim=rnn_dim,
284
- bidirectional=bidirectional,
285
- rnn_type=rnn_type,
286
- n_fft=n_fft,
287
- win_length=win_length,
288
- hop_length=hop_length,
289
- window_fn=window_fn,
290
- wkwargs=wkwargs,
291
- power=power,
292
- center=center,
293
- normalized=normalized,
294
- pad_mode=pad_mode,
295
- onesided=onesided,
296
- fs=fs,
297
- )
298
-
299
- self.stems = stems
300
-
301
- self.instantiate_mask_estim(
302
- in_channels=in_channels,
303
- stems=stems,
304
- emb_dim=emb_dim,
305
- mlp_dim=mlp_dim,
306
- hidden_activation=hidden_activation,
307
- hidden_activation_kwargs=hidden_activation_kwargs,
308
- complex_mask=complex_mask,
309
- n_freq=n_fft // 2 + 1,
310
- use_freq_weights=use_freq_weights,
311
- )
312
-
313
- def instantiate_mask_estim(
314
- self,
315
- in_channels: int,
316
- stems: List[str],
317
- emb_dim: int,
318
- mlp_dim: int,
319
- hidden_activation: str,
320
- hidden_activation_kwargs: Optional[Dict] = None,
321
- complex_mask: bool = True,
322
- n_freq: Optional[int] = None,
323
- use_freq_weights: bool = False,
324
- ):
325
- if hidden_activation_kwargs is None:
326
- hidden_activation_kwargs = {}
327
-
328
- assert n_freq is not None
329
-
330
- self.mask_estim = nn.ModuleDict(
331
- {
332
- stem: OverlappingMaskEstimationModule(
333
- band_specs=self.band_specs.get_band_specs(),
334
- freq_weights=self.band_specs.get_freq_weights(),
335
- n_freq=n_freq,
336
- emb_dim=emb_dim,
337
- mlp_dim=mlp_dim,
338
- in_channels=in_channels,
339
- hidden_activation=hidden_activation,
340
- hidden_activation_kwargs=hidden_activation_kwargs,
341
- complex_mask=complex_mask,
342
- use_freq_weights=use_freq_weights,
343
- )
344
- for stem in stems
345
- }
346
- )
347
-
348
- def separate(self, batch):
349
- batch["estimates"] = {}
350
-
351
- x, q, length = self.encode(batch)
352
-
353
- for stem, mem in self.mask_estim.items():
354
- m = mem(q)
355
-
356
- s = self.mask(x, m.to(x.dtype))
357
- s = torch.reshape(s, x.shape)
358
- batch["estimates"][stem] = {
359
- "audio": self.istft(s, length),
360
- "spectrogram": s,
361
- }
362
-
363
- return batch
 
1
+ from typing import Dict, List, Optional
2
+
3
+ import torch
4
+ import torchaudio as ta
5
+ from torch import nn
6
+ import pytorch_lightning as pl
7
+
8
+ from .bandsplit import BandSplitModule
9
+ from .maskestim import OverlappingMaskEstimationModule
10
+ from .tfmodel import SeqBandModellingModule
11
+ from .utils import MusicalBandsplitSpecification
12
+
13
+
14
+ class BaseEndToEndModule(pl.LightningModule):
15
+ def __init__(
16
+ self,
17
+ ) -> None:
18
+ super().__init__()
19
+
20
+
21
+ class BaseBandit(BaseEndToEndModule):
22
+ def __init__(
23
+ self,
24
+ in_channels: int,
25
+ fs: int,
26
+ band_type: str = "musical",
27
+ n_bands: int = 64,
28
+ require_no_overlap: bool = False,
29
+ require_no_gap: bool = True,
30
+ normalize_channel_independently: bool = False,
31
+ treat_channel_as_feature: bool = True,
32
+ n_sqm_modules: int = 12,
33
+ emb_dim: int = 128,
34
+ rnn_dim: int = 256,
35
+ bidirectional: bool = True,
36
+ rnn_type: str = "LSTM",
37
+ n_fft: int = 2048,
38
+ win_length: Optional[int] = 2048,
39
+ hop_length: int = 512,
40
+ window_fn: str = "hann_window",
41
+ wkwargs: Optional[Dict] = None,
42
+ power: Optional[int] = None,
43
+ center: bool = True,
44
+ normalized: bool = True,
45
+ pad_mode: str = "constant",
46
+ onesided: bool = True,
47
+ ):
48
+ super().__init__()
49
+
50
+ self.in_channels = in_channels
51
+
52
+ self.instantitate_spectral(
53
+ n_fft=n_fft,
54
+ win_length=win_length,
55
+ hop_length=hop_length,
56
+ window_fn=window_fn,
57
+ wkwargs=wkwargs,
58
+ power=power,
59
+ normalized=normalized,
60
+ center=center,
61
+ pad_mode=pad_mode,
62
+ onesided=onesided,
63
+ )
64
+
65
+ self.instantiate_bandsplit(
66
+ in_channels=in_channels,
67
+ band_type=band_type,
68
+ n_bands=n_bands,
69
+ require_no_overlap=require_no_overlap,
70
+ require_no_gap=require_no_gap,
71
+ normalize_channel_independently=normalize_channel_independently,
72
+ treat_channel_as_feature=treat_channel_as_feature,
73
+ emb_dim=emb_dim,
74
+ n_fft=n_fft,
75
+ fs=fs,
76
+ )
77
+
78
+ self.instantiate_tf_modelling(
79
+ n_sqm_modules=n_sqm_modules,
80
+ emb_dim=emb_dim,
81
+ rnn_dim=rnn_dim,
82
+ bidirectional=bidirectional,
83
+ rnn_type=rnn_type,
84
+ )
85
+
86
+ def instantitate_spectral(
87
+ self,
88
+ n_fft: int = 2048,
89
+ win_length: Optional[int] = 2048,
90
+ hop_length: int = 512,
91
+ window_fn: str = "hann_window",
92
+ wkwargs: Optional[Dict] = None,
93
+ power: Optional[int] = None,
94
+ normalized: bool = True,
95
+ center: bool = True,
96
+ pad_mode: str = "constant",
97
+ onesided: bool = True,
98
+ ):
99
+ assert power is None
100
+
101
+ window_fn = torch.__dict__[window_fn]
102
+
103
+ self.stft = ta.transforms.Spectrogram(
104
+ n_fft=n_fft,
105
+ win_length=win_length,
106
+ hop_length=hop_length,
107
+ pad_mode=pad_mode,
108
+ pad=0,
109
+ window_fn=window_fn,
110
+ wkwargs=wkwargs,
111
+ power=power,
112
+ normalized=normalized,
113
+ center=center,
114
+ onesided=onesided,
115
+ )
116
+
117
+ self.istft = ta.transforms.InverseSpectrogram(
118
+ n_fft=n_fft,
119
+ win_length=win_length,
120
+ hop_length=hop_length,
121
+ pad_mode=pad_mode,
122
+ pad=0,
123
+ window_fn=window_fn,
124
+ wkwargs=wkwargs,
125
+ normalized=normalized,
126
+ center=center,
127
+ onesided=onesided,
128
+ )
129
+
130
+ def instantiate_bandsplit(
131
+ self,
132
+ in_channels: int,
133
+ band_type: str = "musical",
134
+ n_bands: int = 64,
135
+ require_no_overlap: bool = False,
136
+ require_no_gap: bool = True,
137
+ normalize_channel_independently: bool = False,
138
+ treat_channel_as_feature: bool = True,
139
+ emb_dim: int = 128,
140
+ n_fft: int = 2048,
141
+ fs: int = 44100,
142
+ ):
143
+ assert band_type == "musical"
144
+
145
+ self.band_specs = MusicalBandsplitSpecification(
146
+ nfft=n_fft, fs=fs, n_bands=n_bands
147
+ )
148
+
149
+ self.band_split = BandSplitModule(
150
+ in_channels=in_channels,
151
+ band_specs=self.band_specs.get_band_specs(),
152
+ require_no_overlap=require_no_overlap,
153
+ require_no_gap=require_no_gap,
154
+ normalize_channel_independently=normalize_channel_independently,
155
+ treat_channel_as_feature=treat_channel_as_feature,
156
+ emb_dim=emb_dim,
157
+ )
158
+
159
+ def instantiate_tf_modelling(
160
+ self,
161
+ n_sqm_modules: int = 12,
162
+ emb_dim: int = 128,
163
+ rnn_dim: int = 256,
164
+ bidirectional: bool = True,
165
+ rnn_type: str = "LSTM",
166
+ ):
167
+ try:
168
+ self.tf_model = torch.compile(
169
+ SeqBandModellingModule(
170
+ n_modules=n_sqm_modules,
171
+ emb_dim=emb_dim,
172
+ rnn_dim=rnn_dim,
173
+ bidirectional=bidirectional,
174
+ rnn_type=rnn_type,
175
+ ),
176
+ disable=True,
177
+ )
178
+ except Exception as e:
179
+ self.tf_model = SeqBandModellingModule(
180
+ n_modules=n_sqm_modules,
181
+ emb_dim=emb_dim,
182
+ rnn_dim=rnn_dim,
183
+ bidirectional=bidirectional,
184
+ rnn_type=rnn_type,
185
+ )
186
+
187
+ def mask(self, x, m):
188
+ return x * m
189
+
190
+ def forward(self, batch, mode="train"):
191
+ init_shape = batch.shape
192
+ if not isinstance(batch, dict):
193
+ mono = batch.view(-1, 1, batch.shape[-1])
194
+ batch = {"mixture": {"audio": mono}}
195
+
196
+ with torch.no_grad():
197
+ mixture = batch["mixture"]["audio"]
198
+
199
+ x = self.stft(mixture)
200
+ batch["mixture"]["spectrogram"] = x
201
+
202
+ if "sources" in batch.keys():
203
+ for stem in batch["sources"].keys():
204
+ s = batch["sources"][stem]["audio"]
205
+ s = self.stft(s)
206
+ batch["sources"][stem]["spectrogram"] = s
207
+
208
+ batch = self.separate(batch)
209
+
210
+ if 1:
211
+ b = []
212
+ for s in self.stems:
213
+ r = batch["estimates"][s]["audio"].view(
214
+ -1, init_shape[1], init_shape[2]
215
+ )
216
+ b.append(r)
217
+ batch = torch.stack(b, dim=1)
218
+ return batch
219
+
220
+ def encode(self, batch):
221
+ x = batch["mixture"]["spectrogram"]
222
+ length = batch["mixture"]["audio"].shape[-1]
223
+
224
+ z = self.band_split(x)
225
+ q = self.tf_model(z)
226
+
227
+ return x, q, length
228
+
229
+ def separate(self, batch):
230
+ raise NotImplementedError
231
+
232
+
233
+ class Bandit(BaseBandit):
234
+ def __init__(
235
+ self,
236
+ in_channels: int,
237
+ stems: List[str],
238
+ band_type: str = "musical",
239
+ n_bands: int = 64,
240
+ require_no_overlap: bool = False,
241
+ require_no_gap: bool = True,
242
+ normalize_channel_independently: bool = False,
243
+ treat_channel_as_feature: bool = True,
244
+ n_sqm_modules: int = 12,
245
+ emb_dim: int = 128,
246
+ rnn_dim: int = 256,
247
+ bidirectional: bool = True,
248
+ rnn_type: str = "LSTM",
249
+ mlp_dim: int = 512,
250
+ hidden_activation: str = "Tanh",
251
+ hidden_activation_kwargs: Dict | None = None,
252
+ complex_mask: bool = True,
253
+ use_freq_weights: bool = True,
254
+ n_fft: int = 2048,
255
+ win_length: int | None = 2048,
256
+ hop_length: int = 512,
257
+ window_fn: str = "hann_window",
258
+ wkwargs: Dict | None = None,
259
+ power: int | None = None,
260
+ center: bool = True,
261
+ normalized: bool = True,
262
+ pad_mode: str = "constant",
263
+ onesided: bool = True,
264
+ fs: int = 44100,
265
+ stft_precisions="32",
266
+ bandsplit_precisions="bf16",
267
+ tf_model_precisions="bf16",
268
+ mask_estim_precisions="bf16",
269
+ ):
270
+ super().__init__(
271
+ in_channels=in_channels,
272
+ band_type=band_type,
273
+ n_bands=n_bands,
274
+ require_no_overlap=require_no_overlap,
275
+ require_no_gap=require_no_gap,
276
+ normalize_channel_independently=normalize_channel_independently,
277
+ treat_channel_as_feature=treat_channel_as_feature,
278
+ n_sqm_modules=n_sqm_modules,
279
+ emb_dim=emb_dim,
280
+ rnn_dim=rnn_dim,
281
+ bidirectional=bidirectional,
282
+ rnn_type=rnn_type,
283
+ n_fft=n_fft,
284
+ win_length=win_length,
285
+ hop_length=hop_length,
286
+ window_fn=window_fn,
287
+ wkwargs=wkwargs,
288
+ power=power,
289
+ center=center,
290
+ normalized=normalized,
291
+ pad_mode=pad_mode,
292
+ onesided=onesided,
293
+ fs=fs,
294
+ )
295
+
296
+ self.stems = stems
297
+
298
+ self.instantiate_mask_estim(
299
+ in_channels=in_channels,
300
+ stems=stems,
301
+ emb_dim=emb_dim,
302
+ mlp_dim=mlp_dim,
303
+ hidden_activation=hidden_activation,
304
+ hidden_activation_kwargs=hidden_activation_kwargs,
305
+ complex_mask=complex_mask,
306
+ n_freq=n_fft // 2 + 1,
307
+ use_freq_weights=use_freq_weights,
308
+ )
309
+
310
+ def instantiate_mask_estim(
311
+ self,
312
+ in_channels: int,
313
+ stems: List[str],
314
+ emb_dim: int,
315
+ mlp_dim: int,
316
+ hidden_activation: str,
317
+ hidden_activation_kwargs: Optional[Dict] = None,
318
+ complex_mask: bool = True,
319
+ n_freq: Optional[int] = None,
320
+ use_freq_weights: bool = False,
321
+ ):
322
+ if hidden_activation_kwargs is None:
323
+ hidden_activation_kwargs = {}
324
+
325
+ assert n_freq is not None
326
+
327
+ self.mask_estim = nn.ModuleDict(
328
+ {
329
+ stem: OverlappingMaskEstimationModule(
330
+ band_specs=self.band_specs.get_band_specs(),
331
+ freq_weights=self.band_specs.get_freq_weights(),
332
+ n_freq=n_freq,
333
+ emb_dim=emb_dim,
334
+ mlp_dim=mlp_dim,
335
+ in_channels=in_channels,
336
+ hidden_activation=hidden_activation,
337
+ hidden_activation_kwargs=hidden_activation_kwargs,
338
+ complex_mask=complex_mask,
339
+ use_freq_weights=use_freq_weights,
340
+ )
341
+ for stem in stems
342
+ }
343
+ )
344
+
345
+ def separate(self, batch):
346
+ batch["estimates"] = {}
347
+
348
+ x, q, length = self.encode(batch)
349
+
350
+ for stem, mem in self.mask_estim.items():
351
+ m = mem(q)
352
+
353
+ s = self.mask(x, m.to(x.dtype))
354
+ s = torch.reshape(s, x.shape)
355
+ batch["estimates"][stem] = {
356
+ "audio": self.istft(s, length),
357
+ "spectrogram": s,
358
+ }
359
+
360
+ return batch
 
 
 
mvsepless/models/bandit_v2/bandsplit.py CHANGED
@@ -1,130 +1,127 @@
1
- from typing import List, Tuple
2
-
3
- import torch
4
- from torch import nn
5
- from torch.utils.checkpoint import checkpoint_sequential
6
-
7
- from .utils import (
8
- band_widths_from_specs,
9
- check_no_gap,
10
- check_no_overlap,
11
- check_nonzero_bandwidth,
12
- )
13
-
14
-
15
- class NormFC(nn.Module):
16
- def __init__(
17
- self,
18
- emb_dim: int,
19
- bandwidth: int,
20
- in_channels: int,
21
- normalize_channel_independently: bool = False,
22
- treat_channel_as_feature: bool = True,
23
- ) -> None:
24
- super().__init__()
25
-
26
- if not treat_channel_as_feature:
27
- raise NotImplementedError
28
-
29
- self.treat_channel_as_feature = treat_channel_as_feature
30
-
31
- if normalize_channel_independently:
32
- raise NotImplementedError
33
-
34
- reim = 2
35
-
36
- norm = nn.LayerNorm(in_channels * bandwidth * reim)
37
-
38
- fc_in = bandwidth * reim
39
-
40
- if treat_channel_as_feature:
41
- fc_in *= in_channels
42
- else:
43
- assert emb_dim % in_channels == 0
44
- emb_dim = emb_dim // in_channels
45
-
46
- fc = nn.Linear(fc_in, emb_dim)
47
-
48
- self.combined = nn.Sequential(norm, fc)
49
-
50
- def forward(self, xb):
51
- return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False)
52
-
53
-
54
- class BandSplitModule(nn.Module):
55
- def __init__(
56
- self,
57
- band_specs: List[Tuple[float, float]],
58
- emb_dim: int,
59
- in_channels: int,
60
- require_no_overlap: bool = False,
61
- require_no_gap: bool = True,
62
- normalize_channel_independently: bool = False,
63
- treat_channel_as_feature: bool = True,
64
- ) -> None:
65
- super().__init__()
66
-
67
- check_nonzero_bandwidth(band_specs)
68
-
69
- if require_no_gap:
70
- check_no_gap(band_specs)
71
-
72
- if require_no_overlap:
73
- check_no_overlap(band_specs)
74
-
75
- self.band_specs = band_specs
76
- # list of [fstart, fend) in index.
77
- # Note that fend is exclusive.
78
- self.band_widths = band_widths_from_specs(band_specs)
79
- self.n_bands = len(band_specs)
80
- self.emb_dim = emb_dim
81
-
82
- try:
83
- self.norm_fc_modules = nn.ModuleList(
84
- [ # type: ignore
85
- torch.compile(
86
- NormFC(
87
- emb_dim=emb_dim,
88
- bandwidth=bw,
89
- in_channels=in_channels,
90
- normalize_channel_independently=normalize_channel_independently,
91
- treat_channel_as_feature=treat_channel_as_feature,
92
- ),
93
- disable=True,
94
- )
95
- for bw in self.band_widths
96
- ]
97
- )
98
- except Exception as e:
99
- self.norm_fc_modules = nn.ModuleList(
100
- [ # type: ignore
101
- NormFC(
102
- emb_dim=emb_dim,
103
- bandwidth=bw,
104
- in_channels=in_channels,
105
- normalize_channel_independently=normalize_channel_independently,
106
- treat_channel_as_feature=treat_channel_as_feature,
107
- )
108
- for bw in self.band_widths
109
- ]
110
- )
111
-
112
- def forward(self, x: torch.Tensor):
113
- # x = complex spectrogram (batch, in_chan, n_freq, n_time)
114
-
115
- batch, in_chan, band_width, n_time = x.shape
116
-
117
- z = torch.zeros(
118
- size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
119
- )
120
-
121
- x = torch.permute(x, (0, 3, 1, 2)).contiguous()
122
-
123
- for i, nfm in enumerate(self.norm_fc_modules):
124
- fstart, fend = self.band_specs[i]
125
- xb = x[:, :, :, fstart:fend]
126
- xb = torch.view_as_real(xb)
127
- xb = torch.reshape(xb, (batch, n_time, -1))
128
- z[:, i, :, :] = nfm(xb)
129
-
130
- return z
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.utils.checkpoint import checkpoint_sequential
6
+
7
+ from .utils import (
8
+ band_widths_from_specs,
9
+ check_no_gap,
10
+ check_no_overlap,
11
+ check_nonzero_bandwidth,
12
+ )
13
+
14
+
15
+ class NormFC(nn.Module):
16
+ def __init__(
17
+ self,
18
+ emb_dim: int,
19
+ bandwidth: int,
20
+ in_channels: int,
21
+ normalize_channel_independently: bool = False,
22
+ treat_channel_as_feature: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+
26
+ if not treat_channel_as_feature:
27
+ raise NotImplementedError
28
+
29
+ self.treat_channel_as_feature = treat_channel_as_feature
30
+
31
+ if normalize_channel_independently:
32
+ raise NotImplementedError
33
+
34
+ reim = 2
35
+
36
+ norm = nn.LayerNorm(in_channels * bandwidth * reim)
37
+
38
+ fc_in = bandwidth * reim
39
+
40
+ if treat_channel_as_feature:
41
+ fc_in *= in_channels
42
+ else:
43
+ assert emb_dim % in_channels == 0
44
+ emb_dim = emb_dim // in_channels
45
+
46
+ fc = nn.Linear(fc_in, emb_dim)
47
+
48
+ self.combined = nn.Sequential(norm, fc)
49
+
50
+ def forward(self, xb):
51
+ return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False)
52
+
53
+
54
+ class BandSplitModule(nn.Module):
55
+ def __init__(
56
+ self,
57
+ band_specs: List[Tuple[float, float]],
58
+ emb_dim: int,
59
+ in_channels: int,
60
+ require_no_overlap: bool = False,
61
+ require_no_gap: bool = True,
62
+ normalize_channel_independently: bool = False,
63
+ treat_channel_as_feature: bool = True,
64
+ ) -> None:
65
+ super().__init__()
66
+
67
+ check_nonzero_bandwidth(band_specs)
68
+
69
+ if require_no_gap:
70
+ check_no_gap(band_specs)
71
+
72
+ if require_no_overlap:
73
+ check_no_overlap(band_specs)
74
+
75
+ self.band_specs = band_specs
76
+ self.band_widths = band_widths_from_specs(band_specs)
77
+ self.n_bands = len(band_specs)
78
+ self.emb_dim = emb_dim
79
+
80
+ try:
81
+ self.norm_fc_modules = nn.ModuleList(
82
+ [ # type: ignore
83
+ torch.compile(
84
+ NormFC(
85
+ emb_dim=emb_dim,
86
+ bandwidth=bw,
87
+ in_channels=in_channels,
88
+ normalize_channel_independently=normalize_channel_independently,
89
+ treat_channel_as_feature=treat_channel_as_feature,
90
+ ),
91
+ disable=True,
92
+ )
93
+ for bw in self.band_widths
94
+ ]
95
+ )
96
+ except Exception as e:
97
+ self.norm_fc_modules = nn.ModuleList(
98
+ [ # type: ignore
99
+ NormFC(
100
+ emb_dim=emb_dim,
101
+ bandwidth=bw,
102
+ in_channels=in_channels,
103
+ normalize_channel_independently=normalize_channel_independently,
104
+ treat_channel_as_feature=treat_channel_as_feature,
105
+ )
106
+ for bw in self.band_widths
107
+ ]
108
+ )
109
+
110
+ def forward(self, x: torch.Tensor):
111
+
112
+ batch, in_chan, band_width, n_time = x.shape
113
+
114
+ z = torch.zeros(
115
+ size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
116
+ )
117
+
118
+ x = torch.permute(x, (0, 3, 1, 2)).contiguous()
119
+
120
+ for i, nfm in enumerate(self.norm_fc_modules):
121
+ fstart, fend = self.band_specs[i]
122
+ xb = x[:, :, :, fstart:fend]
123
+ xb = torch.view_as_real(xb)
124
+ xb = torch.reshape(xb, (batch, n_time, -1))
125
+ z[:, i, :, :] = nfm(xb)
126
+
127
+ return z
 
 
 
mvsepless/models/bandit_v2/film.py CHANGED
@@ -1,23 +1,23 @@
1
- from torch import nn
2
- import torch
3
-
4
-
5
- class FiLM(nn.Module):
6
- def __init__(self):
7
- super().__init__()
8
-
9
- def forward(self, x, gamma, beta):
10
- return gamma * x + beta
11
-
12
-
13
- class BTFBroadcastedFiLM(nn.Module):
14
- def __init__(self):
15
- super().__init__()
16
- self.film = FiLM()
17
-
18
- def forward(self, x, gamma, beta):
19
-
20
- gamma = gamma[None, None, None, :]
21
- beta = beta[None, None, None, :]
22
-
23
- return self.film(x, gamma, beta)
 
1
+ from torch import nn
2
+ import torch
3
+
4
+
5
+ class FiLM(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x, gamma, beta):
10
+ return gamma * x + beta
11
+
12
+
13
+ class BTFBroadcastedFiLM(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+ self.film = FiLM()
17
+
18
+ def forward(self, x, gamma, beta):
19
+
20
+ gamma = gamma[None, None, None, :]
21
+ beta = beta[None, None, None, :]
22
+
23
+ return self.film(x, gamma, beta)
mvsepless/models/bandit_v2/maskestim.py CHANGED
@@ -1,281 +1,269 @@
1
- from typing import Dict, List, Optional, Tuple, Type
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn.modules import activation
6
- from torch.utils.checkpoint import checkpoint_sequential
7
-
8
- from .utils import (
9
- band_widths_from_specs,
10
- check_no_gap,
11
- check_no_overlap,
12
- check_nonzero_bandwidth,
13
- )
14
-
15
-
16
- class BaseNormMLP(nn.Module):
17
- def __init__(
18
- self,
19
- emb_dim: int,
20
- mlp_dim: int,
21
- bandwidth: int,
22
- in_channels: Optional[int],
23
- hidden_activation: str = "Tanh",
24
- hidden_activation_kwargs=None,
25
- complex_mask: bool = True,
26
- ):
27
- super().__init__()
28
- if hidden_activation_kwargs is None:
29
- hidden_activation_kwargs = {}
30
- self.hidden_activation_kwargs = hidden_activation_kwargs
31
- self.norm = nn.LayerNorm(emb_dim)
32
- self.hidden = nn.Sequential(
33
- nn.Linear(in_features=emb_dim, out_features=mlp_dim),
34
- activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
35
- )
36
-
37
- self.bandwidth = bandwidth
38
- self.in_channels = in_channels
39
-
40
- self.complex_mask = complex_mask
41
- self.reim = 2 if complex_mask else 1
42
- self.glu_mult = 2
43
-
44
-
45
- class NormMLP(BaseNormMLP):
46
- def __init__(
47
- self,
48
- emb_dim: int,
49
- mlp_dim: int,
50
- bandwidth: int,
51
- in_channels: Optional[int],
52
- hidden_activation: str = "Tanh",
53
- hidden_activation_kwargs=None,
54
- complex_mask: bool = True,
55
- ) -> None:
56
- super().__init__(
57
- emb_dim=emb_dim,
58
- mlp_dim=mlp_dim,
59
- bandwidth=bandwidth,
60
- in_channels=in_channels,
61
- hidden_activation=hidden_activation,
62
- hidden_activation_kwargs=hidden_activation_kwargs,
63
- complex_mask=complex_mask,
64
- )
65
-
66
- self.output = nn.Sequential(
67
- nn.Linear(
68
- in_features=mlp_dim,
69
- out_features=bandwidth * in_channels * self.reim * 2,
70
- ),
71
- nn.GLU(dim=-1),
72
- )
73
-
74
- try:
75
- self.combined = torch.compile(
76
- nn.Sequential(self.norm, self.hidden, self.output), disable=True
77
- )
78
- except Exception as e:
79
- self.combined = nn.Sequential(self.norm, self.hidden, self.output)
80
-
81
- def reshape_output(self, mb):
82
- # print(mb.shape)
83
- batch, n_time, _ = mb.shape
84
- if self.complex_mask:
85
- mb = mb.reshape(
86
- batch, n_time, self.in_channels, self.bandwidth, self.reim
87
- ).contiguous()
88
- # print(mb.shape)
89
- mb = torch.view_as_complex(mb) # (batch, n_time, in_channels, bandwidth)
90
- else:
91
- mb = mb.reshape(batch, n_time, self.in_channels, self.bandwidth)
92
-
93
- mb = torch.permute(mb, (0, 2, 3, 1)) # (batch, in_channels, bandwidth, n_time)
94
-
95
- return mb
96
-
97
- def forward(self, qb):
98
- # qb = (batch, n_time, emb_dim)
99
- # qb = self.norm(qb) # (batch, n_time, emb_dim)
100
- # qb = self.hidden(qb) # (batch, n_time, mlp_dim)
101
- # mb = self.output(qb) # (batch, n_time, bandwidth * in_channels * reim)
102
-
103
- mb = checkpoint_sequential(self.combined, 2, qb, use_reentrant=False)
104
- mb = self.reshape_output(mb) # (batch, in_channels, bandwidth, n_time)
105
-
106
- return mb
107
-
108
-
109
- class MaskEstimationModuleSuperBase(nn.Module):
110
- pass
111
-
112
-
113
- class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
114
- def __init__(
115
- self,
116
- band_specs: List[Tuple[float, float]],
117
- emb_dim: int,
118
- mlp_dim: int,
119
- in_channels: Optional[int],
120
- hidden_activation: str = "Tanh",
121
- hidden_activation_kwargs: Dict = None,
122
- complex_mask: bool = True,
123
- norm_mlp_cls: Type[nn.Module] = NormMLP,
124
- norm_mlp_kwargs: Dict = None,
125
- ) -> None:
126
- super().__init__()
127
-
128
- self.band_widths = band_widths_from_specs(band_specs)
129
- self.n_bands = len(band_specs)
130
-
131
- if hidden_activation_kwargs is None:
132
- hidden_activation_kwargs = {}
133
-
134
- if norm_mlp_kwargs is None:
135
- norm_mlp_kwargs = {}
136
-
137
- self.norm_mlp = nn.ModuleList(
138
- [
139
- norm_mlp_cls(
140
- bandwidth=self.band_widths[b],
141
- emb_dim=emb_dim,
142
- mlp_dim=mlp_dim,
143
- in_channels=in_channels,
144
- hidden_activation=hidden_activation,
145
- hidden_activation_kwargs=hidden_activation_kwargs,
146
- complex_mask=complex_mask,
147
- **norm_mlp_kwargs,
148
- )
149
- for b in range(self.n_bands)
150
- ]
151
- )
152
-
153
- def compute_masks(self, q):
154
- batch, n_bands, n_time, emb_dim = q.shape
155
-
156
- masks = []
157
-
158
- for b, nmlp in enumerate(self.norm_mlp):
159
- # print(f"maskestim/{b:02d}")
160
- qb = q[:, b, :, :]
161
- mb = nmlp(qb)
162
- masks.append(mb)
163
-
164
- return masks
165
-
166
- def compute_mask(self, q, b):
167
- batch, n_bands, n_time, emb_dim = q.shape
168
- qb = q[:, b, :, :]
169
- mb = self.norm_mlp[b](qb)
170
- return mb
171
-
172
-
173
- class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
174
- def __init__(
175
- self,
176
- in_channels: int,
177
- band_specs: List[Tuple[float, float]],
178
- freq_weights: List[torch.Tensor],
179
- n_freq: int,
180
- emb_dim: int,
181
- mlp_dim: int,
182
- cond_dim: int = 0,
183
- hidden_activation: str = "Tanh",
184
- hidden_activation_kwargs: Dict = None,
185
- complex_mask: bool = True,
186
- norm_mlp_cls: Type[nn.Module] = NormMLP,
187
- norm_mlp_kwargs: Dict = None,
188
- use_freq_weights: bool = False,
189
- ) -> None:
190
- check_nonzero_bandwidth(band_specs)
191
- check_no_gap(band_specs)
192
-
193
- if cond_dim > 0:
194
- raise NotImplementedError
195
-
196
- super().__init__(
197
- band_specs=band_specs,
198
- emb_dim=emb_dim + cond_dim,
199
- mlp_dim=mlp_dim,
200
- in_channels=in_channels,
201
- hidden_activation=hidden_activation,
202
- hidden_activation_kwargs=hidden_activation_kwargs,
203
- complex_mask=complex_mask,
204
- norm_mlp_cls=norm_mlp_cls,
205
- norm_mlp_kwargs=norm_mlp_kwargs,
206
- )
207
-
208
- self.n_freq = n_freq
209
- self.band_specs = band_specs
210
- self.in_channels = in_channels
211
-
212
- if freq_weights is not None and use_freq_weights:
213
- for i, fw in enumerate(freq_weights):
214
- self.register_buffer(f"freq_weights/{i}", fw)
215
-
216
- self.use_freq_weights = use_freq_weights
217
- else:
218
- self.use_freq_weights = False
219
-
220
- def forward(self, q):
221
- # q = (batch, n_bands, n_time, emb_dim)
222
-
223
- batch, n_bands, n_time, emb_dim = q.shape
224
-
225
- masks = torch.zeros(
226
- (batch, self.in_channels, self.n_freq, n_time),
227
- device=q.device,
228
- dtype=torch.complex64,
229
- )
230
-
231
- for im in range(n_bands):
232
- fstart, fend = self.band_specs[im]
233
-
234
- mask = self.compute_mask(q, im)
235
-
236
- if self.use_freq_weights:
237
- fw = self.get_buffer(f"freq_weights/{im}")[:, None]
238
- mask = mask * fw
239
- masks[:, :, fstart:fend, :] += mask
240
-
241
- return masks
242
-
243
-
244
- class MaskEstimationModule(OverlappingMaskEstimationModule):
245
- def __init__(
246
- self,
247
- band_specs: List[Tuple[float, float]],
248
- emb_dim: int,
249
- mlp_dim: int,
250
- in_channels: Optional[int],
251
- hidden_activation: str = "Tanh",
252
- hidden_activation_kwargs: Dict = None,
253
- complex_mask: bool = True,
254
- **kwargs,
255
- ) -> None:
256
- check_nonzero_bandwidth(band_specs)
257
- check_no_gap(band_specs)
258
- check_no_overlap(band_specs)
259
- super().__init__(
260
- in_channels=in_channels,
261
- band_specs=band_specs,
262
- freq_weights=None,
263
- n_freq=None,
264
- emb_dim=emb_dim,
265
- mlp_dim=mlp_dim,
266
- hidden_activation=hidden_activation,
267
- hidden_activation_kwargs=hidden_activation_kwargs,
268
- complex_mask=complex_mask,
269
- )
270
-
271
- def forward(self, q, cond=None):
272
- # q = (batch, n_bands, n_time, emb_dim)
273
-
274
- masks = self.compute_masks(
275
- q
276
- ) # [n_bands * (batch, in_channels, bandwidth, n_time)]
277
-
278
- # TODO: currently this requires band specs to have no gap and no overlap
279
- masks = torch.concat(masks, dim=2) # (batch, in_channels, n_freq, n_time)
280
-
281
- return masks
 
1
+ from typing import Dict, List, Optional, Tuple, Type
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.modules import activation
6
+ from torch.utils.checkpoint import checkpoint_sequential
7
+
8
+ from .utils import (
9
+ band_widths_from_specs,
10
+ check_no_gap,
11
+ check_no_overlap,
12
+ check_nonzero_bandwidth,
13
+ )
14
+
15
+
16
+ class BaseNormMLP(nn.Module):
17
+ def __init__(
18
+ self,
19
+ emb_dim: int,
20
+ mlp_dim: int,
21
+ bandwidth: int,
22
+ in_channels: Optional[int],
23
+ hidden_activation: str = "Tanh",
24
+ hidden_activation_kwargs=None,
25
+ complex_mask: bool = True,
26
+ ):
27
+ super().__init__()
28
+ if hidden_activation_kwargs is None:
29
+ hidden_activation_kwargs = {}
30
+ self.hidden_activation_kwargs = hidden_activation_kwargs
31
+ self.norm = nn.LayerNorm(emb_dim)
32
+ self.hidden = nn.Sequential(
33
+ nn.Linear(in_features=emb_dim, out_features=mlp_dim),
34
+ activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
35
+ )
36
+
37
+ self.bandwidth = bandwidth
38
+ self.in_channels = in_channels
39
+
40
+ self.complex_mask = complex_mask
41
+ self.reim = 2 if complex_mask else 1
42
+ self.glu_mult = 2
43
+
44
+
45
+ class NormMLP(BaseNormMLP):
46
+ def __init__(
47
+ self,
48
+ emb_dim: int,
49
+ mlp_dim: int,
50
+ bandwidth: int,
51
+ in_channels: Optional[int],
52
+ hidden_activation: str = "Tanh",
53
+ hidden_activation_kwargs=None,
54
+ complex_mask: bool = True,
55
+ ) -> None:
56
+ super().__init__(
57
+ emb_dim=emb_dim,
58
+ mlp_dim=mlp_dim,
59
+ bandwidth=bandwidth,
60
+ in_channels=in_channels,
61
+ hidden_activation=hidden_activation,
62
+ hidden_activation_kwargs=hidden_activation_kwargs,
63
+ complex_mask=complex_mask,
64
+ )
65
+
66
+ self.output = nn.Sequential(
67
+ nn.Linear(
68
+ in_features=mlp_dim,
69
+ out_features=bandwidth * in_channels * self.reim * 2,
70
+ ),
71
+ nn.GLU(dim=-1),
72
+ )
73
+
74
+ try:
75
+ self.combined = torch.compile(
76
+ nn.Sequential(self.norm, self.hidden, self.output), disable=True
77
+ )
78
+ except Exception as e:
79
+ self.combined = nn.Sequential(self.norm, self.hidden, self.output)
80
+
81
+ def reshape_output(self, mb):
82
+ batch, n_time, _ = mb.shape
83
+ if self.complex_mask:
84
+ mb = mb.reshape(
85
+ batch, n_time, self.in_channels, self.bandwidth, self.reim
86
+ ).contiguous()
87
+ mb = torch.view_as_complex(mb)
88
+ else:
89
+ mb = mb.reshape(batch, n_time, self.in_channels, self.bandwidth)
90
+
91
+ mb = torch.permute(mb, (0, 2, 3, 1))
92
+
93
+ return mb
94
+
95
+ def forward(self, qb):
96
+
97
+ mb = checkpoint_sequential(self.combined, 2, qb, use_reentrant=False)
98
+ mb = self.reshape_output(mb)
99
+
100
+ return mb
101
+
102
+
103
+ class MaskEstimationModuleSuperBase(nn.Module):
104
+ pass
105
+
106
+
107
+ class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
108
+ def __init__(
109
+ self,
110
+ band_specs: List[Tuple[float, float]],
111
+ emb_dim: int,
112
+ mlp_dim: int,
113
+ in_channels: Optional[int],
114
+ hidden_activation: str = "Tanh",
115
+ hidden_activation_kwargs: Dict = None,
116
+ complex_mask: bool = True,
117
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
118
+ norm_mlp_kwargs: Dict = None,
119
+ ) -> None:
120
+ super().__init__()
121
+
122
+ self.band_widths = band_widths_from_specs(band_specs)
123
+ self.n_bands = len(band_specs)
124
+
125
+ if hidden_activation_kwargs is None:
126
+ hidden_activation_kwargs = {}
127
+
128
+ if norm_mlp_kwargs is None:
129
+ norm_mlp_kwargs = {}
130
+
131
+ self.norm_mlp = nn.ModuleList(
132
+ [
133
+ norm_mlp_cls(
134
+ bandwidth=self.band_widths[b],
135
+ emb_dim=emb_dim,
136
+ mlp_dim=mlp_dim,
137
+ in_channels=in_channels,
138
+ hidden_activation=hidden_activation,
139
+ hidden_activation_kwargs=hidden_activation_kwargs,
140
+ complex_mask=complex_mask,
141
+ **norm_mlp_kwargs,
142
+ )
143
+ for b in range(self.n_bands)
144
+ ]
145
+ )
146
+
147
+ def compute_masks(self, q):
148
+ batch, n_bands, n_time, emb_dim = q.shape
149
+
150
+ masks = []
151
+
152
+ for b, nmlp in enumerate(self.norm_mlp):
153
+ qb = q[:, b, :, :]
154
+ mb = nmlp(qb)
155
+ masks.append(mb)
156
+
157
+ return masks
158
+
159
+ def compute_mask(self, q, b):
160
+ batch, n_bands, n_time, emb_dim = q.shape
161
+ qb = q[:, b, :, :]
162
+ mb = self.norm_mlp[b](qb)
163
+ return mb
164
+
165
+
166
+ class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
167
+ def __init__(
168
+ self,
169
+ in_channels: int,
170
+ band_specs: List[Tuple[float, float]],
171
+ freq_weights: List[torch.Tensor],
172
+ n_freq: int,
173
+ emb_dim: int,
174
+ mlp_dim: int,
175
+ cond_dim: int = 0,
176
+ hidden_activation: str = "Tanh",
177
+ hidden_activation_kwargs: Dict = None,
178
+ complex_mask: bool = True,
179
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
180
+ norm_mlp_kwargs: Dict = None,
181
+ use_freq_weights: bool = False,
182
+ ) -> None:
183
+ check_nonzero_bandwidth(band_specs)
184
+ check_no_gap(band_specs)
185
+
186
+ if cond_dim > 0:
187
+ raise NotImplementedError
188
+
189
+ super().__init__(
190
+ band_specs=band_specs,
191
+ emb_dim=emb_dim + cond_dim,
192
+ mlp_dim=mlp_dim,
193
+ in_channels=in_channels,
194
+ hidden_activation=hidden_activation,
195
+ hidden_activation_kwargs=hidden_activation_kwargs,
196
+ complex_mask=complex_mask,
197
+ norm_mlp_cls=norm_mlp_cls,
198
+ norm_mlp_kwargs=norm_mlp_kwargs,
199
+ )
200
+
201
+ self.n_freq = n_freq
202
+ self.band_specs = band_specs
203
+ self.in_channels = in_channels
204
+
205
+ if freq_weights is not None and use_freq_weights:
206
+ for i, fw in enumerate(freq_weights):
207
+ self.register_buffer(f"freq_weights/{i}", fw)
208
+
209
+ self.use_freq_weights = use_freq_weights
210
+ else:
211
+ self.use_freq_weights = False
212
+
213
+ def forward(self, q):
214
+
215
+ batch, n_bands, n_time, emb_dim = q.shape
216
+
217
+ masks = torch.zeros(
218
+ (batch, self.in_channels, self.n_freq, n_time),
219
+ device=q.device,
220
+ dtype=torch.complex64,
221
+ )
222
+
223
+ for im in range(n_bands):
224
+ fstart, fend = self.band_specs[im]
225
+
226
+ mask = self.compute_mask(q, im)
227
+
228
+ if self.use_freq_weights:
229
+ fw = self.get_buffer(f"freq_weights/{im}")[:, None]
230
+ mask = mask * fw
231
+ masks[:, :, fstart:fend, :] += mask
232
+
233
+ return masks
234
+
235
+
236
+ class MaskEstimationModule(OverlappingMaskEstimationModule):
237
+ def __init__(
238
+ self,
239
+ band_specs: List[Tuple[float, float]],
240
+ emb_dim: int,
241
+ mlp_dim: int,
242
+ in_channels: Optional[int],
243
+ hidden_activation: str = "Tanh",
244
+ hidden_activation_kwargs: Dict = None,
245
+ complex_mask: bool = True,
246
+ **kwargs,
247
+ ) -> None:
248
+ check_nonzero_bandwidth(band_specs)
249
+ check_no_gap(band_specs)
250
+ check_no_overlap(band_specs)
251
+ super().__init__(
252
+ in_channels=in_channels,
253
+ band_specs=band_specs,
254
+ freq_weights=None,
255
+ n_freq=None,
256
+ emb_dim=emb_dim,
257
+ mlp_dim=mlp_dim,
258
+ hidden_activation=hidden_activation,
259
+ hidden_activation_kwargs=hidden_activation_kwargs,
260
+ complex_mask=complex_mask,
261
+ )
262
+
263
+ def forward(self, q, cond=None):
264
+
265
+ masks = self.compute_masks(q)
266
+
267
+ masks = torch.concat(masks, dim=2)
268
+
269
+ return masks
 
 
 
 
 
 
 
 
 
 
 
 
mvsepless/models/bandit_v2/tfmodel.py CHANGED
@@ -1,145 +1,141 @@
1
- import warnings
2
-
3
- import torch
4
- import torch.backends.cuda
5
- from torch import nn
6
- from torch.nn.modules import rnn
7
- from torch.utils.checkpoint import checkpoint_sequential
8
-
9
-
10
- class TimeFrequencyModellingModule(nn.Module):
11
- def __init__(self) -> None:
12
- super().__init__()
13
-
14
-
15
- class ResidualRNN(nn.Module):
16
- def __init__(
17
- self,
18
- emb_dim: int,
19
- rnn_dim: int,
20
- bidirectional: bool = True,
21
- rnn_type: str = "LSTM",
22
- use_batch_trick: bool = True,
23
- use_layer_norm: bool = True,
24
- ) -> None:
25
- # n_group is the size of the 2nd dim
26
- super().__init__()
27
-
28
- assert use_layer_norm
29
- assert use_batch_trick
30
-
31
- self.use_layer_norm = use_layer_norm
32
- self.norm = nn.LayerNorm(emb_dim)
33
- self.rnn = rnn.__dict__[rnn_type](
34
- input_size=emb_dim,
35
- hidden_size=rnn_dim,
36
- num_layers=1,
37
- batch_first=True,
38
- bidirectional=bidirectional,
39
- )
40
-
41
- self.fc = nn.Linear(
42
- in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
43
- )
44
-
45
- self.use_batch_trick = use_batch_trick
46
- if not self.use_batch_trick:
47
- warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
48
-
49
- def forward(self, z):
50
- # z = (batch, n_uncrossed, n_across, emb_dim)
51
-
52
- z0 = torch.clone(z)
53
- z = self.norm(z)
54
-
55
- batch, n_uncrossed, n_across, emb_dim = z.shape
56
- z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
57
- z = self.rnn(z)[0]
58
- z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
59
-
60
- z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
61
-
62
- z = z + z0
63
-
64
- return z
65
-
66
-
67
- class Transpose(nn.Module):
68
- def __init__(self, dim0: int, dim1: int) -> None:
69
- super().__init__()
70
- self.dim0 = dim0
71
- self.dim1 = dim1
72
-
73
- def forward(self, z):
74
- return z.transpose(self.dim0, self.dim1)
75
-
76
-
77
- class SeqBandModellingModule(TimeFrequencyModellingModule):
78
- def __init__(
79
- self,
80
- n_modules: int = 12,
81
- emb_dim: int = 128,
82
- rnn_dim: int = 256,
83
- bidirectional: bool = True,
84
- rnn_type: str = "LSTM",
85
- parallel_mode=False,
86
- ) -> None:
87
- super().__init__()
88
-
89
- self.n_modules = n_modules
90
-
91
- if parallel_mode:
92
- self.seqband = nn.ModuleList([])
93
- for _ in range(n_modules):
94
- self.seqband.append(
95
- nn.ModuleList(
96
- [
97
- ResidualRNN(
98
- emb_dim=emb_dim,
99
- rnn_dim=rnn_dim,
100
- bidirectional=bidirectional,
101
- rnn_type=rnn_type,
102
- ),
103
- ResidualRNN(
104
- emb_dim=emb_dim,
105
- rnn_dim=rnn_dim,
106
- bidirectional=bidirectional,
107
- rnn_type=rnn_type,
108
- ),
109
- ]
110
- )
111
- )
112
- else:
113
- seqband = []
114
- for _ in range(2 * n_modules):
115
- seqband += [
116
- ResidualRNN(
117
- emb_dim=emb_dim,
118
- rnn_dim=rnn_dim,
119
- bidirectional=bidirectional,
120
- rnn_type=rnn_type,
121
- ),
122
- Transpose(1, 2),
123
- ]
124
-
125
- self.seqband = nn.Sequential(*seqband)
126
-
127
- self.parallel_mode = parallel_mode
128
-
129
- def forward(self, z):
130
- # z = (batch, n_bands, n_time, emb_dim)
131
-
132
- if self.parallel_mode:
133
- for sbm_pair in self.seqband:
134
- # z: (batch, n_bands, n_time, emb_dim)
135
- sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
136
- zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
137
- zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
138
- z = zt + zf.transpose(1, 2)
139
- else:
140
- z = checkpoint_sequential(
141
- self.seqband, self.n_modules, z, use_reentrant=False
142
- )
143
-
144
- q = z
145
- return q # (batch, n_bands, n_time, emb_dim)
 
1
+ import warnings
2
+
3
+ import torch
4
+ import torch.backends.cuda
5
+ from torch import nn
6
+ from torch.nn.modules import rnn
7
+ from torch.utils.checkpoint import checkpoint_sequential
8
+
9
+
10
+ class TimeFrequencyModellingModule(nn.Module):
11
+ def __init__(self) -> None:
12
+ super().__init__()
13
+
14
+
15
+ class ResidualRNN(nn.Module):
16
+ def __init__(
17
+ self,
18
+ emb_dim: int,
19
+ rnn_dim: int,
20
+ bidirectional: bool = True,
21
+ rnn_type: str = "LSTM",
22
+ use_batch_trick: bool = True,
23
+ use_layer_norm: bool = True,
24
+ ) -> None:
25
+ super().__init__()
26
+
27
+ assert use_layer_norm
28
+ assert use_batch_trick
29
+
30
+ self.use_layer_norm = use_layer_norm
31
+ self.norm = nn.LayerNorm(emb_dim)
32
+ self.rnn = rnn.__dict__[rnn_type](
33
+ input_size=emb_dim,
34
+ hidden_size=rnn_dim,
35
+ num_layers=1,
36
+ batch_first=True,
37
+ bidirectional=bidirectional,
38
+ )
39
+
40
+ self.fc = nn.Linear(
41
+ in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
42
+ )
43
+
44
+ self.use_batch_trick = use_batch_trick
45
+ if not self.use_batch_trick:
46
+ warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
47
+
48
+ def forward(self, z):
49
+
50
+ z0 = torch.clone(z)
51
+ z = self.norm(z)
52
+
53
+ batch, n_uncrossed, n_across, emb_dim = z.shape
54
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
55
+ z = self.rnn(z)[0]
56
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
57
+
58
+ z = self.fc(z)
59
+
60
+ z = z + z0
61
+
62
+ return z
63
+
64
+
65
+ class Transpose(nn.Module):
66
+ def __init__(self, dim0: int, dim1: int) -> None:
67
+ super().__init__()
68
+ self.dim0 = dim0
69
+ self.dim1 = dim1
70
+
71
+ def forward(self, z):
72
+ return z.transpose(self.dim0, self.dim1)
73
+
74
+
75
+ class SeqBandModellingModule(TimeFrequencyModellingModule):
76
+ def __init__(
77
+ self,
78
+ n_modules: int = 12,
79
+ emb_dim: int = 128,
80
+ rnn_dim: int = 256,
81
+ bidirectional: bool = True,
82
+ rnn_type: str = "LSTM",
83
+ parallel_mode=False,
84
+ ) -> None:
85
+ super().__init__()
86
+
87
+ self.n_modules = n_modules
88
+
89
+ if parallel_mode:
90
+ self.seqband = nn.ModuleList([])
91
+ for _ in range(n_modules):
92
+ self.seqband.append(
93
+ nn.ModuleList(
94
+ [
95
+ ResidualRNN(
96
+ emb_dim=emb_dim,
97
+ rnn_dim=rnn_dim,
98
+ bidirectional=bidirectional,
99
+ rnn_type=rnn_type,
100
+ ),
101
+ ResidualRNN(
102
+ emb_dim=emb_dim,
103
+ rnn_dim=rnn_dim,
104
+ bidirectional=bidirectional,
105
+ rnn_type=rnn_type,
106
+ ),
107
+ ]
108
+ )
109
+ )
110
+ else:
111
+ seqband = []
112
+ for _ in range(2 * n_modules):
113
+ seqband += [
114
+ ResidualRNN(
115
+ emb_dim=emb_dim,
116
+ rnn_dim=rnn_dim,
117
+ bidirectional=bidirectional,
118
+ rnn_type=rnn_type,
119
+ ),
120
+ Transpose(1, 2),
121
+ ]
122
+
123
+ self.seqband = nn.Sequential(*seqband)
124
+
125
+ self.parallel_mode = parallel_mode
126
+
127
+ def forward(self, z):
128
+
129
+ if self.parallel_mode:
130
+ for sbm_pair in self.seqband:
131
+ sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
132
+ zt = sbm_t(z)
133
+ zf = sbm_f(z.transpose(1, 2))
134
+ z = zt + zf.transpose(1, 2)
135
+ else:
136
+ z = checkpoint_sequential(
137
+ self.seqband, self.n_modules, z, use_reentrant=False
138
+ )
139
+
140
+ q = z
141
+ return q
 
 
 
 
mvsepless/models/bandit_v2/utils.py CHANGED
@@ -1,523 +1,384 @@
1
- import os
2
- from abc import abstractmethod
3
- from typing import Callable
4
-
5
- import numpy as np
6
- import torch
7
- from librosa import hz_to_midi, midi_to_hz
8
- from torchaudio import functional as taF
9
-
10
- # from spafe.fbanks import bark_fbanks
11
- # from spafe.utils.converters import erb2hz, hz2bark, hz2erb
12
-
13
-
14
- def band_widths_from_specs(band_specs):
15
- return [e - i for i, e in band_specs]
16
-
17
-
18
- def check_nonzero_bandwidth(band_specs):
19
- # pprint(band_specs)
20
- for fstart, fend in band_specs:
21
- if fend - fstart <= 0:
22
- raise ValueError("Bands cannot be zero-width")
23
-
24
-
25
- def check_no_overlap(band_specs):
26
- fend_prev = -1
27
- for fstart_curr, fend_curr in band_specs:
28
- if fstart_curr <= fend_prev:
29
- raise ValueError("Bands cannot overlap")
30
-
31
-
32
- def check_no_gap(band_specs):
33
- fstart, _ = band_specs[0]
34
- assert fstart == 0
35
-
36
- fend_prev = -1
37
- for fstart_curr, fend_curr in band_specs:
38
- if fstart_curr - fend_prev > 1:
39
- raise ValueError("Bands cannot leave gap")
40
- fend_prev = fend_curr
41
-
42
-
43
- class BandsplitSpecification:
44
- def __init__(self, nfft: int, fs: int) -> None:
45
- self.fs = fs
46
- self.nfft = nfft
47
- self.nyquist = fs / 2
48
- self.max_index = nfft // 2 + 1
49
-
50
- self.split500 = self.hertz_to_index(500)
51
- self.split1k = self.hertz_to_index(1000)
52
- self.split2k = self.hertz_to_index(2000)
53
- self.split4k = self.hertz_to_index(4000)
54
- self.split8k = self.hertz_to_index(8000)
55
- self.split16k = self.hertz_to_index(16000)
56
- self.split20k = self.hertz_to_index(20000)
57
-
58
- self.above20k = [(self.split20k, self.max_index)]
59
- self.above16k = [(self.split16k, self.split20k)] + self.above20k
60
-
61
- def index_to_hertz(self, index: int):
62
- return index * self.fs / self.nfft
63
-
64
- def hertz_to_index(self, hz: float, round: bool = True):
65
- index = hz * self.nfft / self.fs
66
-
67
- if round:
68
- index = int(np.round(index))
69
-
70
- return index
71
-
72
- def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
73
- band_specs = []
74
- lower = start_index
75
-
76
- while lower < end_index:
77
- upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
78
- upper = min(upper, end_index)
79
-
80
- band_specs.append((lower, upper))
81
- lower = upper
82
-
83
- return band_specs
84
-
85
- @abstractmethod
86
- def get_band_specs(self):
87
- raise NotImplementedError
88
-
89
-
90
- class VocalBandsplitSpecification(BandsplitSpecification):
91
- def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
92
- super().__init__(nfft=nfft, fs=fs)
93
-
94
- self.version = version
95
-
96
- def get_band_specs(self):
97
- return getattr(self, f"version{self.version}")()
98
-
99
- @property
100
- def version1(self):
101
- return self.get_band_specs_with_bandwidth(
102
- start_index=0, end_index=self.max_index, bandwidth_hz=1000
103
- )
104
-
105
- def version2(self):
106
- below16k = self.get_band_specs_with_bandwidth(
107
- start_index=0, end_index=self.split16k, bandwidth_hz=1000
108
- )
109
- below20k = self.get_band_specs_with_bandwidth(
110
- start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
111
- )
112
-
113
- return below16k + below20k + self.above20k
114
-
115
- def version3(self):
116
- below8k = self.get_band_specs_with_bandwidth(
117
- start_index=0, end_index=self.split8k, bandwidth_hz=1000
118
- )
119
- below16k = self.get_band_specs_with_bandwidth(
120
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
121
- )
122
-
123
- return below8k + below16k + self.above16k
124
-
125
- def version4(self):
126
- below1k = self.get_band_specs_with_bandwidth(
127
- start_index=0, end_index=self.split1k, bandwidth_hz=100
128
- )
129
- below8k = self.get_band_specs_with_bandwidth(
130
- start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
131
- )
132
- below16k = self.get_band_specs_with_bandwidth(
133
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
134
- )
135
-
136
- return below1k + below8k + below16k + self.above16k
137
-
138
- def version5(self):
139
- below1k = self.get_band_specs_with_bandwidth(
140
- start_index=0, end_index=self.split1k, bandwidth_hz=100
141
- )
142
- below16k = self.get_band_specs_with_bandwidth(
143
- start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
144
- )
145
- below20k = self.get_band_specs_with_bandwidth(
146
- start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
147
- )
148
- return below1k + below16k + below20k + self.above20k
149
-
150
- def version6(self):
151
- below1k = self.get_band_specs_with_bandwidth(
152
- start_index=0, end_index=self.split1k, bandwidth_hz=100
153
- )
154
- below4k = self.get_band_specs_with_bandwidth(
155
- start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
156
- )
157
- below8k = self.get_band_specs_with_bandwidth(
158
- start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
159
- )
160
- below16k = self.get_band_specs_with_bandwidth(
161
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
162
- )
163
- return below1k + below4k + below8k + below16k + self.above16k
164
-
165
- def version7(self):
166
- below1k = self.get_band_specs_with_bandwidth(
167
- start_index=0, end_index=self.split1k, bandwidth_hz=100
168
- )
169
- below4k = self.get_band_specs_with_bandwidth(
170
- start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
171
- )
172
- below8k = self.get_band_specs_with_bandwidth(
173
- start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
174
- )
175
- below16k = self.get_band_specs_with_bandwidth(
176
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
177
- )
178
- below20k = self.get_band_specs_with_bandwidth(
179
- start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
180
- )
181
- return below1k + below4k + below8k + below16k + below20k + self.above20k
182
-
183
-
184
- class OtherBandsplitSpecification(VocalBandsplitSpecification):
185
- def __init__(self, nfft: int, fs: int) -> None:
186
- super().__init__(nfft=nfft, fs=fs, version="7")
187
-
188
-
189
- class BassBandsplitSpecification(BandsplitSpecification):
190
- def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
191
- super().__init__(nfft=nfft, fs=fs)
192
-
193
- def get_band_specs(self):
194
- below500 = self.get_band_specs_with_bandwidth(
195
- start_index=0, end_index=self.split500, bandwidth_hz=50
196
- )
197
- below1k = self.get_band_specs_with_bandwidth(
198
- start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
199
- )
200
- below4k = self.get_band_specs_with_bandwidth(
201
- start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
202
- )
203
- below8k = self.get_band_specs_with_bandwidth(
204
- start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
205
- )
206
- below16k = self.get_band_specs_with_bandwidth(
207
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
208
- )
209
- above16k = [(self.split16k, self.max_index)]
210
-
211
- return below500 + below1k + below4k + below8k + below16k + above16k
212
-
213
-
214
- class DrumBandsplitSpecification(BandsplitSpecification):
215
- def __init__(self, nfft: int, fs: int) -> None:
216
- super().__init__(nfft=nfft, fs=fs)
217
-
218
- def get_band_specs(self):
219
- below1k = self.get_band_specs_with_bandwidth(
220
- start_index=0, end_index=self.split1k, bandwidth_hz=50
221
- )
222
- below2k = self.get_band_specs_with_bandwidth(
223
- start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
224
- )
225
- below4k = self.get_band_specs_with_bandwidth(
226
- start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
227
- )
228
- below8k = self.get_band_specs_with_bandwidth(
229
- start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
230
- )
231
- below16k = self.get_band_specs_with_bandwidth(
232
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
233
- )
234
- above16k = [(self.split16k, self.max_index)]
235
-
236
- return below1k + below2k + below4k + below8k + below16k + above16k
237
-
238
-
239
- class PerceptualBandsplitSpecification(BandsplitSpecification):
240
- def __init__(
241
- self,
242
- nfft: int,
243
- fs: int,
244
- fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
245
- n_bands: int,
246
- f_min: float = 0.0,
247
- f_max: float = None,
248
- ) -> None:
249
- super().__init__(nfft=nfft, fs=fs)
250
- self.n_bands = n_bands
251
- if f_max is None:
252
- f_max = fs / 2
253
-
254
- self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
255
-
256
- weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True) # (1, n_freqs)
257
- normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs)
258
-
259
- freq_weights = []
260
- band_specs = []
261
- for i in range(self.n_bands):
262
- active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
263
- if isinstance(active_bins, int):
264
- active_bins = (active_bins, active_bins)
265
- if len(active_bins) == 0:
266
- continue
267
- start_index = active_bins[0]
268
- end_index = active_bins[-1] + 1
269
- band_specs.append((start_index, end_index))
270
- freq_weights.append(normalized_mel_fb[i, start_index:end_index])
271
-
272
- self.freq_weights = freq_weights
273
- self.band_specs = band_specs
274
-
275
- def get_band_specs(self):
276
- return self.band_specs
277
-
278
- def get_freq_weights(self):
279
- return self.freq_weights
280
-
281
- def save_to_file(self, dir_path: str) -> None:
282
- os.makedirs(dir_path, exist_ok=True)
283
-
284
- import pickle
285
-
286
- with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
287
- pickle.dump(
288
- {
289
- "band_specs": self.band_specs,
290
- "freq_weights": self.freq_weights,
291
- "filterbank": self.filterbank,
292
- },
293
- f,
294
- )
295
-
296
-
297
- def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
298
- fb = taF.melscale_fbanks(
299
- n_mels=n_bands,
300
- sample_rate=fs,
301
- f_min=f_min,
302
- f_max=f_max,
303
- n_freqs=n_freqs,
304
- ).T
305
-
306
- fb[0, 0] = 1.0
307
-
308
- return fb
309
-
310
-
311
- class MelBandsplitSpecification(PerceptualBandsplitSpecification):
312
- def __init__(
313
- self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
314
- ) -> None:
315
- super().__init__(
316
- fbank_fn=mel_filterbank,
317
- nfft=nfft,
318
- fs=fs,
319
- n_bands=n_bands,
320
- f_min=f_min,
321
- f_max=f_max,
322
- )
323
-
324
-
325
- def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
326
- nfft = 2 * (n_freqs - 1)
327
- df = fs / nfft
328
- # init freqs
329
- f_max = f_max or fs / 2
330
- f_min = f_min or 0
331
- f_min = fs / nfft
332
-
333
- n_octaves = np.log2(f_max / f_min)
334
- n_octaves_per_band = n_octaves / n_bands
335
- bandwidth_mult = np.power(2.0, n_octaves_per_band)
336
-
337
- low_midi = max(0, hz_to_midi(f_min))
338
- high_midi = hz_to_midi(f_max)
339
- midi_points = np.linspace(low_midi, high_midi, n_bands)
340
- hz_pts = midi_to_hz(midi_points)
341
-
342
- low_pts = hz_pts / bandwidth_mult
343
- high_pts = hz_pts * bandwidth_mult
344
-
345
- low_bins = np.floor(low_pts / df).astype(int)
346
- high_bins = np.ceil(high_pts / df).astype(int)
347
-
348
- fb = np.zeros((n_bands, n_freqs))
349
-
350
- for i in range(n_bands):
351
- fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
352
-
353
- fb[0, : low_bins[0]] = 1.0
354
- fb[-1, high_bins[-1] + 1 :] = 1.0
355
-
356
- return torch.as_tensor(fb)
357
-
358
-
359
- class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
360
- def __init__(
361
- self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
362
- ) -> None:
363
- super().__init__(
364
- fbank_fn=musical_filterbank,
365
- nfft=nfft,
366
- fs=fs,
367
- n_bands=n_bands,
368
- f_min=f_min,
369
- f_max=f_max,
370
- )
371
-
372
-
373
- # def bark_filterbank(
374
- # n_bands, fs, f_min, f_max, n_freqs
375
- # ):
376
- # nfft = 2 * (n_freqs -1)
377
- # fb, _ = bark_fbanks.bark_filter_banks(
378
- # nfilts=n_bands,
379
- # nfft=nfft,
380
- # fs=fs,
381
- # low_freq=f_min,
382
- # high_freq=f_max,
383
- # scale="constant"
384
- # )
385
-
386
- # return torch.as_tensor(fb)
387
-
388
- # class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
389
- # def __init__(
390
- # self,
391
- # nfft: int,
392
- # fs: int,
393
- # n_bands: int,
394
- # f_min: float = 0.0,
395
- # f_max: float = None
396
- # ) -> None:
397
- # super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
398
-
399
-
400
- # def triangular_bark_filterbank(
401
- # n_bands, fs, f_min, f_max, n_freqs
402
- # ):
403
-
404
- # all_freqs = torch.linspace(0, fs // 2, n_freqs)
405
-
406
- # # calculate mel freq bins
407
- # m_min = hz2bark(f_min)
408
- # m_max = hz2bark(f_max)
409
-
410
- # m_pts = torch.linspace(m_min, m_max, n_bands + 2)
411
- # f_pts = 600 * torch.sinh(m_pts / 6)
412
-
413
- # # create filterbank
414
- # fb = _create_triangular_filterbank(all_freqs, f_pts)
415
-
416
- # fb = fb.T
417
-
418
- # first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
419
- # first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
420
-
421
- # fb[first_active_band, :first_active_bin] = 1.0
422
-
423
- # return fb
424
-
425
- # class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
426
- # def __init__(
427
- # self,
428
- # nfft: int,
429
- # fs: int,
430
- # n_bands: int,
431
- # f_min: float = 0.0,
432
- # f_max: float = None
433
- # ) -> None:
434
- # super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
435
-
436
-
437
- # def minibark_filterbank(
438
- # n_bands, fs, f_min, f_max, n_freqs
439
- # ):
440
- # fb = bark_filterbank(
441
- # n_bands,
442
- # fs,
443
- # f_min,
444
- # f_max,
445
- # n_freqs
446
- # )
447
-
448
- # fb[fb < np.sqrt(0.5)] = 0.0
449
-
450
- # return fb
451
-
452
- # class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
453
- # def __init__(
454
- # self,
455
- # nfft: int,
456
- # fs: int,
457
- # n_bands: int,
458
- # f_min: float = 0.0,
459
- # f_max: float = None
460
- # ) -> None:
461
- # super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
462
-
463
-
464
- # def erb_filterbank(
465
- # n_bands: int,
466
- # fs: int,
467
- # f_min: float,
468
- # f_max: float,
469
- # n_freqs: int,
470
- # ) -> Tensor:
471
- # # freq bins
472
- # A = (1000 * np.log(10)) / (24.7 * 4.37)
473
- # all_freqs = torch.linspace(0, fs // 2, n_freqs)
474
-
475
- # # calculate mel freq bins
476
- # m_min = hz2erb(f_min)
477
- # m_max = hz2erb(f_max)
478
-
479
- # m_pts = torch.linspace(m_min, m_max, n_bands + 2)
480
- # f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437
481
-
482
- # # create filterbank
483
- # fb = _create_triangular_filterbank(all_freqs, f_pts)
484
-
485
- # fb = fb.T
486
-
487
-
488
- # first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
489
- # first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
490
-
491
- # fb[first_active_band, :first_active_bin] = 1.0
492
-
493
- # return fb
494
-
495
-
496
- # class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
497
- # def __init__(
498
- # self,
499
- # nfft: int,
500
- # fs: int,
501
- # n_bands: int,
502
- # f_min: float = 0.0,
503
- # f_max: float = None
504
- # ) -> None:
505
- # super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
506
-
507
- if __name__ == "__main__":
508
- import pandas as pd
509
-
510
- band_defs = []
511
-
512
- for bands in [VocalBandsplitSpecification]:
513
- band_name = bands.__name__.replace("BandsplitSpecification", "")
514
-
515
- mbs = bands(nfft=2048, fs=44100).get_band_specs()
516
-
517
- for i, (f_min, f_max) in enumerate(mbs):
518
- band_defs.append(
519
- {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
520
- )
521
-
522
- df = pd.DataFrame(band_defs)
523
- df.to_csv("vox7bands.csv", index=False)
 
1
+ import os
2
+ from abc import abstractmethod
3
+ from typing import Callable
4
+
5
+ import numpy as np
6
+ import torch
7
+ from librosa import hz_to_midi, midi_to_hz
8
+ from torchaudio import functional as taF
9
+
10
+
11
+ def band_widths_from_specs(band_specs):
12
+ return [e - i for i, e in band_specs]
13
+
14
+
15
+ def check_nonzero_bandwidth(band_specs):
16
+ for fstart, fend in band_specs:
17
+ if fend - fstart <= 0:
18
+ raise ValueError("Bands cannot be zero-width")
19
+
20
+
21
+ def check_no_overlap(band_specs):
22
+ fend_prev = -1
23
+ for fstart_curr, fend_curr in band_specs:
24
+ if fstart_curr <= fend_prev:
25
+ raise ValueError("Bands cannot overlap")
26
+
27
+
28
+ def check_no_gap(band_specs):
29
+ fstart, _ = band_specs[0]
30
+ assert fstart == 0
31
+
32
+ fend_prev = -1
33
+ for fstart_curr, fend_curr in band_specs:
34
+ if fstart_curr - fend_prev > 1:
35
+ raise ValueError("Bands cannot leave gap")
36
+ fend_prev = fend_curr
37
+
38
+
39
+ class BandsplitSpecification:
40
+ def __init__(self, nfft: int, fs: int) -> None:
41
+ self.fs = fs
42
+ self.nfft = nfft
43
+ self.nyquist = fs / 2
44
+ self.max_index = nfft // 2 + 1
45
+
46
+ self.split500 = self.hertz_to_index(500)
47
+ self.split1k = self.hertz_to_index(1000)
48
+ self.split2k = self.hertz_to_index(2000)
49
+ self.split4k = self.hertz_to_index(4000)
50
+ self.split8k = self.hertz_to_index(8000)
51
+ self.split16k = self.hertz_to_index(16000)
52
+ self.split20k = self.hertz_to_index(20000)
53
+
54
+ self.above20k = [(self.split20k, self.max_index)]
55
+ self.above16k = [(self.split16k, self.split20k)] + self.above20k
56
+
57
+ def index_to_hertz(self, index: int):
58
+ return index * self.fs / self.nfft
59
+
60
+ def hertz_to_index(self, hz: float, round: bool = True):
61
+ index = hz * self.nfft / self.fs
62
+
63
+ if round:
64
+ index = int(np.round(index))
65
+
66
+ return index
67
+
68
+ def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
69
+ band_specs = []
70
+ lower = start_index
71
+
72
+ while lower < end_index:
73
+ upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
74
+ upper = min(upper, end_index)
75
+
76
+ band_specs.append((lower, upper))
77
+ lower = upper
78
+
79
+ return band_specs
80
+
81
+ @abstractmethod
82
+ def get_band_specs(self):
83
+ raise NotImplementedError
84
+
85
+
86
+ class VocalBandsplitSpecification(BandsplitSpecification):
87
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
88
+ super().__init__(nfft=nfft, fs=fs)
89
+
90
+ self.version = version
91
+
92
+ def get_band_specs(self):
93
+ return getattr(self, f"version{self.version}")()
94
+
95
+ @property
96
+ def version1(self):
97
+ return self.get_band_specs_with_bandwidth(
98
+ start_index=0, end_index=self.max_index, bandwidth_hz=1000
99
+ )
100
+
101
+ def version2(self):
102
+ below16k = self.get_band_specs_with_bandwidth(
103
+ start_index=0, end_index=self.split16k, bandwidth_hz=1000
104
+ )
105
+ below20k = self.get_band_specs_with_bandwidth(
106
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
107
+ )
108
+
109
+ return below16k + below20k + self.above20k
110
+
111
+ def version3(self):
112
+ below8k = self.get_band_specs_with_bandwidth(
113
+ start_index=0, end_index=self.split8k, bandwidth_hz=1000
114
+ )
115
+ below16k = self.get_band_specs_with_bandwidth(
116
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
117
+ )
118
+
119
+ return below8k + below16k + self.above16k
120
+
121
+ def version4(self):
122
+ below1k = self.get_band_specs_with_bandwidth(
123
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
124
+ )
125
+ below8k = self.get_band_specs_with_bandwidth(
126
+ start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
127
+ )
128
+ below16k = self.get_band_specs_with_bandwidth(
129
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
130
+ )
131
+
132
+ return below1k + below8k + below16k + self.above16k
133
+
134
+ def version5(self):
135
+ below1k = self.get_band_specs_with_bandwidth(
136
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
137
+ )
138
+ below16k = self.get_band_specs_with_bandwidth(
139
+ start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
140
+ )
141
+ below20k = self.get_band_specs_with_bandwidth(
142
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
143
+ )
144
+ return below1k + below16k + below20k + self.above20k
145
+
146
+ def version6(self):
147
+ below1k = self.get_band_specs_with_bandwidth(
148
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
149
+ )
150
+ below4k = self.get_band_specs_with_bandwidth(
151
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
152
+ )
153
+ below8k = self.get_band_specs_with_bandwidth(
154
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
155
+ )
156
+ below16k = self.get_band_specs_with_bandwidth(
157
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
158
+ )
159
+ return below1k + below4k + below8k + below16k + self.above16k
160
+
161
+ def version7(self):
162
+ below1k = self.get_band_specs_with_bandwidth(
163
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
164
+ )
165
+ below4k = self.get_band_specs_with_bandwidth(
166
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
167
+ )
168
+ below8k = self.get_band_specs_with_bandwidth(
169
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
170
+ )
171
+ below16k = self.get_band_specs_with_bandwidth(
172
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
173
+ )
174
+ below20k = self.get_band_specs_with_bandwidth(
175
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
176
+ )
177
+ return below1k + below4k + below8k + below16k + below20k + self.above20k
178
+
179
+
180
+ class OtherBandsplitSpecification(VocalBandsplitSpecification):
181
+ def __init__(self, nfft: int, fs: int) -> None:
182
+ super().__init__(nfft=nfft, fs=fs, version="7")
183
+
184
+
185
+ class BassBandsplitSpecification(BandsplitSpecification):
186
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
187
+ super().__init__(nfft=nfft, fs=fs)
188
+
189
+ def get_band_specs(self):
190
+ below500 = self.get_band_specs_with_bandwidth(
191
+ start_index=0, end_index=self.split500, bandwidth_hz=50
192
+ )
193
+ below1k = self.get_band_specs_with_bandwidth(
194
+ start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
195
+ )
196
+ below4k = self.get_band_specs_with_bandwidth(
197
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
198
+ )
199
+ below8k = self.get_band_specs_with_bandwidth(
200
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
201
+ )
202
+ below16k = self.get_band_specs_with_bandwidth(
203
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
204
+ )
205
+ above16k = [(self.split16k, self.max_index)]
206
+
207
+ return below500 + below1k + below4k + below8k + below16k + above16k
208
+
209
+
210
+ class DrumBandsplitSpecification(BandsplitSpecification):
211
+ def __init__(self, nfft: int, fs: int) -> None:
212
+ super().__init__(nfft=nfft, fs=fs)
213
+
214
+ def get_band_specs(self):
215
+ below1k = self.get_band_specs_with_bandwidth(
216
+ start_index=0, end_index=self.split1k, bandwidth_hz=50
217
+ )
218
+ below2k = self.get_band_specs_with_bandwidth(
219
+ start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
220
+ )
221
+ below4k = self.get_band_specs_with_bandwidth(
222
+ start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
223
+ )
224
+ below8k = self.get_band_specs_with_bandwidth(
225
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
226
+ )
227
+ below16k = self.get_band_specs_with_bandwidth(
228
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
229
+ )
230
+ above16k = [(self.split16k, self.max_index)]
231
+
232
+ return below1k + below2k + below4k + below8k + below16k + above16k
233
+
234
+
235
+ class PerceptualBandsplitSpecification(BandsplitSpecification):
236
+ def __init__(
237
+ self,
238
+ nfft: int,
239
+ fs: int,
240
+ fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
241
+ n_bands: int,
242
+ f_min: float = 0.0,
243
+ f_max: float = None,
244
+ ) -> None:
245
+ super().__init__(nfft=nfft, fs=fs)
246
+ self.n_bands = n_bands
247
+ if f_max is None:
248
+ f_max = fs / 2
249
+
250
+ self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
251
+
252
+ weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True)
253
+ normalized_mel_fb = self.filterbank / weight_per_bin
254
+
255
+ freq_weights = []
256
+ band_specs = []
257
+ for i in range(self.n_bands):
258
+ active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
259
+ if isinstance(active_bins, int):
260
+ active_bins = (active_bins, active_bins)
261
+ if len(active_bins) == 0:
262
+ continue
263
+ start_index = active_bins[0]
264
+ end_index = active_bins[-1] + 1
265
+ band_specs.append((start_index, end_index))
266
+ freq_weights.append(normalized_mel_fb[i, start_index:end_index])
267
+
268
+ self.freq_weights = freq_weights
269
+ self.band_specs = band_specs
270
+
271
+ def get_band_specs(self):
272
+ return self.band_specs
273
+
274
+ def get_freq_weights(self):
275
+ return self.freq_weights
276
+
277
+ def save_to_file(self, dir_path: str) -> None:
278
+ os.makedirs(dir_path, exist_ok=True)
279
+
280
+ import pickle
281
+
282
+ with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
283
+ pickle.dump(
284
+ {
285
+ "band_specs": self.band_specs,
286
+ "freq_weights": self.freq_weights,
287
+ "filterbank": self.filterbank,
288
+ },
289
+ f,
290
+ )
291
+
292
+
293
+ def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
294
+ fb = taF.melscale_fbanks(
295
+ n_mels=n_bands,
296
+ sample_rate=fs,
297
+ f_min=f_min,
298
+ f_max=f_max,
299
+ n_freqs=n_freqs,
300
+ ).T
301
+
302
+ fb[0, 0] = 1.0
303
+
304
+ return fb
305
+
306
+
307
+ class MelBandsplitSpecification(PerceptualBandsplitSpecification):
308
+ def __init__(
309
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
310
+ ) -> None:
311
+ super().__init__(
312
+ fbank_fn=mel_filterbank,
313
+ nfft=nfft,
314
+ fs=fs,
315
+ n_bands=n_bands,
316
+ f_min=f_min,
317
+ f_max=f_max,
318
+ )
319
+
320
+
321
+ def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
322
+ nfft = 2 * (n_freqs - 1)
323
+ df = fs / nfft
324
+ f_max = f_max or fs / 2
325
+ f_min = f_min or 0
326
+ f_min = fs / nfft
327
+
328
+ n_octaves = np.log2(f_max / f_min)
329
+ n_octaves_per_band = n_octaves / n_bands
330
+ bandwidth_mult = np.power(2.0, n_octaves_per_band)
331
+
332
+ low_midi = max(0, hz_to_midi(f_min))
333
+ high_midi = hz_to_midi(f_max)
334
+ midi_points = np.linspace(low_midi, high_midi, n_bands)
335
+ hz_pts = midi_to_hz(midi_points)
336
+
337
+ low_pts = hz_pts / bandwidth_mult
338
+ high_pts = hz_pts * bandwidth_mult
339
+
340
+ low_bins = np.floor(low_pts / df).astype(int)
341
+ high_bins = np.ceil(high_pts / df).astype(int)
342
+
343
+ fb = np.zeros((n_bands, n_freqs))
344
+
345
+ for i in range(n_bands):
346
+ fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
347
+
348
+ fb[0, : low_bins[0]] = 1.0
349
+ fb[-1, high_bins[-1] + 1 :] = 1.0
350
+
351
+ return torch.as_tensor(fb)
352
+
353
+
354
+ class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
355
+ def __init__(
356
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
357
+ ) -> None:
358
+ super().__init__(
359
+ fbank_fn=musical_filterbank,
360
+ nfft=nfft,
361
+ fs=fs,
362
+ n_bands=n_bands,
363
+ f_min=f_min,
364
+ f_max=f_max,
365
+ )
366
+
367
+
368
+ if __name__ == "__main__":
369
+ import pandas as pd
370
+
371
+ band_defs = []
372
+
373
+ for bands in [VocalBandsplitSpecification]:
374
+ band_name = bands.__name__.replace("BandsplitSpecification", "")
375
+
376
+ mbs = bands(nfft=2048, fs=44100).get_band_specs()
377
+
378
+ for i, (f_min, f_max) in enumerate(mbs):
379
+ band_defs.append(
380
+ {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
381
+ )
382
+
383
+ df = pd.DataFrame(band_defs)
384
+ df.to_csv("vox7bands.csv", index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvsepless/models/bs_roformer/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
- from .bs_roformer import BSRoformer
2
- from .bs_roformer_sw import BSRoformer_SW
3
- from .bs_roformer_fno import BSRoformer_FNO
4
- from .bs_roformer_hyperace import BSRoformerHyperACE
5
- from .bs_roformer_hyperace2 import BSRoformerHyperACE_2
6
- from .mel_band_roformer import MelBandRoformer
 
1
+ from .bs_roformer import BSRoformer
2
+ from .bs_roformer_sw import BSRoformer_SW
3
+ from .bs_roformer_fno import BSRoformer_FNO
4
+ from .bs_roformer_hyperace import BSRoformerHyperACE
5
+ from .bs_roformer_hyperace2 import BSRoformerHyperACE_2
6
+ from .mel_band_roformer import MelBandRoformer
mvsepless/models/bs_roformer/attend.py CHANGED
@@ -1,126 +1,120 @@
1
- from functools import wraps
2
- from packaging import version
3
- from collections import namedtuple
4
-
5
- import os
6
- import torch
7
- from torch import nn, einsum
8
- import torch.nn.functional as F
9
-
10
- from einops import rearrange, reduce
11
-
12
- # constants
13
-
14
- FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
15
-
16
- # helpers
17
-
18
- def exists(val):
19
- return val is not None
20
-
21
- def default(v, d):
22
- return v if exists(v) else d
23
-
24
- def once(fn):
25
- called = False
26
- @wraps(fn)
27
- def inner(x):
28
- nonlocal called
29
- if called:
30
- return
31
- called = True
32
- return fn(x)
33
- return inner
34
-
35
- print_once = once(print)
36
-
37
- # main class
38
-
39
- class Attend(nn.Module):
40
- def __init__(
41
- self,
42
- dropout = 0.,
43
- flash = False,
44
- scale = None
45
- ):
46
- super().__init__()
47
- self.scale = scale
48
- self.dropout = dropout
49
- self.attn_dropout = nn.Dropout(dropout)
50
-
51
- self.flash = flash
52
- assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
53
-
54
- # determine efficient attention configs for cuda and cpu
55
-
56
- self.cpu_config = FlashAttentionConfig(True, True, True)
57
- self.cuda_config = None
58
-
59
- if not torch.cuda.is_available() or not flash:
60
- return
61
-
62
- device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
63
- device_version = version.parse(f'{device_properties.major}.{device_properties.minor}')
64
-
65
- if device_version >= version.parse('8.0'):
66
- if os.name == 'nt':
67
- print_once('Windows OS detected, using math or mem efficient attention if input tensor is on cuda')
68
- self.cuda_config = FlashAttentionConfig(False, True, True)
69
- else:
70
- print_once('GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda')
71
- self.cuda_config = FlashAttentionConfig(True, False, False)
72
- else:
73
- print_once('GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda')
74
- self.cuda_config = FlashAttentionConfig(False, True, True)
75
-
76
- def flash_attn(self, q, k, v):
77
- _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
78
-
79
- if exists(self.scale):
80
- default_scale = q.shape[-1] ** -0.5
81
- q = q * (self.scale / default_scale)
82
-
83
- # Check if there is a compatible device for flash attention
84
-
85
- config = self.cuda_config if is_cuda else self.cpu_config
86
-
87
- # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
88
-
89
- with torch.backends.cuda.sdp_kernel(**config._asdict()):
90
- out = F.scaled_dot_product_attention(
91
- q, k, v,
92
- dropout_p = self.dropout if self.training else 0.
93
- )
94
-
95
- return out
96
-
97
- def forward(self, q, k, v):
98
- """
99
- einstein notation
100
- b - batch
101
- h - heads
102
- n, i, j - sequence length (base sequence length, source, target)
103
- d - feature dimension
104
- """
105
-
106
- q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
107
-
108
- scale = default(self.scale, q.shape[-1] ** -0.5)
109
-
110
- if self.flash:
111
- return self.flash_attn(q, k, v)
112
-
113
- # similarity
114
-
115
- sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
116
-
117
- # attention
118
-
119
- attn = sim.softmax(dim=-1)
120
- attn = self.attn_dropout(attn)
121
-
122
- # aggregate values
123
-
124
- out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
125
-
126
- return out
 
1
+ from functools import wraps
2
+ from packaging import version
3
+ from collections import namedtuple
4
+
5
+ import os
6
+ import torch
7
+ from torch import nn, einsum
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange, reduce
11
+
12
+
13
+ FlashAttentionConfig = namedtuple(
14
+ "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]
15
+ )
16
+
17
+
18
+ def exists(val):
19
+ return val is not None
20
+
21
+
22
+ def default(v, d):
23
+ return v if exists(v) else d
24
+
25
+
26
+ def once(fn):
27
+ called = False
28
+
29
+ @wraps(fn)
30
+ def inner(x):
31
+ nonlocal called
32
+ if called:
33
+ return
34
+ called = True
35
+ return fn(x)
36
+
37
+ return inner
38
+
39
+
40
+ print_once = once(print)
41
+
42
+
43
+ class Attend(nn.Module):
44
+ def __init__(self, dropout=0.0, flash=False, scale=None):
45
+ super().__init__()
46
+ self.scale = scale
47
+ self.dropout = dropout
48
+ self.attn_dropout = nn.Dropout(dropout)
49
+
50
+ self.flash = flash
51
+ assert not (
52
+ flash and version.parse(torch.__version__) < version.parse("2.0.0")
53
+ ), "in order to use flash attention, you must be using pytorch 2.0 or above"
54
+
55
+ self.cpu_config = FlashAttentionConfig(True, True, True)
56
+ self.cuda_config = None
57
+
58
+ if not torch.cuda.is_available() or not flash:
59
+ return
60
+
61
+ device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
62
+ device_version = version.parse(
63
+ f"{device_properties.major}.{device_properties.minor}"
64
+ )
65
+
66
+ if device_version >= version.parse("8.0"):
67
+ if os.name == "nt":
68
+ print_once(
69
+ "Windows OS detected, using math or mem efficient attention if input tensor is on cuda"
70
+ )
71
+ self.cuda_config = FlashAttentionConfig(False, True, True)
72
+ else:
73
+ print_once(
74
+ "GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda"
75
+ )
76
+ self.cuda_config = FlashAttentionConfig(True, False, False)
77
+ else:
78
+ print_once(
79
+ "GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda"
80
+ )
81
+ self.cuda_config = FlashAttentionConfig(False, True, True)
82
+
83
+ def flash_attn(self, q, k, v):
84
+ _, heads, q_len, _, k_len, is_cuda, device = (
85
+ *q.shape,
86
+ k.shape[-2],
87
+ q.is_cuda,
88
+ q.device,
89
+ )
90
+
91
+ if exists(self.scale):
92
+ default_scale = q.shape[-1] ** -0.5
93
+ q = q * (self.scale / default_scale)
94
+
95
+ config = self.cuda_config if is_cuda else self.cpu_config
96
+
97
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
98
+ out = F.scaled_dot_product_attention(
99
+ q, k, v, dropout_p=self.dropout if self.training else 0.0
100
+ )
101
+
102
+ return out
103
+
104
+ def forward(self, q, k, v):
105
+
106
+ q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
107
+
108
+ scale = default(self.scale, q.shape[-1] ** -0.5)
109
+
110
+ if self.flash:
111
+ return self.flash_attn(q, k, v)
112
+
113
+ sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
114
+
115
+ attn = sim.softmax(dim=-1)
116
+ attn = self.attn_dropout(attn)
117
+
118
+ out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
119
+
120
+ return out