noblebarkrr commited on
Commit
4a26913
·
verified ·
1 Parent(s): e6fcb9f

Upload 101 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. mvsepless/__init__.py +1579 -0
  2. mvsepless/__main__.py +67 -0
  3. mvsepless/audio.py +781 -0
  4. mvsepless/downloader.py +92 -0
  5. mvsepless/ensemble.py +224 -0
  6. mvsepless/infer.py +623 -0
  7. mvsepless/infer_utils.py +382 -0
  8. mvsepless/model_manager.py +540 -0
  9. mvsepless/models.json +0 -0
  10. mvsepless/models/bandit/core/__init__.py +691 -0
  11. mvsepless/models/bandit/core/data/__init__.py +2 -0
  12. mvsepless/models/bandit/core/data/_types.py +17 -0
  13. mvsepless/models/bandit/core/data/augmentation.py +102 -0
  14. mvsepless/models/bandit/core/data/augmented.py +34 -0
  15. mvsepless/models/bandit/core/data/base.py +60 -0
  16. mvsepless/models/bandit/core/data/dnr/__init__.py +0 -0
  17. mvsepless/models/bandit/core/data/dnr/datamodule.py +68 -0
  18. mvsepless/models/bandit/core/data/dnr/dataset.py +366 -0
  19. mvsepless/models/bandit/core/data/dnr/preprocess.py +51 -0
  20. mvsepless/models/bandit/core/data/musdb/__init__.py +0 -0
  21. mvsepless/models/bandit/core/data/musdb/datamodule.py +75 -0
  22. mvsepless/models/bandit/core/data/musdb/dataset.py +273 -0
  23. mvsepless/models/bandit/core/data/musdb/preprocess.py +226 -0
  24. mvsepless/models/bandit/core/data/musdb/validation.yaml +15 -0
  25. mvsepless/models/bandit/core/loss/__init__.py +8 -0
  26. mvsepless/models/bandit/core/loss/_complex.py +27 -0
  27. mvsepless/models/bandit/core/loss/_multistem.py +43 -0
  28. mvsepless/models/bandit/core/loss/_timefreq.py +95 -0
  29. mvsepless/models/bandit/core/loss/snr.py +139 -0
  30. mvsepless/models/bandit/core/metrics/__init__.py +9 -0
  31. mvsepless/models/bandit/core/metrics/_squim.py +443 -0
  32. mvsepless/models/bandit/core/metrics/snr.py +127 -0
  33. mvsepless/models/bandit/core/model/__init__.py +3 -0
  34. mvsepless/models/bandit/core/model/_spectral.py +54 -0
  35. mvsepless/models/bandit/core/model/bsrnn/__init__.py +23 -0
  36. mvsepless/models/bandit/core/model/bsrnn/bandsplit.py +135 -0
  37. mvsepless/models/bandit/core/model/bsrnn/core.py +651 -0
  38. mvsepless/models/bandit/core/model/bsrnn/maskestim.py +351 -0
  39. mvsepless/models/bandit/core/model/bsrnn/tfmodel.py +320 -0
  40. mvsepless/models/bandit/core/model/bsrnn/utils.py +525 -0
  41. mvsepless/models/bandit/core/model/bsrnn/wrapper.py +829 -0
  42. mvsepless/models/bandit/core/utils/__init__.py +0 -0
  43. mvsepless/models/bandit/core/utils/audio.py +412 -0
  44. mvsepless/models/bandit/model_from_config.py +26 -0
  45. mvsepless/models/bandit_v2/bandit.py +363 -0
  46. mvsepless/models/bandit_v2/bandsplit.py +130 -0
  47. mvsepless/models/bandit_v2/film.py +23 -0
  48. mvsepless/models/bandit_v2/maskestim.py +281 -0
  49. mvsepless/models/bandit_v2/tfmodel.py +145 -0
  50. mvsepless/models/bandit_v2/utils.py +523 -0
mvsepless/__init__.py ADDED
@@ -0,0 +1,1579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import shutil
4
+ import logging
5
+ import zipfile
6
+ import importlib.util
7
+ from pathlib import Path
8
+ logging.basicConfig(level=logging.WARNING)
9
+
10
+ script_dir = os.path.dirname(os.path.abspath(__file__))
11
+ os.chdir(script_dir)
12
+
13
+ if not __package__:
14
+ from model_manager import MvseplessModelManager
15
+ from audio import Audio, Inverter
16
+ from namer import Namer
17
+ from vbach_infer import vbach_inference, model_manager as VbachModel
18
+ from downloader import dw_file, dw_yt_dlp
19
+ from ensemble import ensemble_audio_files
20
+ else:
21
+ from .model_manager import MvseplessModelManager
22
+ from .audio import Audio, Inverter
23
+ from .namer import Namer
24
+ from .vbach_infer import vbach_inference, model_manager as VbachModel
25
+ from .downloader import dw_file, dw_yt_dlp
26
+ from .ensemble import ensemble_audio_files
27
+ from typing import Literal
28
+ import gradio as gr
29
+ import pandas as pd
30
+ import subprocess
31
+ import json
32
+ import threading
33
+ import queue
34
+ import time
35
+ import argparse
36
+ from datetime import datetime
37
+ import tempfile
38
+ import ast
39
+
40
+ class MVSEPLESS:
41
+ audio = Audio()
42
+ inverter = Inverter()
43
+ namer = Namer()
44
+ model_manager = MvseplessModelManager()
45
+ vbach_model_manager = VbachModel
46
+
47
+ class Separator(MVSEPLESS):
48
+
49
+ class OutputReader:
50
+ def __init__(self, debug=False):
51
+ self.debug = debug
52
+
53
+ def parse_json_line(self, line):
54
+ try:
55
+ return json.loads(line)
56
+ except json.JSONDecodeError:
57
+ return None
58
+
59
+ def reaction_line(self, line, progress, add_text):
60
+ _add_text = ""
61
+ if add_text != "" or add_text is not None:
62
+ _add_text = f"| {add_text}"
63
+
64
+ data = self.parse_json_line(line)
65
+ if data is None:
66
+ return None
67
+ elif "reading" in data:
68
+ progress(0.05, desc=f"Чтение файла {_add_text}")
69
+ print("Чтение файла")
70
+ return None
71
+ elif "processing" in data:
72
+ progress_a = data["processing"]
73
+ processed = progress_a.get("processed", 0)
74
+ total = progress_a.get("total", 1)
75
+ # Исправлено: убрано деление на ноль
76
+ if total > 0:
77
+ progress_ratio = min(0.89, 0.05 + (processed / total * 0.85)) # Оставляем место для этапа записи
78
+ percent = int((processed / total) * 100)
79
+ progress(progress_ratio, desc=f"Обработано: {percent}% {_add_text}")
80
+ print(f"\rОбработано: {percent}%", end="")
81
+ return None
82
+ elif "writing" in data:
83
+ progress(0.9, desc="Запись результатов")
84
+ print(f"\rЗапись в файл {data['writing']}", end="")
85
+ return None
86
+ elif "done" in data:
87
+ progress(1.0, desc=f"Завершено {_add_text}")
88
+ print("\rЗавершено", end="\n")
89
+ return data["done"]
90
+ elif "error" in data:
91
+ raise Exception(data["error"])
92
+
93
+ def read_stream_to_queue(self, stream, queue_obj, stream_name):
94
+ """Чтение потока вывода подпроцесса и запись в очередь"""
95
+ try:
96
+ for line in iter(stream.readline, ''):
97
+ line = line.strip()
98
+ if line:
99
+ if self.debug:
100
+ print(f"[{stream_name}] {line}") # Отладочный вывод
101
+ queue_obj.put(line)
102
+ stream.close()
103
+ except Exception as e:
104
+ print(f"Error reading {stream_name}: {e}")
105
+
106
+ output_reader = OutputReader()
107
+
108
+ def separator_model_loader(self, model_type: str, model_name: str, mdx_denoise: bool, vr_aggr: bool, progress) -> tuple[int, str, str]:
109
+
110
+ if model_type in [
111
+ "mel_band_roformer",
112
+ "bs_roformer",
113
+ "mdx23c",
114
+ "mdxnet",
115
+ "vr",
116
+ "scnet",
117
+ "htdemucs",
118
+ "bandit",
119
+ "bandit_v2",
120
+ ]:
121
+ info = self.model_manager.models_info[model_type].get(model_name, None)
122
+ if not info:
123
+ raise ValueError(f"Модель {model_name} не найдена для типа {model_type}")
124
+
125
+ id = self.model_manager.get_id(model_type, model_name)
126
+ conf, ckpt = self.model_manager.download_model(
127
+ self.model_manager.models_cache_dir,
128
+ model_name,
129
+ model_type,
130
+ info["checkpoint_url"],
131
+ info["config_url"],
132
+ )
133
+ if model_type != "htdemucs":
134
+ self.model_manager.conf_editor(conf, mdx_denoise, vr_aggr, model_type)
135
+
136
+ return id, conf, ckpt
137
+
138
+ else:
139
+ raise ValueError("Неподдерживаемый тип модели")
140
+
141
+ def separator_base(
142
+ self,
143
+ input_file: str,
144
+ output_dir: str,
145
+ model_type: Literal[
146
+ "mel_band_roformer",
147
+ "bs_roformer",
148
+ "mdx23c",
149
+ "mdxnet",
150
+ "scnet",
151
+ "htdemucs",
152
+ "bandit",
153
+ "bandit_v2",
154
+ ] = "mel_band_roformer",
155
+ model_name: str = "Mel-Band-Roformer_Vocals_kimberley_jensen",
156
+ ext_inst: bool = True,
157
+ output_format: Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] = "mp3",
158
+ output_bitrate: str = "320k",
159
+ template: str = "NAME_(STEM)_MODEL",
160
+ selected_stems: list = None,
161
+ ckpt: str = None,
162
+ conf: str = None,
163
+ id: int = None,
164
+ progress: any = gr.Progress(track_tqdm=True),
165
+ add_text_progress: str = ""
166
+ ) -> list[tuple[str, str]]:
167
+
168
+ if model_type in [
169
+ "mel_band_roformer",
170
+ "bs_roformer",
171
+ "mdx23c",
172
+ "mdxnet",
173
+ "vr",
174
+ "scnet",
175
+ "htdemucs",
176
+ "bandit",
177
+ "bandit_v2",
178
+ ]:
179
+
180
+ cmd = [
181
+ os.sys.executable,
182
+ "-m",
183
+ "infer",
184
+ "--input", input_file,
185
+ "--store_dir", output_dir,
186
+ "--model_type", model_type,
187
+ "--model_name", model_name,
188
+ "--model_id", str(id),
189
+ "--config_path", conf,
190
+ "--start_check_point", ckpt,
191
+ "--output_format", output_format,
192
+ "--output_bitrate", str(output_bitrate),
193
+ "--template", template
194
+ ]
195
+ if ext_inst:
196
+ cmd.append("--extract_instrumental")
197
+ if selected_stems:
198
+ cmd.append("--selected_instruments")
199
+ cmd.extend(selected_stems)
200
+
201
+ try:
202
+ process = subprocess.Popen(
203
+ cmd,
204
+ stdout=subprocess.PIPE,
205
+ stderr=subprocess.PIPE,
206
+ text=True,
207
+ bufsize=1,
208
+ universal_newlines=True,
209
+ encoding='utf-8',
210
+ errors='replace'
211
+ )
212
+
213
+ # Создаем очереди для stdout и stderr
214
+ stdout_queue = queue.Queue()
215
+ stderr_queue = queue.Queue()
216
+
217
+ # Запускаем потоки для чтения stdout и stderr
218
+ stdout_thread = threading.Thread(
219
+ target=self.output_reader.read_stream_to_queue,
220
+ args=(process.stdout, stdout_queue, "stdout")
221
+ )
222
+ stderr_thread = threading.Thread(
223
+ target=self.output_reader.read_stream_to_queue,
224
+ args=(process.stderr, stderr_queue, "stderr")
225
+ )
226
+
227
+ stdout_thread.daemon = True
228
+ stderr_thread.daemon = True
229
+
230
+ stdout_thread.start()
231
+ stderr_thread.start()
232
+
233
+ results = {'output': None, 'error': None}
234
+ process_completed = False
235
+
236
+ # Основной цикл обработки сообщений
237
+ while not process_completed:
238
+ # Проверяем завершение процесса
239
+ if process.poll() is not None:
240
+ process_completed = True
241
+
242
+ # Обрабатываем сообщения из stdout
243
+ try:
244
+ stdout_line = stdout_queue.get_nowait()
245
+ result = self.output_reader.reaction_line(stdout_line, progress, add_text_progress)
246
+ if result is not None:
247
+ results['output'] = result
248
+ break
249
+ except queue.Empty:
250
+ pass
251
+
252
+ # Обрабатываем сообщения из stderr
253
+ try:
254
+ stderr_line = stderr_queue.get_nowait()
255
+ result = self.output_reader.reaction_line(stderr_line, progress, add_text_progress)
256
+ if result is not None:
257
+ results['output'] = result
258
+ break
259
+ except queue.Empty:
260
+ pass
261
+
262
+ # Если процесс еще работает, ждем немного перед следующей проверкой
263
+ if not process_completed:
264
+ time.sleep(0.1)
265
+
266
+ # Дополнительная обработка оставшихся сообщений после завершения процесса
267
+ for _ in range(10): # Проверяем несколько раз на случай задержек
268
+ try:
269
+ stdout_line = stdout_queue.get_nowait()
270
+ result = self.output_reader.reaction_line(stdout_line, progress, add_text_progress)
271
+ if result is not None:
272
+ results['output'] = result
273
+ break
274
+ except queue.Empty:
275
+ pass
276
+
277
+ try:
278
+ stderr_line = stderr_queue.get_nowait()
279
+ result = self.output_reader.reaction_line(stderr_line, progress, add_text_progress)
280
+ if result is not None:
281
+ results['output'] = result
282
+ break
283
+ except queue.Empty:
284
+ pass
285
+
286
+ time.sleep(0.1)
287
+
288
+ # Проверяем результаты
289
+ if results.get('error'):
290
+ raise Exception(results['error'])
291
+
292
+ if results.get('output'):
293
+ return results['output']
294
+
295
+ # Если процесс завершился с ошибкой
296
+ if process.returncode != 0:
297
+ # Пытаемся получить последние сообщения об ошибках
298
+ error_messages = []
299
+ try:
300
+ while True:
301
+ error_msg = stderr_queue.get_nowait()
302
+ error_messages.append(error_msg)
303
+ except queue.Empty:
304
+ pass
305
+
306
+ error_text = "\n".join(error_messages[-5:]) # Последние 5 сообщений
307
+ raise Exception(f"Процесс завершился с ошибкой. Код возврата: {process.returncode}. Сообщения об ошибках:\n{error_text}")
308
+
309
+ except Exception as e:
310
+ raise e
311
+ finally:
312
+ # Гарантируем завершение процесса
313
+ try:
314
+ if process.poll() is None:
315
+ process.terminate()
316
+ process.wait(timeout=5)
317
+ except:
318
+ try:
319
+ process.kill()
320
+ except:
321
+ pass
322
+ else:
323
+ raise ValueError("Неподдерживаемый тип модели")
324
+
325
+ def separate(
326
+ self,
327
+ input: str | list = None,
328
+ output_dir: str = None,
329
+ model_type: Literal[
330
+ "mel_band_roformer",
331
+ "bs_roformer",
332
+ "mdx23c",
333
+ "mdxnet",
334
+ "vr",
335
+ "scnet",
336
+ "htdemucs",
337
+ "bandit",
338
+ "bandit_v2",
339
+ ] = "mel_band_roformer",
340
+ model_name: str = "Mel-Band-Roformer_Vocals_kimberley_jensen",
341
+ ext_inst: bool = True,
342
+ output_format: Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] = "mp3",
343
+ output_bitrate: str = "320k",
344
+ template: str = "NAME_(STEM)_MODEL",
345
+ selected_stems: list = None,
346
+ add_settings: dict = {"mdx_denoise": False, "vr_aggr": 5, "add_single_sep_text_progress": None},
347
+ progress: any = gr.Progress(track_tqdm=True)
348
+ ) -> list[tuple[str, str]] | list[str, list[tuple[str, str]]]:
349
+
350
+ progress(0, desc="Начало обработки")
351
+
352
+ # Валидация параметров
353
+ if output_format not in self.audio.output_formats:
354
+ output_format = "flac"
355
+
356
+ if output_dir is None:
357
+ output_dir = os.getcwd()
358
+
359
+ if output_dir:
360
+ output_dir = os.path.abspath(output_dir)
361
+
362
+ if selected_stems is None:
363
+ selected_stems = []
364
+
365
+ if not input:
366
+ raise ValueError("Входной файл не указан")
367
+
368
+ if "STEM" not in template and template is not None:
369
+ template = template + "_STEM_"
370
+ if not template:
371
+ template = "mvsepless_NAME_(STEM)"
372
+
373
+ os.makedirs(output_dir, exist_ok=True)
374
+
375
+ mdx_denoise = add_settings.get("mdx_denoise", False)
376
+
377
+ vr_aggr = add_settings.get("vr_aggr", 5)
378
+
379
+ add_progress_text_custom = add_settings.get("add_single_sep_text_progress", "")
380
+
381
+ id, conf, ckpt = self.separator_model_loader(model_type, model_name, mdx_denoise, vr_aggr, progress)
382
+
383
+ if isinstance(input, str):
384
+ if not os.path.exists(input):
385
+ raise ValueError(f"Входной файл не найден: {input}")
386
+
387
+ if not self.audio.check(input):
388
+ raise ValueError("Входной файл не содержит аудио")
389
+
390
+ basename = os.path.splitext(os.path.basename(input))[0]
391
+ seped = self.separator_base(input_file=input,
392
+ output_dir=output_dir,
393
+ model_type=model_type,
394
+ model_name=model_name,
395
+ ext_inst=ext_inst,
396
+ output_format=output_format,
397
+ output_bitrate=output_bitrate,
398
+ template=template,
399
+ selected_stems=selected_stems,
400
+ ckpt=ckpt,
401
+ conf=conf,
402
+ id=id,
403
+ progress=progress,
404
+ add_text_progress=add_progress_text_custom)
405
+ return seped
406
+
407
+ elif isinstance(input, list):
408
+ results = []
409
+ for i, f in enumerate(input, 1):
410
+ print(f"Файл {i} из {len(input)}: {f}")
411
+ if os.path.exists(f):
412
+ if self.audio.check(f):
413
+ basename = os.path.splitext(os.path.basename(f))[0]
414
+ seped = self.separator_base(input_file=f,
415
+ output_dir=output_dir,
416
+ model_type=model_type,
417
+ model_name=model_name,
418
+ ext_inst=ext_inst,
419
+ output_format=output_format,
420
+ output_bitrate=output_bitrate,
421
+ template=template,
422
+ selected_stems=selected_stems,
423
+ ckpt=ckpt,
424
+ conf=conf,
425
+ id=id,
426
+ progress=progress,
427
+ add_text_progress=f"({i} из {len(input)}) ")
428
+ results.append([basename, seped])
429
+ return results
430
+
431
+
432
+ def UI(self, output_base_dir, output_temp_dir_check):
433
+ default_mts = self.model_manager.get_mt()
434
+ default_mt = self.model_manager.get_mt()[0]
435
+ default_mns = self.model_manager.get_mn(default_mt)
436
+ default_mn = default_mns[0]
437
+ default_stems = self.model_manager.get_stems(default_mt, default_mn)
438
+ default_tgt_inst = self.model_manager.get_tgt_inst(default_mt, default_mn)
439
+ with gr.Blocks():
440
+ with gr.Row():
441
+ with gr.Column():
442
+ with gr.Group(visible=False) as add_inputs:
443
+ input_path = gr.Textbox(label="Путь к входному файлу", interactive=True)
444
+ add_inputs_btn = gr.Button("Добавить файл", variant="primary")
445
+ with gr.Group(visible=False) as add_inputs_from_url:
446
+ input_url = gr.Textbox(label="URL входного файла", interactive=True)
447
+ with gr.Row():
448
+ inputs_url_format = gr.Dropdown(label="Формат входного файла", interactive=True,
449
+ choices=self.audio.output_formats,
450
+ value="mp3", filterable=False)
451
+ inputs_url_bitrate = gr.Slider(label="Битрейт входного файла", minimum=64, maximum=512, step=32, value=320, interactive=True)
452
+ with gr.Row():
453
+ inputs_url_cookie = gr.UploadButton(label="Файл cookie (необязательно)", interactive=True, type="filepath", file_count="single", file_types=[".txt", ".cookies"], variant="secondary")
454
+ add_inputs_url_btn = gr.Button("Добавить файл", variant="primary")
455
+ with gr.Row(visible=True) as add_buttons_row:
456
+ add_path_btn = gr.Button("Добавить файл по пути", variant="secondary")
457
+ add_url_btn = gr.Button("Добавить файл по URL", variant="secondary")
458
+ with gr.Group():
459
+ input_audio = gr.File(label="Входные аудио", interactive=True, type="filepath", file_count="multiple", file_types=[f".{of}" for of in self.audio.input_formats])
460
+ sep_state = gr.Textbox(label="Состояние разделения", interactive=False, value="", visible=False)
461
+ status = gr.Textbox(container=False, lines=3, interactive=False, max_lines=3)
462
+ input_preview_check = gr.Checkbox(label="Показать плееры для входных аудио", interactive=True, value=False)
463
+ @gr.render(inputs=[input_preview_check, input_audio])
464
+ def show_input_players(preview, audios):
465
+ if preview:
466
+ if audios:
467
+ with gr.Group():
468
+ for file in audios:
469
+ gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath")
470
+ with gr.Column():
471
+ with gr.Group():
472
+ with gr.Row():
473
+ model_type = gr.Dropdown(label="Тип модели", interactive=True, filterable=False,
474
+ choices=default_mts,
475
+ value=default_mt)
476
+ model_name = gr.Dropdown(label="Имя модели", interactive=True, filterable=False,
477
+ choices=default_mns, value=default_mn)
478
+ with gr.Group():
479
+ extract_instrumental = gr.Checkbox(label="Извлечь инструментал", interactive=True, value=True)
480
+ selected_stems = gr.CheckboxGroup(label="Выбранные стемы для разделения", interactive=False,
481
+ choices=default_stems, value=[])
482
+
483
+ with gr.Accordion(label="Дополнительные настройки", open=False):
484
+ vr_aggr_slider = gr.Slider(label="Сила подавления для VR моделей", minimum=-100, maximum=100, value=5, step=1)
485
+ mdx_denoise_check = gr.Checkbox(label="Включить шумоподавление для MDX-NET моделей (это повышает потребление памяти в два раза)", value=False)
486
+ with gr.Row():
487
+ output_format = gr.Dropdown(label="Формат выходного файла", interactive=True,
488
+ choices=self.audio.output_formats,
489
+ value="mp3", filterable=False)
490
+ output_bitrate = gr.Slider(label="Битрейт выходного файла", minimum=64, maximum=512, step=32, value=320, interactive=True)
491
+
492
+ template = gr.Textbox(label="Шаблон именования выходных файлов", interactive=True, value="NAME (STEM) MODEL", info="Используйте ключи: \nNAME - имя входного файла без расширения, \nSTEM - имя стема, \nMODEL - имя модели разделения")
493
+ separate_btn = gr.Button("Разделить", variant="primary")
494
+
495
+ @gr.render(inputs=[sep_state], triggers=[sep_state.change])
496
+ def players(state):
497
+ def create_archive_advanced(file_list, archive_name="archive.zip"):
498
+ """
499
+ Создает архив с расширенной обработкой ошибок
500
+ """
501
+ try:
502
+ print("Генерация ZIP-архива с результатами разделения...")
503
+ with zipfile.ZipFile(archive_name, 'w', zipfile.ZIP_DEFLATED) as zipf:
504
+ successful_files = 0
505
+
506
+ for basename, stems in file_list:
507
+ for stem_name, stem_path in stems:
508
+ try:
509
+ if os.path.exists(stem_path) and os.path.isfile(stem_path):
510
+ basename_ = os.path.basename(stem_path)
511
+ zipf.write(stem_path, basename_)
512
+ successful_files += 1
513
+ print(f"✓ Добавлен: {stem_path} -> {basename}")
514
+ else:
515
+ print(f"✗ Файл не найден или не является файлом: {stem_path}")
516
+
517
+ except Exception as e:
518
+ print(f"✗ Ошибка при добавлении {stem_path}: {e}")
519
+
520
+ print(f"\nАрхив создан: {archive_name}")
521
+ print(f"Успешно добавлено файлов: {successful_files}")
522
+ return os.path.abspath(archive_name)
523
+
524
+ except Exception as e:
525
+ print(f"Ошибка при создании архива: {e}")
526
+ if state != "":
527
+ state_loaded = ast.literal_eval(state)
528
+ archive_stems = create_archive_advanced(state_loaded, os.path.join(tempfile.tempdir, f"mvsepless_output_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip"))
529
+ for basename, stems in state_loaded:
530
+ with gr.Group():
531
+ gr.Markdown(f"<h4><center>{basename}</center></h4>")
532
+ for stem_name, stem_path in stems:
533
+ with gr.Row(equal_height=True):
534
+ output_stem = gr.Audio(value=stem_path, label=stem_name, type="filepath", interactive=False, show_download_button=True, scale=15)
535
+ reuse_btn = gr.Button("Использовать снова", variant="secondary")
536
+ @reuse_btn.click(
537
+ inputs=[output_stem, input_audio],
538
+ outputs=input_audio
539
+ )
540
+ def reuse_fn(stem_audio, input_a):
541
+ if input_a is None:
542
+ input_a = []
543
+ if isinstance(input_a, str):
544
+ input_a = [input_a]
545
+ if os.path.exists(stem_audio):
546
+ if self.audio.check(stem_audio):
547
+ input_a.append(stem_audio)
548
+ return input_a
549
+
550
+ gr.DownloadButton(label="Скачать как ZIP", value=archive_stems, interactive=True)
551
+
552
+ @add_inputs_btn.click(
553
+ inputs=[input_path, input_audio],
554
+ outputs=[add_inputs, input_audio, add_buttons_row])
555
+ def add_inputs_fn(input_p, input_a):
556
+ if input_p and os.path.exists(input_p):
557
+ if input_a is None:
558
+ input_a = []
559
+ if isinstance(input_a, str):
560
+ input_a = [input_a]
561
+ if self.audio.check(input_p):
562
+ input_a.append(input_p)
563
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
564
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
565
+
566
+ @add_inputs_url_btn.click(
567
+ inputs=[input_url, input_audio, inputs_url_format, inputs_url_bitrate, inputs_url_cookie],
568
+ outputs=[add_inputs_from_url, input_audio, add_buttons_row])
569
+ def add_inputs_from_url_fn(input_u, input_a, fmt, br, cookie):
570
+ if input_u:
571
+ if input_a is None:
572
+ input_a = []
573
+ if isinstance(input_a, str):
574
+ input_a = [input_a]
575
+ downloaded_file = dw_yt_dlp(
576
+ url=input_u,
577
+ output_format=fmt,
578
+ output_bitrate=str(int(br)),
579
+ cookie=cookie
580
+ )
581
+ if downloaded_file and os.path.exists(downloaded_file):
582
+ if self.audio.check(downloaded_file):
583
+ input_a.append(downloaded_file)
584
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
585
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
586
+
587
+ add_path_btn.click(
588
+ lambda: (gr.update(visible=True), gr.update(visible=False)),
589
+ outputs=[add_inputs, add_buttons_row])
590
+
591
+ add_url_btn.click(
592
+ lambda: (gr.update(visible=True), gr.update(visible=False)),
593
+ outputs=[add_inputs_from_url, add_buttons_row])
594
+
595
+ model_type.change(lambda x: gr.update(choices=self.model_manager.get_mn(x), value=self.model_manager.get_mn(x)[0]),
596
+ inputs=model_type, outputs=model_name)
597
+ model_name.change(lambda mt, mn: (gr.update(choices=self.model_manager.get_stems(mt, mn), value=[], interactive=False if self.model_manager.get_tgt_inst(mt, mn) else True), gr.update(value=True if self.model_manager.get_tgt_inst(mt, mn) else False)), inputs=[model_type, model_name], outputs=[selected_stems, extract_instrumental])
598
+ output_format.change(lambda x: gr.update(visible=False if x in ["wav", "flac", "aiff"] else True), inputs=output_format, outputs=output_bitrate)
599
+ inputs_url_format.change(lambda x: gr.update(visible=False if x in ["wav", "flac", "aiff"] else True), inputs=inputs_url_format, outputs=inputs_url_bitrate)
600
+ @separate_btn.click(
601
+ inputs=[
602
+ input_audio, model_type, model_name,
603
+ extract_instrumental, output_format, output_bitrate,
604
+ template, selected_stems, output_base_dir, output_temp_dir_check, mdx_denoise_check, vr_aggr_slider
605
+ ],
606
+ outputs=[sep_state, status],
607
+ show_progress="full"
608
+ )
609
+ def wrap(i, mt, mn, ei, of, ob, t, stems, o_dir, temp_save, mdx_denoise, vr_aggr, progress=gr.Progress(track_tqdm=True)):
610
+ if o_dir.strip() != "" and not temp_save:
611
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
612
+ o = os.path.join(o_dir, f"mvsepless_outputs_{timestamp}")
613
+ os.makedirs(o, exist_ok=True)
614
+ else:
615
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
616
+ o = tempfile.mkdtemp(prefix=f"mvsepless_outputs_{timestamp}_")
617
+ os.makedirs(o, exist_ok=True)
618
+ results = self.separate(i, o, mt, mn, ei, of, ob, t, stems, add_settings={"mdx_denoise": mdx_denoise, "vr_aggr": int(vr_aggr)}, progress=progress)
619
+ return str(results), ""
620
+
621
+ class AutoEnsembless(Separator):
622
+
623
+ class ModelManager(MVSEPLESS):
624
+ def __init__(self):
625
+ self.data = []
626
+ self.ensemble_methods = ("min_fft", "max_fft", "avg_fft", "median_fft")
627
+ self.ensemble_invert_methods_map = {"min_fft": "max_fft", "max_fft": "min_fft", "avg_fft": "avg_fft", "median_fft": "median_fft"}
628
+ self.dir_presets = os.path.join(tempfile.tempdir, "presets")
629
+ os.makedirs(self.dir_presets, exist_ok=True)
630
+
631
+ def save(self, name):
632
+ if not name:
633
+ name = "ensembless_preset"
634
+ filepath = os.path.join(self.dir_presets, f"{self.namer.short(self.namer.sanitize(name), length=50)}.json")
635
+ with open(filepath, "w") as f:
636
+ json.dump(self.data, f, indent=4, ensure_ascii=False)
637
+ return filepath
638
+
639
+ def load(self, filepath):
640
+ with open(filepath, "r") as f:
641
+ ensemble_data_temp = json.load(f)
642
+ self.data = []
643
+ for (mt, mn, s_stem, i_stem, weight) in ensemble_data_temp:
644
+ if {mt, mn} not in [{model[0], model[1]} for model in self.data]:
645
+ self.data.append((mt, mn, s_stem, i_stem, weight))
646
+
647
+ def add(self, mt, mn, s_stem, i_stem, weight):
648
+ if {mt, mn} not in [{model[0], model[1]} for model in self.data]:
649
+ if s_stem and i_stem:
650
+ self.data.append((mt, mn, s_stem, i_stem, weight))
651
+
652
+ def replace(self, mt, mn, s_stem, i_stem, weight, index=1):
653
+ if self.data:
654
+ len_data = len(self.data)
655
+ if index >= 1:
656
+ if index <= len_data:
657
+ self.data[index - 1] = (mt, mn, s_stem, i_stem, weight)
658
+ elif index == 0:
659
+ self.data[0] = (mt, mn, s_stem, i_stem, weight)
660
+
661
+ def remove(self, index=1):
662
+ if self.data:
663
+ len_data = len(self.data)
664
+ if index >= 1:
665
+ if index <= len_data:
666
+ del self.data[index - 1]
667
+ elif index == 0:
668
+ del self.data[0]
669
+
670
+ def clear(self):
671
+ self.data = []
672
+
673
+ def get_df(self):
674
+ if not self.data:
675
+ columns = ["#", "Имя модели", "Основной стем", "Инверсия", "Вес"]
676
+ return pd.DataFrame(columns=columns)
677
+
678
+ data = []
679
+ for i, model in enumerate(self.data):
680
+ data.append(
681
+ [
682
+ f"{i+1}",
683
+ model[1],
684
+ model[2],
685
+ model[3],
686
+ model[4],
687
+ ]
688
+ )
689
+ columns = ["#", "Имя модели", "Основной стем", "Инверсия", "Вес"]
690
+ return pd.DataFrame(data, columns=columns)
691
+
692
+ def UI(self, output_base_dir, output_temp_dir_check):
693
+ ensemble_model_manager = self.ModelManager()
694
+ def get_stems(mt, mn):
695
+ stems = []
696
+ for stem in self.model_manager.get_stems(mt, mn):
697
+ stems.append(stem)
698
+
699
+ if not self.model_manager.get_tgt_inst(mt, mn):
700
+ if set(stems) == {"bass", "drums", "other", "vocals"} or set(stems) == {"bass", "drums", "other", "vocals", "piano", "guitar"}:
701
+ stems.append("instrumental +")
702
+ stems.append("instrumental -")
703
+
704
+ return stems
705
+
706
+ def get_invert_stems(mt, mn, s_stem):
707
+ orig_stems = []
708
+ stems = []
709
+ for stem in self.model_manager.get_stems(mt, mn):
710
+ orig_stems.append(stem)
711
+
712
+ for stem in orig_stems:
713
+ if stem != s_stem:
714
+ stems.append(stem)
715
+
716
+ if not self.model_manager.get_tgt_inst(mt, mn):
717
+ if len(orig_stems) > 2:
718
+ if s_stem not in ["instrumental +", "instrumental -"]:
719
+ stems.append("inverted +")
720
+ stems.append("inverted -")
721
+
722
+ return stems
723
+
724
+ default_model = {
725
+ "mt": self.model_manager.get_mt(),
726
+ "mn": self.model_manager.get_mn(self.model_manager.get_mt()[0]),
727
+ "stem": get_stems(
728
+ self.model_manager.get_mt()[0],
729
+ self.model_manager.get_mn(self.model_manager.get_mt()[0])[0],
730
+ ),
731
+ "invert_stem": get_invert_stems(
732
+ self.model_manager.get_mt()[0],
733
+ self.model_manager.get_mn(self.model_manager.get_mt()[0])[0],
734
+ "vocals",
735
+ ),
736
+ "weight": 1,
737
+ }
738
+
739
+ gr.Markdown("<h3>Пресет</h3>")
740
+ with gr.Group():
741
+ with gr.Row(equal_height=True):
742
+ export_preset_name = gr.Textbox(
743
+ label="Имя пресета",
744
+ interactive=True,
745
+ value="ensembless_preset", scale=9
746
+ )
747
+ export_btn = gr.DownloadButton("Экспорт", variant="secondary", scale=3, interactive=True)
748
+ import_btn = gr.UploadButton(
749
+ "Импорт", file_types=[".json"], file_count="single", scale=3, interactive=True
750
+ )
751
+ gr.Markdown("<h3>Ансамбль</h3>")
752
+ with gr.Row():
753
+ with gr.Column(scale=3): # логика добавлеия моделей
754
+ model_type = gr.Dropdown(label="Тип модели", choices=default_model["mt"], value=default_model["mt"][0], interactive=True, filterable=False)
755
+ model_name = gr.Dropdown(label="Имя модели", choices=default_model["mn"], value=default_model["mn"][0], interactive=True, filterable=False)
756
+ primary_stem = gr.Dropdown(label="Основной стем", choices=default_model["stem"], value=default_model["stem"][0], interactive=True, filterable=False)
757
+ secondary_stem = gr.Dropdown(label="Инверсия", choices=default_model["invert_stem"], value=default_model["invert_stem"][0], interactive=True, filterable=False)
758
+ weight = gr.Slider(label="Вес", minimum=0, maximum=10, step=0.01, value=1, interactive=True)
759
+ @model_type.change(
760
+ inputs=[model_type],
761
+ outputs=[model_name]
762
+ )
763
+ def update_model_names(mt):
764
+ model_names = self.model_manager.get_mn(mt)
765
+ new_mn = model_names[0] if model_names else ""
766
+
767
+ return gr.update(choices=model_names, value=new_mn)
768
+ @model_name.change(
769
+ inputs=[model_type, model_name],
770
+ outputs=[primary_stem, secondary_stem]
771
+ )
772
+ def update_stems_after_model_change(mt, mn):
773
+ stems = get_stems(mt, mn)
774
+ invert_stems = get_invert_stems(mt, mn, stems[0]) if stems else []
775
+
776
+ new_s_stem = stems[0] if stems else ""
777
+ new_i_stem = invert_stems[0] if invert_stems else ""
778
+
779
+ return (
780
+ gr.update(choices=stems, value=new_s_stem),
781
+ gr.update(choices=invert_stems, value=new_i_stem)
782
+ )
783
+ @primary_stem.change(
784
+ inputs=[model_type, model_name, primary_stem],
785
+ outputs=[secondary_stem]
786
+ )
787
+ def update_invert_stems(mt, mn, s_stem):
788
+ stems = get_invert_stems(mt, mn, s_stem)
789
+ new_i_stem = stems[0] if stems else ""
790
+ return gr.update(choices=stems, value=new_i_stem)
791
+
792
+ model_add_button = gr.Button("Добавить", interactive=True)
793
+ with gr.Column(scale=10):
794
+ df = gr.DataFrame(
795
+ value=ensemble_model_manager.get_df(),
796
+ headers=["#", "Имя модели", "Основной стем", "Инверсия", "Вес"],
797
+ datatype=["number", "str", "str", "str", "number"],
798
+ interactive=False
799
+ )
800
+
801
+ with gr.Group():
802
+ with gr.Row(equal_height=True):
803
+ with gr.Column():
804
+ model_index = gr.Number(label="Индекс модели", value=1, interactive=True)
805
+ model_clear_btn = gr.Button("Очистить", variant="stop", interactive=True)
806
+ with gr.Column():
807
+ model_replace_btn = gr.Button("Заменить", variant="primary", interactive=True)
808
+ model_delete_btn = gr.Button("Удалить", variant="stop", interactive=True)
809
+
810
+ @model_add_button.click(
811
+ inputs=[model_type, model_name, primary_stem, secondary_stem, weight],
812
+ outputs=df
813
+ )
814
+ def add_model_to_auto_ensemble(mt, mn, s_stem, i_stem, weight):
815
+ ensemble_model_manager.add(mt, mn, s_stem, i_stem, weight)
816
+ return ensemble_model_manager.get_df()
817
+
818
+ @model_replace_btn.click(
819
+ inputs=[model_type, model_name, primary_stem, secondary_stem, weight, model_index],
820
+ outputs=df
821
+ )
822
+ def replace_model_to_auto_ensemble(mt, mn, s_stem, i_stem, weight, index):
823
+ ensemble_model_manager.replace(mt, mn, s_stem, i_stem, weight, index)
824
+ return ensemble_model_manager.get_df()
825
+
826
+ @model_delete_btn.click(
827
+ inputs=[model_index],
828
+ outputs=df
829
+ )
830
+ def delete_model_to_auto_ensemble(index):
831
+ ensemble_model_manager.remove(index)
832
+ return ensemble_model_manager.get_df()
833
+
834
+ @model_clear_btn.click(
835
+ outputs=df
836
+ )
837
+ def clear_model_to_auto_ensemble():
838
+ ensemble_model_manager.clear()
839
+ return ensemble_model_manager.get_df()
840
+
841
+ gr.on(fn=ensemble_model_manager.get_df, outputs=df)
842
+
843
+ df.change(
844
+ fn=ensemble_model_manager.save,
845
+ inputs=export_preset_name,
846
+ outputs=export_btn
847
+ )
848
+
849
+ export_preset_name.change(
850
+ fn=ensemble_model_manager.save,
851
+ inputs=export_preset_name,
852
+ outputs=export_btn
853
+ )
854
+
855
+ @import_btn.upload(
856
+ inputs=import_btn,
857
+ outputs=df
858
+ )
859
+ def load_ensemble_preset(filepath):
860
+ ensemble_model_manager.load(filepath)
861
+ return ensemble_model_manager.get_df()
862
+ with gr.Row():
863
+ with gr.Column():
864
+ gr.Markdown("<h3>Входное аудио</h3>")
865
+ with gr.Group():
866
+ with gr.Group(visible=False) as add_inputs:
867
+ input_path = gr.Textbox(label="Путь к входному файлу", interactive=True)
868
+ add_inputs_btn = gr.Button("Загрузить файл", variant="primary")
869
+ with gr.Group(visible=False) as add_inputs_from_url:
870
+ input_url = gr.Textbox(label="URL входного файла", interactive=True)
871
+ with gr.Row():
872
+ inputs_url_format = gr.Dropdown(label="Формат входного файла", interactive=True,
873
+ choices=self.audio.output_formats,
874
+ value="mp3", filterable=False)
875
+ inputs_url_bitrate = gr.Slider(label="Битрейт входного файла", minimum=64, maximum=512, step=32, value=320, interactive=True)
876
+ with gr.Row():
877
+ inputs_url_cookie = gr.UploadButton(label="Файл cookie (необязательно)", interactive=True, type="filepath", file_count="single", file_types=[".txt", ".cookies"], variant="secondary")
878
+ add_inputs_url_btn = gr.Button("Загрузить файл", variant="primary")
879
+ with gr.Row(visible=True) as add_buttons_row:
880
+ add_path_btn = gr.Button("Загрузить файл по пути", variant="secondary")
881
+ add_url_btn = gr.Button("Загрузить файл по URL", variant="secondary")
882
+ with gr.Group():
883
+ input_audio = gr.File(label="Входное аудио", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.audio.input_formats])
884
+ with gr.Column():
885
+ gr.Markdown("<h3>Настройки</h3>")
886
+ with gr.Group():
887
+ method = gr.Dropdown(label="Алгоритм склеивания", choices=["min_fft", "max_fft", "avg_fft", "median_fft"], value="avg_fft", filterable=False)
888
+ invert_ensemble = gr.Checkbox(label="Инверсия ансамбля", interactive=True, value=False)
889
+ output_format = gr.Dropdown(label="Формат выходного файла", interactive=True,
890
+ choices=self.audio.output_formats,
891
+ value="mp3", filterable=False)
892
+ run_btn = gr.Button("Создать ансамбль", variant="primary", interactive=True)
893
+
894
+ with gr.Row():
895
+ with gr.Column():
896
+ gr.Markdown("<h3>Результаты</h3>")
897
+ output_audio = gr.Audio(label="Результат", type="filepath", interactive=False, show_download_button=True)
898
+ output_audio_wav = gr.Textbox(label="Результат в WAV", interactive=False, visible=False)
899
+ with gr.Group():
900
+ invert_method = gr.Radio(
901
+ choices=["waveform", "spectrogram"],
902
+ label="Метод создания инверсии",
903
+ value="waveform",
904
+ )
905
+ invert_btn = gr.Button("Инвертировать")
906
+ output_inverted_audio = gr.Audio(label="Инверсия", type="filepath", interactive=False, show_download_button=True)
907
+ @invert_btn.click(inputs=[input_audio, output_audio_wav, invert_method, output_format], outputs=[output_inverted_audio])
908
+ def invert_result_ensemble(input_file, output_file, method, out_format):
909
+ if input_file and output_file:
910
+ o_dir = os.path.dirname(output_file)
911
+ basename = os.path.splitext(os.path.basename(input_file))[0]
912
+ output_path = os.path.join(o_dir, f"ensembless_{self.namer.short(basename, length=50)}_{method}_invert.{out_format}")
913
+ inverted = self.inverter.process_audio(audio1_path=input_file, audio2_path=output_file, out_format=out_format, method=method, output_path=output_path)
914
+ return inverted
915
+ else:
916
+ return None
917
+
918
+ with gr.Column():
919
+ gr.Markdown("<h3>Исходники ансамбля (WAV)</h3>")
920
+ output_source_files = gr.Files(type="filepath", interactive=False, show_label=False)
921
+ output_source_preview_check = gr.Checkbox(label="Показать плееры для исходников ансамбля", interactive=True, value=False)
922
+ @gr.render(inputs=[output_source_preview_check, output_source_files])
923
+ def show_output_auto_ensemble_players(preview, audios):
924
+ if preview:
925
+ if audios:
926
+ with gr.Group():
927
+ for file in audios:
928
+ gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath")
929
+
930
+ @run_btn.click(
931
+ inputs=[input_audio, method, output_format, invert_ensemble, output_base_dir, output_temp_dir_check],
932
+ outputs=[output_audio, output_audio_wav, output_inverted_audio, output_source_files]
933
+ )
934
+ def auto_ensemble_run(input_file, method, out_format, invert_ensemble, o_dir, temp_save, progress=gr.Progress(track_tqdm=True)):
935
+ ensemble_state = ensemble_model_manager.data
936
+ invert_methods_map = ensemble_model_manager.ensemble_invert_methods_map
937
+ if not input_file:
938
+ return None, None, None, None, []
939
+ if not os.path.exists(input_file):
940
+ return None, None, None, None, []
941
+ if not self.audio.check(input_file):
942
+ return None, None, None, None, []
943
+ if o_dir.strip() != "" and not temp_save:
944
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
945
+ o = os.path.join(o_dir, f"ensembless_outputs_{timestamp}")
946
+ os.makedirs(o, exist_ok=True)
947
+ else:
948
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
949
+ o = tempfile.mkdtemp(prefix=f"ensembless_outputs_{timestamp}_")
950
+ os.makedirs(o, exist_ok=True)
951
+
952
+ basename = os.path.splitext(os.path.basename(input_file))[0]
953
+ def invert_weights(weights):
954
+ total_weight = sum(weights)
955
+ return [total_weight - w for w in weights]
956
+ # print(json.dumps(ensemble_state, indent=4, ensure_ascii=False))
957
+ success_separations = [] # list[tuple[str, str, float]] [(s_stem, i_stem, weight)]
958
+ ensemble_sources_list = [] # list[str, str, str, ...]
959
+ if ensemble_state:
960
+ total_ensemble_models = len(ensemble_state)
961
+ for i, model in enumerate(ensemble_state, start=1):
962
+
963
+ ens_mt = model[0]
964
+ ens_mn = model[1]
965
+ ens_s_stem = model[2]
966
+ ens_i_stem = model[3]
967
+ weight = model[4]
968
+
969
+ s_stem = None #path to primary stem
970
+ i_stem = None #path to invert stem
971
+
972
+ try:
973
+ result_seped_auto_ensemble = self.separate(input=input_file, output_dir=os.path.join(o, ens_mn), model_type=ens_mt, model_name=ens_mn, ext_inst=True, template="NAME - MODEL - STEM", output_format="wav", add_settings={"add_single_sep_text_progress": f"{i} из {total_ensemble_models}"}, progress=progress)
974
+ if result_seped_auto_ensemble:
975
+ for stem, path in result_seped_auto_ensemble:
976
+ ensemble_sources_list.append(path)
977
+ if stem == ens_s_stem:
978
+ s_stem = path
979
+ elif stem == ens_i_stem:
980
+ i_stem = path
981
+
982
+ if invert_ensemble:
983
+ if not i_stem:
984
+ result_seped_auto_ensemble_invert = self.separate(input=input_file, output_dir=os.path.join(o, f"{ens_mn}_invert"), model_type=ens_mt, model_name=ens_mn, ext_inst=True, template="NAME - MODEL - STEM", output_format="wav", selected_stems=[ens_s_stem], add_settings={"add_single_sep_text_progress": f"{i} из {total_ensemble_models} (инверт.)"}, progress=progress)
985
+ if result_seped_auto_ensemble_invert:
986
+ for stem, path in result_seped_auto_ensemble:
987
+ if stem == ens_i_stem:
988
+ i_stem = path
989
+ ensemble_sources_list.append(path)
990
+
991
+ except Exception as e:
992
+ print(f"\nПроизошла ошибка при разделении: {e}")
993
+ progress(0, desc="Произошла ошибка при разделении, модель пропускается...")
994
+ continue
995
+ finally:
996
+ if s_stem:
997
+ success_separations.append((ens_mn, s_stem, i_stem, weight))
998
+
999
+ ensemble_sources_stems = []
1000
+ ensemble_sources_invert_stems = []
1001
+ weights = []
1002
+
1003
+ for out_mn, out_s_stem, out_i_stem, out_weight in success_separations:
1004
+ ensemble_sources_stems.append(out_s_stem)
1005
+ ensemble_sources_invert_stems.append(out_i_stem)
1006
+ weights.append(out_weight)
1007
+
1008
+
1009
+ auto_ensemble_invout_file = None
1010
+ auto_ensemble_invout_file_wav = None
1011
+
1012
+ auto_ensemble_output_name = f"ensembless_{self.namer.short(basename, length=50)}_{len(ensemble_sources_stems)}_{method}"
1013
+ auto_ensemble_inverted_output_name = f"ensembless_{self.namer.short(basename, length=50)}_{len(ensemble_sources_stems)}_{invert_methods_map[method]}_invert"
1014
+ auto_ensemble_out_file, auto_ensemble_out_file_wav = ensemble_audio_files(files=ensemble_sources_stems, weights=weights, output=os.path.join(o, auto_ensemble_output_name), ensemble_type=method, out_format=out_format, add_wav=True)
1015
+
1016
+ if invert_ensemble:
1017
+ auto_ensemble_invout_file, auto_ensemble_invout_file_wav = ensemble_audio_files(files=ensemble_sources_invert_stems, weights=invert_weights(weights), output=os.path.join(o, auto_ensemble_inverted_output_name), ensemble_type=invert_methods_map[method], out_format=out_format, add_wav=True)
1018
+
1019
+ return auto_ensemble_out_file, auto_ensemble_out_file_wav, auto_ensemble_invout_file, ensemble_sources_list
1020
+
1021
+ @add_inputs_btn.click(
1022
+ inputs=[input_path, input_audio],
1023
+ outputs=[add_inputs, input_audio, add_buttons_row])
1024
+ def add_inputs_fn(input_p, input_a):
1025
+ if input_p and os.path.exists(input_p):
1026
+ if input_a is None:
1027
+ input_a = None
1028
+ if self.audio.check(input_p):
1029
+ input_a = input_p
1030
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
1031
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
1032
+
1033
+ @add_inputs_url_btn.click(
1034
+ inputs=[input_url, input_audio, inputs_url_format, inputs_url_bitrate, inputs_url_cookie],
1035
+ outputs=[add_inputs_from_url, input_audio, add_buttons_row])
1036
+ def add_inputs_from_url_fn(input_u, input_a, fmt, br, cookie):
1037
+ if input_u:
1038
+ if input_a is None:
1039
+ input_a = None
1040
+ downloaded_file = dw_yt_dlp(
1041
+ url=input_u,
1042
+ output_format=fmt,
1043
+ output_bitrate=str(int(br)),
1044
+ cookie=cookie
1045
+ )
1046
+ if downloaded_file and os.path.exists(downloaded_file):
1047
+ if self.audio.check(downloaded_file):
1048
+ input_a = downloaded_file
1049
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
1050
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
1051
+
1052
+ add_path_btn.click(
1053
+ lambda: (gr.update(visible=True), gr.update(visible=False)),
1054
+ outputs=[add_inputs, add_buttons_row])
1055
+
1056
+ add_url_btn.click(
1057
+ lambda: (gr.update(visible=True), gr.update(visible=False)),
1058
+ outputs=[add_inputs_from_url, add_buttons_row])
1059
+
1060
+ inputs_url_format.change(lambda x: gr.update(visible=False if x in ["wav", "flac", "aiff"] else True), inputs=inputs_url_format, outputs=inputs_url_bitrate)
1061
+
1062
+ class ManualEnsembless(MVSEPLESS):
1063
+ def UI(self, output_base_dir, output_temp_dir_check):
1064
+ with gr.Row():
1065
+ with gr.Column():
1066
+ with gr.Group(visible=False) as add_ensemble_inputs:
1067
+ input_ensemble_path = gr.Textbox(label="Путь к входному файлу", interactive=True)
1068
+ add_ensemble_inputs_btn = gr.Button("Добавить файл", variant="primary")
1069
+ add_ensemble_path_btn = gr.Button("Добавить файл по пути", variant="secondary")
1070
+ input_ensemble_files = gr.File(label="Входное аудио", interactive=True, type="filepath", file_count="multiple", file_types=[f".{of}" for of in self.audio.input_formats])
1071
+ input_ensemble_preview_check = gr.Checkbox(label="Показать плееры для входных аудио", interactive=True, value=False)
1072
+ @gr.render(inputs=[input_ensemble_preview_check, input_ensemble_files])
1073
+ def show_input_ensemble_players(preview, audios):
1074
+ if preview:
1075
+ if audios:
1076
+ with gr.Group():
1077
+ for file in audios:
1078
+ gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath")
1079
+
1080
+ with gr.Column():
1081
+ @gr.render(inputs=[input_ensemble_files])
1082
+ def input_ensemble_files_fn(input_files):
1083
+ check_ensemble_files_status = f"""Анализ входных файлов
1084
+ ---"""
1085
+ hz_ = []
1086
+ err_list = []
1087
+ if input_files:
1088
+ for file in input_files:
1089
+ basename = os.path.splitext(os.path.basename(file))[0]
1090
+ if os.path.exists(file):
1091
+ if self.audio.check(file):
1092
+ info = self.audio.get_info(file)
1093
+ hz = info[0].get("sample_rate")
1094
+ check_ensemble_files_status += f"\n{basename} - {hz} hz"
1095
+ hz_.append(hz)
1096
+ else:
1097
+ check_ensemble_files_status += f"\n{basename} - Нет аудио"
1098
+ err_list.append(file)
1099
+ else:
1100
+ check_ensemble_files_status += f"\n{basename} - Файл не найден"
1101
+ err_list.append(file)
1102
+
1103
+ check_ensemble_files_result = f"Действительных файлов: {len(hz_)}"
1104
+
1105
+ all_same = True
1106
+
1107
+ common_rate = None
1108
+
1109
+ for hz_hz in hz_:
1110
+ if common_rate is None:
1111
+ common_rate = hz_hz
1112
+ elif common_rate != hz_hz:
1113
+ all_same = False
1114
+
1115
+ if hz_ and len(hz_) > 1:
1116
+ check_ensemble_files_result += "\nВсе действительные файлы имеют одинаковую частоту дискретизации" if all_same else "\nОшибка! Все действительные файлы имеют РАЗНУЮ частоту дискретизации"
1117
+ else:
1118
+ check_ensemble_files_result += "\nДля создания ансамбля нужно загрузить, как минимум - 2 файла, содержащие аудио"
1119
+
1120
+
1121
+ check_ensemble_files_status += f"\n \n{check_ensemble_files_result}"
1122
+
1123
+ gr.Textbox(container=False, lines=len(check_ensemble_files_status.split("\n")), interactive=False, value=check_ensemble_files_status)
1124
+
1125
+ weights = gr.Textbox(label="Веса", value="1.0,1.0")
1126
+
1127
+ method = gr.Dropdown(label="Алгоритм склеивания", choices=["min_fft", "max_fft", "avg_fft", "median_fft"], value="avg_fft", filterable=False)
1128
+
1129
+ output_format = gr.Dropdown(label="Формат выходного файла", interactive=True,
1130
+ choices=self.audio.output_formats,
1131
+ value="mp3", filterable=False)
1132
+
1133
+ output_manual_ensemble_filename = gr.Textbox(label="Имя выходного файла", value="ensemble", interactive=True)
1134
+
1135
+ make_manual_ensemble_btn = gr.Button(value="Создать ансамбль", variant="primary")
1136
+
1137
+ manual_ensemble_output_audio = gr.Audio(label="Результат", type="filepath", interactive=False, show_download_button=True)
1138
+
1139
+ @make_manual_ensemble_btn.click(
1140
+ inputs=[input_ensemble_files, method, output_format, output_base_dir, output_temp_dir_check, output_manual_ensemble_filename, weights], outputs=manual_ensemble_output_audio
1141
+ )
1142
+ def make_manual_ensemble_fn(input_files_list, method, out_format, o_dir, temp_save, o_filename, weights: str):
1143
+ if o_dir.strip() != "" and not temp_save:
1144
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
1145
+ o = os.path.join(o_dir, f"ensembless_outputs_{timestamp}")
1146
+ os.makedirs(o, exist_ok=True)
1147
+ else:
1148
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
1149
+ o = tempfile.mkdtemp(prefix=f"ensembless_outputs_{timestamp}_")
1150
+ os.makedirs(o, exist_ok=True)
1151
+
1152
+ o_filename = self.namer.sanitize(o_filename)
1153
+ o_filename = self.namer.short(o_filename)
1154
+
1155
+ output_file = ensemble_audio_files(files=input_files_list, output=os.path.join(o, o_filename), weights=[float(x) for x in weights.split(",")], ensemble_type=method, out_format=out_format)
1156
+ return output_file
1157
+
1158
+ add_ensemble_path_btn.click(
1159
+ lambda: (gr.update(visible=True), gr.update(visible=False)),
1160
+ outputs=[add_ensemble_inputs, add_ensemble_path_btn])
1161
+
1162
+ @add_ensemble_inputs_btn.click(
1163
+ inputs=[input_ensemble_path, input_ensemble_files],
1164
+ outputs=[add_ensemble_inputs, input_ensemble_files, add_ensemble_path_btn])
1165
+ def add_ensemble_inputs_fn(input_p, input_a):
1166
+ if input_p and os.path.exists(input_p):
1167
+ if input_a is None:
1168
+ input_a = []
1169
+ if isinstance(input_a, str):
1170
+ input_a = [input_a]
1171
+ if self.audio.check(input_p):
1172
+ input_a.append(input_p)
1173
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
1174
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
1175
+
1176
+ class Inverter_UI(MVSEPLESS):
1177
+ def UI(self):
1178
+ with gr.Group():
1179
+ with gr.Row():
1180
+ original_audio = gr.File(label="Оригинал", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.audio.input_formats])
1181
+ stem_audio = gr.File(label="Cтем, который будет вычтен из оригинала", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.audio.input_formats])
1182
+ with gr.Group():
1183
+ output_format = gr.Dropdown(label="Формат выходного файла", interactive=True,
1184
+ choices=self.audio.output_formats,
1185
+ value="mp3", filterable=False)
1186
+ method = gr.Radio(
1187
+ choices=["waveform", "spectrogram"],
1188
+ label="Метод вычитания",
1189
+ value="waveform",
1190
+ )
1191
+ btn = gr.Button("Вычесть")
1192
+ output_audio = gr.Audio(label="Инверсия", type="filepath", interactive=False, show_download_button=True)
1193
+ @btn.click(inputs=[original_audio, stem_audio, method, output_format], outputs=[output_audio])
1194
+ def invert_result_ensemble(input_file, output_file, method, out_format):
1195
+ if input_file and output_file:
1196
+ o_dir = tempfile.mkdtemp(suffix="_inverter")
1197
+ basename = os.path.splitext(os.path.basename(input_file))[0]
1198
+ output_path = os.path.join(o_dir, f"inverter_{self.namer.short(basename, length=50)}_{method}.{out_format}")
1199
+ inverted = self.inverter.process_audio(audio1_path=input_file, audio2_path=output_file, out_format=out_format, method=method, output_path=output_path)
1200
+ return inverted
1201
+ else:
1202
+ return None
1203
+
1204
+ class Vbach(MVSEPLESS):
1205
+ pitch_methods = ("rmvpe+", "fcpe", "mangio-crepe")
1206
+ hop_length_values = (8, 512)
1207
+ index_rates_values = (0, 1)
1208
+ filter_radius_values = (0, 7)
1209
+ protect_values = (0, 0.5)
1210
+ rms_values = (0, 1)
1211
+ f0_min_values = (50, 3000)
1212
+ f0_max_values = (300, 6000)
1213
+
1214
+ def UI(self):
1215
+ with gr.Tab("Инференс"):
1216
+ with gr.Row():
1217
+ with gr.Column():
1218
+ with gr.Group():
1219
+ input_audio = gr.File(label="Входные аудио", interactive=True, type="filepath", file_count="multiple", file_types=[f".{of}" for of in self.audio.input_formats])
1220
+ converted_state = gr.Textbox(label="Состояние разделения", interactive=False, value="", visible=False)
1221
+ status = gr.Textbox(container=False, lines=3, interactive=False, max_lines=3)
1222
+ input_preview_check = gr.Checkbox(label="Показать плееры для входных аудио", interactive=True, value=False)
1223
+ @gr.render(inputs=[input_preview_check, input_audio])
1224
+ def show_input_players(preview, audios):
1225
+ if preview:
1226
+ if audios:
1227
+ with gr.Group():
1228
+ for file in audios:
1229
+ gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath")
1230
+
1231
+ with gr.Column():
1232
+ with gr.Group():
1233
+ model_name = gr.Dropdown(label="Имя модели", interactive=True)
1234
+ model_list_refresh_btn = gr.Button("Обновить", variant="secondary", interactive=True)
1235
+ @model_list_refresh_btn.click(
1236
+ outputs=[model_name]
1237
+ )
1238
+ def refresh_list_voice_models():
1239
+ models = []
1240
+ models = self.vbach_model_manager.parse_voice_models()
1241
+ first_model = None
1242
+ if len(models) > 0:
1243
+ first_model = models[0]
1244
+ return gr.update(choices=models, value=first_model)
1245
+ with gr.Group():
1246
+ pitch_method = gr.Radio(label="Метод извлечения высоты тона", choices=self.pitch_methods, value=self.pitch_methods[0], interactive=True)
1247
+ pitch = gr.Slider(label="Высота тона", minimum=-48, maximum=48, step=0.5, value=0, interactive=True)
1248
+ hop_length = gr.Slider(label="Длина шага", info="Длина шага влияет на точность передачи высоты тона\nЧем меньше длина шага - тем точнее будет передана высота тона", minimum=self.hop_length_values[0], maximum=self.hop_length_values[1], step=8, value=128, interactive=True, visible=False)
1249
+ @pitch_method.change(
1250
+ inputs=[pitch_method],
1251
+ outputs=[hop_length]
1252
+ )
1253
+ def show_mangio_crepe_hop_length(pitch_method):
1254
+ return gr.update(visible=True if pitch_method in ["mangio-crepe"] else False)
1255
+ stereo_mode = gr.Radio(
1256
+ choices=["mono", "left/right", "sim/dif"],
1257
+ label="Стерео режим",
1258
+ info="mono - монофоническая обработка аудио, \nleft/right - обработка левого и правого каналов отдельно, \nsim/dif - обработка фантомного центра и стерео-базы, разделенную на левый и правый каналы",
1259
+ value="mono",
1260
+ interactive=True
1261
+ )
1262
+ with gr.Accordion(label="Дополнительные настройки",open=False):
1263
+ with gr.Group():
1264
+ with gr.Row():
1265
+ index_rate = gr.Slider(label="Влияние индекса", info="Чем ниже значение, тем больше голос похож на исходный; чем выше, тем ближе к модели", minimum=self.index_rates_values[0], maximum=self.index_rates_values[1], step=0.05, value=0, interactive=True)
1266
+ filter_radius = gr.Slider(label="Радиус фильтра", info="Сглаживает результаты извлечения тона\nМожет снизить дыхание и шумы на выходе", minimum=self.filter_radius_values[0], maximum=self.filter_radius_values[1], step=1, value=3, interactive=True)
1267
+ with gr.Row():
1268
+ rms = gr.Slider(label="Соотношение огибающих громкости", info="Значение 0 - огибающая громкости как у входного аудио, 1 - как у выходного сигнала", minimum=self.rms_values[0], maximum=self.rms_values[1], step=0.05, value=0.25, interactive=True)
1269
+ protect = gr.Slider(label="Защита согласных", info="Предовращает роботизацию дыхания и согласных (Может влиять на четкость речи)\nЗначение 0.5 - выключает защиту, 0 - максимальная защита", minimum=self.protect_values[0], maximum=self.protect_values[1], step=0.05, value=0.35, interactive=True)
1270
+ with gr.Group():
1271
+ with gr.Row():
1272
+ f0_min = gr.Slider(label="Нижний предел диапазона определения высоты тона", minimum=self.f0_min_values[0], maximum=self.f0_min_values[1], step=10, value=50, interactive=True)
1273
+ f0_max = gr.Slider(label="Верхний предел диапазона определения высоты тона", minimum=self.f0_max_values[0], maximum=self.f0_max_values[1], step=10, value=1100, interactive=True)
1274
+
1275
+ with gr.Group():
1276
+ output_name = gr.Textbox(label="Имя выходного файла", interactive=True, value="NAME - MODEL - F0METHOD - PITCH")
1277
+ format_output_name_check = gr.Checkbox(label="Форматировать имя", info="Используйте ключи: \nNAME - имя входного файла без расширения, \nPITCH - высота тона, \nF0METHOD - метод извлечения высота тона, \nMODEL - имя голосовой модели", value=True, interactive=True)
1278
+ output_format = gr.Dropdown(label="Формат выходного файла", interactive=True, choices=self.audio.output_formats, value=self.audio.output_formats[0], filterable=False)
1279
+ convert_btn = gr.Button("Преобразовать", variant="primary", interactive=True)
1280
+
1281
+
1282
+ @convert_btn.click(
1283
+ inputs=[
1284
+ input_audio,
1285
+ model_name,
1286
+ pitch_method,
1287
+ pitch,
1288
+ hop_length,
1289
+ index_rate,
1290
+ filter_radius,
1291
+ rms,
1292
+ protect,
1293
+ f0_min,
1294
+ f0_max,
1295
+ output_name,
1296
+ format_output_name_check,
1297
+ output_format,
1298
+ stereo_mode
1299
+ ], outputs=[converted_state, status]
1300
+ )
1301
+ def vbach_convert_batch(ifl, mn, pm, p, hl, ir, fr, rms, pr, f0min, f0max, on, fn, of, sm):
1302
+ output_converted_files = []
1303
+ progress = gr.Progress()
1304
+ if ifl:
1305
+ for i, file in enumerate(ifl, start=1):
1306
+ try:
1307
+ print(f"Файл {i} из {len(ifl)}: {file}")
1308
+ progress(progress=(i / len(ifl)), desc=f"Файл {i} из {len(ifl)}")
1309
+ out_conv = vbach_inference(input_file=file, model_name=mn, output_dir=tempfile.mkdtemp(), output_name=on, format_name=True if len(ifl) > 1 else fn, output_format=of, pitch=p, method_pitch=pm, output_bitrate=320, add_params={ "index_rate": ir,"filter_radius": fr,"protect": pr,"rms": rms,"mangio_crepe_hop_length": hl,"f0_min": f0min,"f0_max": f0max,"stereo_mode": sm })
1310
+ output_converted_files.append(out_conv)
1311
+ except Exception as e:
1312
+ print(e)
1313
+ return str(output_converted_files), None
1314
+
1315
+ @gr.render(inputs=[converted_state])
1316
+ def show_players_converted(state):
1317
+ if state != "":
1318
+ output_converted_files = ast.literal_eval(state)
1319
+ if output_converted_files:
1320
+ with gr.Group():
1321
+ for conv_file in output_converted_files:
1322
+ basename = os.path.splitext(os.path.basename(conv_file))[0]
1323
+ gr.Audio(
1324
+ label=basename,
1325
+ value=conv_file,
1326
+ type="filepath",
1327
+ interactive=False,
1328
+ show_download_button=True,
1329
+ )
1330
+
1331
+ with gr.TabItem("Менеджер"):
1332
+ with gr.TabItem("Загрузить по ссылке"):
1333
+ with gr.TabItem("Через zip файл"):
1334
+ with gr.Row():
1335
+ with gr.Column(variant="panel"):
1336
+ url_zip = gr.Text(label="Ссылка на zip файл")
1337
+ with gr.Group():
1338
+ url_zip_model_name = gr.Text(
1339
+ label="Имя модели",
1340
+ )
1341
+ url_zip_download_btn = gr.Button("Загрузить", variant="primary")
1342
+
1343
+ url_zip_output = gr.Text(label="Статус", interactive=False, lines=5)
1344
+ url_zip_download_btn.click(
1345
+ (lambda x, y: self.vbach_model_manager.install_model_zip(x, self.namer.short(self.namer.sanitize(y), length=40), "url")),
1346
+ inputs=[url_zip, url_zip_model_name],
1347
+ outputs=url_zip_output,
1348
+ )
1349
+
1350
+ with gr.TabItem("Через отдельные файлы"):
1351
+ with gr.Row():
1352
+ with gr.Column(variant="panel"):
1353
+ url_pth = gr.Text(label="Ссылка на *.pth файл")
1354
+ url_index = gr.Text(label="Ссылка на *.index файл (необязательно)")
1355
+ with gr.Group():
1356
+ url_file_model_name = gr.Text(
1357
+ label="Имя модели",
1358
+ )
1359
+ url_file_download_btn = gr.Button("Загрузить", variant="primary")
1360
+
1361
+ url_file_output = gr.Text(label="Статус", interactive=False, lines=5)
1362
+ url_file_download_btn.click(
1363
+ (lambda x, y, z: self.vbach_model_manager.install_model_files(x, y, self.namer.short(self.namer.sanitize(z), length=40), "url")),
1364
+ inputs=[url_index, url_pth, url_file_model_name],
1365
+ outputs=url_file_output,
1366
+ )
1367
+
1368
+ with gr.Tab("Загрузить с устройства"):
1369
+ with gr.Tab("Через zip файл"):
1370
+ with gr.Row():
1371
+ with gr.Column():
1372
+ local_zip = gr.File(
1373
+ label="zip файл", file_types=[".zip"], file_count="single"
1374
+ )
1375
+ with gr.Column(variant="panel"):
1376
+ with gr.Group():
1377
+ local_zip_model_name = gr.Text(
1378
+ label="Имя модели",
1379
+ )
1380
+ local_zip_upload_btn = gr.Button("Загрузить", variant="primary")
1381
+
1382
+ local_zip_output = gr.Text(label="Статус", interactive=False, lines=5)
1383
+ local_zip_upload_btn.click(
1384
+ (lambda x, y: self.vbach_model_manager.install_model_zip(x, self.namer.short(self.namer.sanitize(y), length=40), "local")),
1385
+ inputs=[local_zip, local_zip_model_name],
1386
+ outputs=local_zip_output,
1387
+ )
1388
+
1389
+ with gr.TabItem("Через отдельные файлы"):
1390
+ with gr.Group():
1391
+ with gr.Row():
1392
+ local_pth = gr.File(
1393
+ label="*.pth файл", file_types=[".pth"], file_count="single"
1394
+ )
1395
+ local_index = gr.File(
1396
+ label="*.index файл (необязательно)", file_types=[".index"], file_count="single"
1397
+ )
1398
+ with gr.Column(variant="panel"):
1399
+ with gr.Group():
1400
+ local_file_model_name = gr.Text(
1401
+ label="Имя модели",
1402
+ )
1403
+ local_file_upload_btn = gr.Button("Загрузить", variant="primary")
1404
+
1405
+ local_file_output = gr.Text(
1406
+ label="Статус", interactive=False
1407
+ )
1408
+ local_file_upload_btn.click(
1409
+ (lambda x, y, z: self.vbach_model_manager.install_model_files(x, y, self.namer.short(self.namer.sanitize(z), length=40), "local")),
1410
+ inputs=[local_index, local_pth, local_file_model_name],
1411
+ outputs=local_file_output,
1412
+ )
1413
+
1414
+ with gr.TabItem("Удалить модель"):
1415
+ with gr.Column(variant="panel"):
1416
+ with gr.Group():
1417
+ delete_model_name = gr.Dropdown(
1418
+ label="Имя модели",
1419
+ choices=self.vbach_model_manager.parse_voice_models(),
1420
+ interactive=True,
1421
+ filterable=False
1422
+ )
1423
+ delete_refresh_btn = gr.Button("Обновить")
1424
+ @delete_refresh_btn.click(inputs=None, outputs=delete_model_name)
1425
+ def refresh_list_voice_models():
1426
+ models = []
1427
+ models = self.vbach_model_manager.parse_voice_models()
1428
+ first_model = None
1429
+ if len(models) > 0:
1430
+ first_model = models[0]
1431
+ return gr.update(choices=models, value=first_model)
1432
+
1433
+ delete_output = gr.Text(
1434
+ label="Статус", interactive=False, lines=5
1435
+ )
1436
+ delete_btn = gr.Button("Удалить")
1437
+ delete_btn.click(
1438
+ fn=self.vbach_model_manager.del_voice_model,
1439
+ inputs=delete_model_name,
1440
+ outputs=delete_output
1441
+ )
1442
+
1443
+ @gr.on(fn="decorator", inputs=None, outputs=[delete_model_name, model_name])
1444
+ def refresh_list_voice_models():
1445
+ models = []
1446
+ models = self.vbach_model_manager.parse_voice_models()
1447
+ first_model = None
1448
+ if len(models) > 0:
1449
+ first_model = models[0]
1450
+ return gr.update(choices=models, value=first_model), gr.update(choices=models, value=first_model)
1451
+
1452
+
1453
+ class PluginManager(Separator):
1454
+ plugins_dir = os.path.join(script_dir, "plugins")
1455
+ os.makedirs(plugins_dir, exist_ok=True)
1456
+
1457
+ def restart_after_install_plugin(self):
1458
+ subprocess.Popen([os.sys.executable] + sys.argv)
1459
+ os._exit(0)
1460
+
1461
+ def parse_plugins(self):
1462
+ for plugin_file in os.listdir(self.plugins_dir):
1463
+ # Пропускаем не-Python файлы и __init__.py
1464
+ if not plugin_file.endswith('.py') or plugin_file == '__init__.py':
1465
+ continue
1466
+
1467
+ # Получаем имя модуля без расширения
1468
+ plugin_module_name = os.path.splitext(plugin_file)[0]
1469
+
1470
+ try:
1471
+ # Определяем путь импорта в зависимости от структуры проекта
1472
+ if __package__:
1473
+ # Если мы в пакете, используем абсолютный импорт
1474
+ plugin_module = importlib.import_module(f".plugins.{plugin_module_name}", package=__package__)
1475
+ else:
1476
+ # Если не в пакете, пробуем разные варианты
1477
+ try:
1478
+ # Попробуем абсолютный импорт
1479
+ plugin_module = importlib.import_module(f"plugins.{plugin_module_name}")
1480
+ except ImportError:
1481
+ # Если не сработало, загружаем из файла
1482
+ plugin_path = os.path.join(self.plugins_dir, plugin_file)
1483
+ spec = importlib.util.spec_from_file_location(plugin_module_name, plugin_path)
1484
+ plugin_module = importlib.util.module_from_spec(spec)
1485
+ spec.loader.exec_module(plugin_module)
1486
+
1487
+ # Получаем класс Plugin из модуля
1488
+ plugin_class = getattr(plugin_module, 'Plugin')
1489
+
1490
+ # Создаем экземпляр плагина
1491
+ plugin_instance = plugin_class()
1492
+
1493
+ # Создаем UI плагина
1494
+ with gr.Tab(plugin_instance.name):
1495
+ plugin_instance.UI()
1496
+
1497
+ except Exception as e:
1498
+ print(f"Ошибка загрузки плагина {plugin_module_name}: {e}")
1499
+ continue
1500
+
1501
+ def UI(self):
1502
+ with gr.Tab("Установка"):
1503
+ upload_plugins_files = gr.File(label="Загрузить плагины", file_types=[".py"], file_count="multiple", interactive=True)
1504
+ install_plugins_btn = gr.Button("Установить", interactive=True)
1505
+
1506
+ @install_plugins_btn.click(
1507
+ inputs=[upload_plugins_files]
1508
+ )
1509
+ def upload_plugin_list(files):
1510
+ if not files:
1511
+ return
1512
+ for file in files:
1513
+ try:
1514
+ # Копируем только .py файлы
1515
+ if file.name.endswith('.py'):
1516
+ shutil.copy(
1517
+ file, os.path.join(self.plugins_dir, os.path.basename(file).replace(" ", "_"))
1518
+ )
1519
+ except Exception as e:
1520
+ print(f"Ошибка копирования файла {file}: {e}")
1521
+ time.sleep(2)
1522
+ self.restart_after_install_plugin()
1523
+
1524
+ self.parse_plugins()
1525
+
1526
+ def mvsepless_app(theme):
1527
+ css = None
1528
+ with gr.Blocks(theme=theme, css=css, title="Разделение музыки и вокала") as app:
1529
+
1530
+ output_base_dir_state = gr.State(value=os.path.join(os.getcwd(), "outputs"))
1531
+ output_temp_dir_state = gr.State(value=False)
1532
+
1533
+ with gr.Tab("Инференс"):
1534
+ Separator().UI(output_base_dir_state, output_temp_dir_state)
1535
+
1536
+ with gr.Tab("Ансамбль"):
1537
+ with gr.Tab("Авто-ансамбль"):
1538
+ AutoEnsembless().UI(output_base_dir_state, output_temp_dir_state)
1539
+
1540
+ with gr.Tab("Ручной ансамбль"):
1541
+ ManualEnsembless().UI(output_base_dir_state, output_temp_dir_state)
1542
+
1543
+ with gr.Tab("Вычитание"):
1544
+ Inverter_UI().UI()
1545
+
1546
+ with gr.Tab("Преобразование"):
1547
+ Vbach().UI()
1548
+
1549
+ with gr.Tab("Плагины"):
1550
+ PluginManager().UI()
1551
+
1552
+ with gr.Tab("Настройки"):
1553
+ with gr.Column():
1554
+ output_base_dir_ui = gr.Textbox(
1555
+ label="Базовый каталог для выходных файлов",
1556
+ interactive=True,
1557
+ value=os.path.join(os.getcwd(), "outputs"),
1558
+ lines=5
1559
+ )
1560
+ output_temp_dir_check_ui = gr.Checkbox(
1561
+ label="Использовать временный каталог для выходных файлов",
1562
+ interactive=True,
1563
+ value=False
1564
+ )
1565
+
1566
+ # Связываем UI элементы с состоянием
1567
+ output_base_dir_ui.change(
1568
+ lambda x: x,
1569
+ inputs=[output_base_dir_ui],
1570
+ outputs=[output_base_dir_state]
1571
+ )
1572
+ output_temp_dir_check_ui.change(
1573
+ lambda x: x,
1574
+ inputs=[output_temp_dir_check_ui],
1575
+ outputs=[output_temp_dir_state]
1576
+ )
1577
+
1578
+ return app
1579
+
mvsepless/__main__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gradio as gr
3
+ import os
4
+
5
+ if not __package__:
6
+ from __init__ import mvsepless_app, Separator
7
+ else:
8
+ from .__init__ import mvsepless_app, Separator
9
+
10
+ if __name__ == "__main__":
11
+ parser = argparse.ArgumentParser(description="MVSepless")
12
+ subparsers = parser.add_subparsers(dest="command")
13
+ app_parser = subparsers.add_parser("app", help="Приложение MVSepless")
14
+ app_parser.add_argument("--port", type=int, default=None, help="Порт для запуска сервера Gradio.")
15
+ app_parser.add_argument("--share", action="store_true", help="Создать публичную ссылку для приложения Gradio.")
16
+ cli_parser = subparsers.add_parser("cli", help="CLI MVSepless Lite")
17
+ cli_parser.add_argument("--input", type=str, required=True, help="Входной аудиофайл или каталог.")
18
+ cli_parser.add_argument("--output_dir", type=str, default=None, help="Каталог для выходных файлов.")
19
+ cli_parser.add_argument("--model_type", type=str, default="mel_band_roformer", help="Тип модели разделения.")
20
+ cli_parser.add_argument("--model_name", type=str, default="Mel-Band-Roformer_Vocals_kimberley_jensen", help="Имя модели разделения.")
21
+ cli_parser.add_argument("--ext_inst", action="store_true", help="Извлечь инструментал.")
22
+ cli_parser.add_argument("--output_format", type=str, default="mp3", choices=Separator.audio.output_formats, help="Формат выходного файла.")
23
+ cli_parser.add_argument("--output_bitrate", type=str, default="320k", help="Битрейт выходного файла.")
24
+ cli_parser.add_argument("--template", type=str, default="NAME (STEM) MODEL", help="Шаблон именования выходных файлов.")
25
+ cli_parser.add_argument("--selected_stems", type=str, nargs='*', default=None, help="Выбранные стемы для разделения.")
26
+ args = parser.parse_args()
27
+
28
+ if args.command == "app":
29
+ theme = gr.themes.Citrus(
30
+ primary_hue="teal",
31
+ secondary_hue="blue",
32
+ neutral_hue="blue",
33
+ spacing_size="sm",
34
+ font=[
35
+ gr.themes.GoogleFont("Montserrat"),
36
+ "ui-sans-serif",
37
+ "system-ui",
38
+ "sans-serif",
39
+ ],
40
+ )
41
+ mvsepless_lite_app = mvsepless_app(theme=theme)
42
+ mvsepless_lite_app.launch(server_name="0.0.0.0", server_port=args.port, share=args.share, allowed_paths=["/"], debug=True)
43
+ elif args.command == "cli":
44
+ input_data = args.input
45
+ if os.path.isdir(input_data):
46
+ list_valid_files = []
47
+ for file in os.listdir(args.input):
48
+ if os.path.isfile(os.path.join(args.input, file)):
49
+ if Separator.audio.check(os.path.join(args.input, file)):
50
+ list_valid_files.append(os.path.join(args.input, file))
51
+
52
+ input_files = list_valid_files
53
+ else:
54
+ input_files = input_data
55
+
56
+ results = Separator().separate(
57
+ input=input_files,
58
+ output_dir=args.output_dir,
59
+ model_type=args.model_type,
60
+ model_name=args.model_name,
61
+ ext_inst=args.ext_inst,
62
+ output_format=args.output_format,
63
+ output_bitrate=args.output_bitrate,
64
+ template=args.template,
65
+ selected_stems=args.selected_stems
66
+ )
67
+ print("Разделение завершено.")
mvsepless/audio.py ADDED
@@ -0,0 +1,781 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import sys
4
+ import json
5
+ import subprocess
6
+ import numpy as np
7
+ from typing import Literal
8
+ from collections.abc import Callable
9
+ from pathlib import Path
10
+ from numpy.typing import DTypeLike
11
+ import tempfile
12
+ import librosa
13
+ if not __package__:
14
+ from namer import Namer
15
+ else:
16
+ from .namer import Namer
17
+ class NotInputFileSpecified(Exception): pass
18
+ class NotOutputFileSpecified(Exception): pass
19
+ class NotSupportedDataType(Exception): pass
20
+ class ErrorDecode(Exception): pass
21
+ class ErrorEncode(Exception): pass
22
+ class NotSupportedFormat(Exception): pass
23
+ class SampleRateError(Exception): pass
24
+ class FileIsNotAudio(Exception): pass
25
+
26
+ class Audio(Namer):
27
+ def __init__(self):
28
+ """
29
+ Чтение и запись аудио файла через ffmpeg
30
+
31
+ Поддерживаемые типы данных: - int16, int32, float32, float64
32
+ """
33
+ super().__init__()
34
+ self.ffmpeg_path = os.environ.get("MVSEPLESS_FFMPEG", "ffmpeg")
35
+ self.ffprobe_path = os.environ.get("MVSEPLESS_FFPROBE", "ffprobe")
36
+ self.output_formats = ("mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff")
37
+ self.input_formats = ("mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff", "mp4", "mkv", "webm", "avi", "mov", "ts")
38
+ self.supported_dtypes = ("int16", "int32", "float32", "float64")
39
+ self.dtypes_dict = {
40
+ "int16": "s16le",
41
+ "int32": "s32le",
42
+ "float32": "f32le",
43
+ "float64": "f64le",
44
+ np.int16: "s16le",
45
+ np.int32: "s32le",
46
+ np.float32: "f32le",
47
+ np.float64: "f64le",
48
+ }
49
+ self.bitrate_limit = {
50
+ "mp3": {"min": 8, "max": 320},
51
+ "aac": {"min": 8, "max": 512},
52
+ "m4a": {"min": 8, "max": 512},
53
+ "ac3": {"min": 32, "max": 640},
54
+ "ogg": {"min": 64, "max": 500},
55
+ "opus": {"min": 6, "max": 512},
56
+ }
57
+ self.sample_rates = {
58
+ "mp3": {
59
+ "supported": (44100, 48000, 32000, 22050, 24000, 16000, 11025, 12000, 8000)
60
+ },
61
+ "opus": {"supported": (48000, 24000, 16000, 12000, 8000)},
62
+ "m4a": {
63
+ "supported": (
64
+ 96000,
65
+ 88200,
66
+ 64000,
67
+ 48000,
68
+ 44100,
69
+ 32000,
70
+ 24000,
71
+ 22050,
72
+ 16000,
73
+ 12000,
74
+ 11025,
75
+ 8000,
76
+ 7350,
77
+ )
78
+ },
79
+ "aac": {
80
+ "supported": (
81
+ 96000,
82
+ 88200,
83
+ 64000,
84
+ 48000,
85
+ 44100,
86
+ 32000,
87
+ 24000,
88
+ 22050,
89
+ 16000,
90
+ 12000,
91
+ 11025,
92
+ 8000,
93
+ 7350,
94
+ )
95
+ },
96
+ "ac3": {
97
+ "supported": (
98
+ 48000,
99
+ 44100,
100
+ 32000,
101
+ )
102
+ },
103
+ "ogg": {"min": 6, "max": 192000},
104
+ "wav": {"min": 0, "max": float("inf")},
105
+ "aiff": {"min": 0, "max": float("inf")},
106
+ "flac": {"min": 0, "max": 192000},
107
+ }
108
+ self.check_ffmpeg()
109
+ self.check_ffprobe()
110
+
111
+ def check_ffmpeg(self):
112
+ """
113
+ Проверяет, установлен ли ffmpeg?
114
+ """
115
+ try:
116
+ ffmpeg_version_output = subprocess.check_output(
117
+ [self.ffmpeg_path, "-version"], text=True
118
+ )
119
+ except FileNotFoundError:
120
+ if "PYTEST_CURRENT_TEST" not in os.environ:
121
+ raise FileNotFoundError("FFMPEG не установлен. Укажите путь к установленному FFMPEG через переменную окружения MVSEPLESS_FFMPEG")
122
+
123
+ def check_ffprobe(self):
124
+ """
125
+ Проверяет, установлен ли ffprobe?
126
+ """
127
+ try:
128
+ ffmpeg_version_output = subprocess.check_output(
129
+ [self.ffprobe_path, "-version"], text=True
130
+ )
131
+ except FileNotFoundError:
132
+ if "PYTEST_CURRENT_TEST" not in os.environ:
133
+ raise FileNotFoundError("FFPROBE не установлен. Укажите путь к установленному FFPROBE через переменную окружения MVSEPLESS_FFPROBE")
134
+
135
+
136
+ def fit_sr(
137
+ self,
138
+ f: str | Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] = "mp3",
139
+ sr: int = 44100
140
+ ) -> int:
141
+ """
142
+ Исправляет значение частоты дисректизации выходного файла
143
+
144
+ Параметры:
145
+ f: Формат вывода
146
+ sr: Частота дискретизации (целое число)
147
+ Возвращает:
148
+ sr: Исправленная частота дискретизации
149
+ """
150
+ format_info = self.sample_rates.get(f.lower())
151
+
152
+ if not format_info:
153
+ return None # Формат не найден
154
+
155
+ if "supported" in format_info:
156
+ # Для форматов с конкретным списком
157
+ supported_rates = format_info["supported"]
158
+ if sr in supported_rates:
159
+ return sr
160
+
161
+ # Находим ближайшую поддерживаемую частоту
162
+ return min(supported_rates, key=lambda x: abs(x - sr))
163
+
164
+ elif "min" in format_info and "max" in format_info:
165
+ # Для форматов с диапазоном - обрезаем до границ
166
+ min_rate = format_info["min"]
167
+ max_rate = format_info["max"]
168
+
169
+ if sr < min_rate:
170
+ return min_rate
171
+ elif sr > max_rate:
172
+ return max_rate
173
+ else:
174
+ return sr
175
+
176
+ return None
177
+
178
+ def fit_br(
179
+ self,
180
+ f: str | Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] = "mp3",
181
+ br: int = 320
182
+ ) -> int:
183
+ """
184
+ Исправляет значение битрейта выходного файла
185
+
186
+ Параметры:
187
+ f: Формат вывода
188
+ br: Битрейт (целое число)
189
+ Возвращает:
190
+ br: Исправленный битрейт
191
+ """
192
+ if f not in self.bitrate_limit:
193
+ raise NotSupportedFormat(f"Формат {f} не поддерживается")
194
+
195
+ limits = self.bitrate_limit[f]
196
+
197
+ if br < limits["min"]:
198
+ return limits["min"]
199
+ elif br > limits["max"]:
200
+ return limits["max"]
201
+ else:
202
+ return br
203
+
204
+ def get_info(
205
+ self,
206
+ i: str | os.PathLike | Callable | None = None,
207
+ ) -> dict[int, dict[int, float]]:
208
+ """
209
+ Получает информацию о аудио потоках из файла напрямую через FFMPEG
210
+
211
+ Параметры:
212
+ i: Путь к выходному файлу
213
+ Возвращает:
214
+ audio_info: Словарь с информацией о аудиопотоках вида:
215
+
216
+ {Номер потока:
217
+ {
218
+ "sample_rate": Частота дисректизации (является целым числом),
219
+ "duration": Длительность аудиопотока (является числом с плавающей точкой)
220
+ }
221
+ }
222
+ """
223
+ audio_info = {}
224
+ if i:
225
+ if isinstance(i, Path):
226
+ i = str(i)
227
+ if os.path.exists(i):
228
+ cmd = [self.ffprobe_path, "-i", i, "-v", "quiet", "-hide_banner",
229
+ "-show_entries", "stream=index,sample_rate,duration", "-select_streams", "a", "-of", "json"]
230
+
231
+ process = subprocess.Popen(
232
+ cmd,
233
+ stdin=subprocess.PIPE,
234
+ stdout=subprocess.PIPE,
235
+ stderr=subprocess.PIPE,
236
+ )
237
+
238
+ stdout, stderr = process.communicate()
239
+
240
+ if process.returncode != 0:
241
+ print(f"STDERR: {stderr.decode('utf-8')}")
242
+ print(f"STDOUT: {stdout.decode('utf-8')}")
243
+
244
+ json_output = json.loads(stdout)
245
+ streams = json_output["streams"]
246
+ if not streams:
247
+ pass
248
+
249
+ else:
250
+ for a, stream in enumerate(streams):
251
+ audio_info[a] = {
252
+ "sample_rate": int(stream.get("sample_rate", 0)),
253
+ "duration": float(stream.get("duration", 0))
254
+ }
255
+
256
+ return audio_info
257
+
258
+ else:
259
+ raise FileExistsError("Указанного файла не существует")
260
+
261
+ else:
262
+ raise NotInputFileSpecified("Не указан путь к файлу")
263
+
264
+ def check(
265
+ self,
266
+ i: str | os.PathLike | Callable | None = None
267
+ ) -> bool:
268
+ """
269
+ Проверяет, является ли файл аудио или видео файлом, поддерживаемым ffmpeg
270
+
271
+ Параметры:
272
+ i: Путь к выходному файлу
273
+ Возвращает:
274
+ is_audio_video: Булево значение, является ли файл аудио или видео файлом
275
+ """
276
+ if i:
277
+ if isinstance(i, Path):
278
+ i = str(i)
279
+ if os.path.exists(i):
280
+ info = self.get_info(i=i)
281
+ if info:
282
+ list_streams = list(info.keys())
283
+ if len(list_streams) > 0:
284
+ if info[0].get("sample_rate") > 0:
285
+ return True
286
+ else:
287
+ return False
288
+ else:
289
+ return False
290
+ else:
291
+ return False
292
+ else:
293
+ raise FileExistsError("Указанного файла не существует")
294
+ else:
295
+ raise NotInputFileSpecified("Не указан путь к файлу")
296
+
297
+ def read(
298
+ self,
299
+ i: str | os.PathLike | Callable | None = None,
300
+ sr: int | None = None,
301
+ mono: bool = False,
302
+ dtype: DTypeLike = np.float32,
303
+ s: int = 0
304
+ ) -> tuple[np.ndarray, int, float]:
305
+ """
306
+ Читает аудио-файл, преобразовывая его в массив с аудио данными напрямую через FFMPEG
307
+ Является заменой soundfile.read() и librosa.load()
308
+
309
+ Параметры:
310
+ i: Путь к выходному файлу
311
+ sr: Целевая частота дискретизации (Если не указана, то используется частота дискретизации входного файла)
312
+ mono: Конвертация в моно (по умолчанию отключена)
313
+ dtype: Тип данных (поддерживаются типы: int16, int32, float32, float64; по умолчанию - float32)
314
+ s: Номер аудиопотока (по умолчанию 0)
315
+ Возвращает:
316
+ audio_array: Массив с аудио данными
317
+ sr: Частота дискретизации массива
318
+ duration: Длительность аудио (количество сэмплов / частота дискретизации)
319
+ """
320
+ output_format = self.dtypes_dict.get(dtype, None)
321
+ if not output_format:
322
+ raise NotSupportedDataType(f"Этот тип данных не поддерживается {dtype}")
323
+ if i:
324
+ if isinstance(i, Path):
325
+ i = str(i)
326
+ if os.path.exists(i):
327
+ audio_info = self.get_info(i=i)
328
+ list_streams = list(audio_info.keys())
329
+ if audio_info.get(s, False):
330
+ stream = s
331
+ else:
332
+ if len(list_streams) > 0:
333
+ stream = 0
334
+ else:
335
+ raise FileIsNotAudio("В входном файле нет аудио потоков")
336
+
337
+ sample_rate_input = audio_info[stream]["sample_rate"]
338
+ if sample_rate_input == 0:
339
+ raise FileIsNotAudio("В входном файле нет аудио потоков")
340
+
341
+ cmd = [
342
+ self.ffmpeg_path,
343
+ "-i", i,
344
+ "-map", f"0:a:{stream}", "-vn",
345
+ "-f", output_format,
346
+ "-ac", "1" if mono else "2",
347
+ ]
348
+
349
+ if sr:
350
+ cmd.extend(["-ar", str(sr)])
351
+ else:
352
+ sr = sample_rate_input
353
+
354
+ cmd.append("pipe:1")
355
+
356
+ process = subprocess.Popen(
357
+ cmd,
358
+ stdout=subprocess.PIPE,
359
+ stderr=subprocess.PIPE,
360
+ bufsize=10**8
361
+ )
362
+
363
+ try:
364
+
365
+ raw_audio, stderr = process.communicate(timeout=300)
366
+
367
+ if process.returncode != 0:
368
+ raise ErrorDecode(f"FFmpeg error: {stderr.decode()}")
369
+
370
+ except subprocess.TimeoutExpired:
371
+ process.kill()
372
+ raise ErrorDecode("FFmpeg timeout при чтении файла")
373
+
374
+ audio_array = np.frombuffer(raw_audio, dtype=dtype)
375
+
376
+ channels = 1 if mono else 2
377
+ audio_array = audio_array.reshape((-1, channels)).T
378
+ if audio_array.ndim > 1 and channels == 1:
379
+ audio_array = np.mean(audio_array, axis=tuple(range(audio_array.ndim - 1)))
380
+
381
+ len_samples = float(audio_array.shape[-1])
382
+
383
+ duration = len_samples / sr
384
+
385
+ print(f"Частота дискретизации: {sr}")
386
+
387
+ return audio_array, sr, duration
388
+ else:
389
+ raise FileExistsError("Указанного файла не существует")
390
+
391
+ else:
392
+ raise NotInputFileSpecified("Не указан путь к файлу")
393
+
394
+ def write(
395
+ self,
396
+ o: str | os.PathLike | Callable | None = None,
397
+ array: np.ndarray = np.array([], dtype=np.float32),
398
+ sr: int = 44100,
399
+ of: str | Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] | None = None,
400
+ br: str | int | None = None
401
+ ) -> str:
402
+ """
403
+ Записывает numpy-массив с аудио данными в файл напрямую через ffmpeg.
404
+ Является заменой soundfile.write()
405
+
406
+ Параметры:
407
+ o: Путь к выходному файлу
408
+ array: Массив с аудио данными (поддерживаются типы: int16, int32, float32, float64)
409
+ sr: Частота дискретизации массива
410
+ of: Формат вывода (по умолчанию mp3)
411
+ br: Битрейт для кодеков, сжимающих аудио с потерями
412
+ Возвращает:
413
+ o: Путь к выходному файлу
414
+ """
415
+ if isinstance(array, np.ndarray):
416
+
417
+ if len(array.shape) == 1:
418
+ array = array.reshape(-1, 1)
419
+ elif len(array.shape) == 2:
420
+ if array.shape[0] == 2:
421
+ array = array.T
422
+ else:
423
+ raise ValueError("numpy-массив должен быть либо одномерным, либо двухмерным")
424
+
425
+ if array.dtype == np.int16:
426
+ input_format = "s16le"
427
+ elif array.dtype == np.int32:
428
+ input_format = "s32le"
429
+ elif array.dtype == np.float32:
430
+ input_format = "f32le"
431
+ elif array.dtype == np.float64:
432
+ input_format = "f64le"
433
+ else:
434
+ raise NotSupportedDataType(f"Этот тип данных не поддерживается {array.dtype}")
435
+
436
+ if array.shape[1] == 1:
437
+ audio_bytes = array.tobytes()
438
+
439
+ channels = 1
440
+
441
+ elif array.shape[1] == 2:
442
+ audio_bytes = array.tobytes()
443
+
444
+ channels = 2
445
+ else:
446
+ raise ValueError("numpy-массив должен содержать 1 или 2 канала")
447
+
448
+ else:
449
+ raise ValueError("Вход должен быть numpy-массивом")
450
+
451
+ if o:
452
+ if isinstance(o, Path):
453
+ o = str(o)
454
+ output_dir = os.path.dirname(o)
455
+ output_base = os.path.basename(o)
456
+ output_name, output_ext = os.path.splitext(output_base)
457
+ if output_dir != "":
458
+ os.makedirs(output_dir, exist_ok=True)
459
+ if output_ext == "":
460
+ if of:
461
+ o += f".{of}"
462
+ else:
463
+ o += f".mp3"
464
+ elif output_ext == ".":
465
+ if of:
466
+ o += f"{of}"
467
+ else:
468
+ o += f"mp3"
469
+ else:
470
+ raise NotOutputFileSpecified("Не указан путь к выходному файлу")
471
+
472
+ if of:
473
+ if of in self.output_formats:
474
+ output_name, output_ext = os.path.splitext(o)
475
+ if output_ext == f".{of}":
476
+ pass
477
+ else:
478
+ o = f"{os.path.join(output_dir, output_name)}.{of}"
479
+ else:
480
+ raise NotSupportedFormat(f"Неподдерживаемый формат: {of}")
481
+ else:
482
+ of = os.path.splitext(o)[1].strip(".")
483
+ if of in self.output_formats:
484
+ pass
485
+ else:
486
+ raise NotSupportedFormat(f"Неподдерживаемый формат: {of}")
487
+
488
+ if sr:
489
+ if isinstance(sr, int):
490
+ sample_rate_fixed = self.fit_sr(f=of, sr=sr)
491
+ elif isinstance(sr, float):
492
+ sr = int(sr)
493
+ sample_rate_fixed = self.fit_sr(f=of, sr=sr)
494
+ else:
495
+ raise SampleRateError(f"Частота дискретизации должна быть числом\n\nЗначение: {sr}\nТип: {type(sr)}")
496
+ else:
497
+ raise SampleRateError("Не указана частота дискретизации")
498
+
499
+ bitrate_fixed = "320k"
500
+
501
+ if of not in ["wav", "flac", "aiff"]:
502
+ if br:
503
+ if isinstance(br, int):
504
+ bitrate_fixed = self.fit_br(f=of, br=br)
505
+ elif isinstance(br, float):
506
+ bitrate_fixed = self.fit_br(f=of, br=int(br))
507
+ elif isinstance(br, str):
508
+ bitrate_fixed = self.fit_br(f=of, br=int(br.strip("k").strip("K")))
509
+ else:
510
+ bitrate_fixed = self.fit_br(f=of, br=320)
511
+ else:
512
+ bitrate_fixed = self.fit_br(of, 320)
513
+
514
+ format_settings = {
515
+ "wav": [
516
+ "-c:a",
517
+ "pcm_f32le",
518
+ "-sample_fmt",
519
+ "flt",
520
+ ],
521
+ "aiff": [
522
+ "-c:a",
523
+ "pcm_f32le",
524
+ "-sample_fmt",
525
+ "flt",
526
+ ],
527
+ "flac": [
528
+ "-c:a",
529
+ "flac",
530
+ "-compression_level",
531
+ "12",
532
+ "-sample_fmt",
533
+ "s32",
534
+ ],
535
+ "mp3": [
536
+ "-c:a",
537
+ "libmp3lame",
538
+ "-b:a",
539
+ f"{bitrate_fixed}k",
540
+ ],
541
+ "ogg": [
542
+ "-c:a",
543
+ "libvorbis",
544
+ "-b:a",
545
+ f"{bitrate_fixed}k",
546
+ ],
547
+ "opus": [
548
+ "-c:a",
549
+ "libopus",
550
+ "-b:a",
551
+ f"{bitrate_fixed}k",
552
+ ],
553
+ "m4a": [
554
+ "-c:a",
555
+ "aac",
556
+ "-b:a",
557
+ f"{bitrate_fixed}k",
558
+ ],
559
+ "aac": [
560
+ "-c:a",
561
+ "aac",
562
+ "-b:a",
563
+ f"{bitrate_fixed}k",
564
+ ],
565
+ "ac3": [
566
+ "-c:a",
567
+ "ac3",
568
+ "-b:a",
569
+ f"{bitrate_fixed}k",
570
+ ],
571
+ }
572
+
573
+ cmd = [
574
+ self.ffmpeg_path,
575
+ "-y",
576
+ "-f",
577
+ input_format,
578
+ "-ar",
579
+ str(sr),
580
+ "-ac",
581
+ str(channels),
582
+ "-i",
583
+ "pipe:0",
584
+ "-ac",
585
+ str(channels),
586
+ ]
587
+
588
+ cmd.extend(["-ar", str(sample_rate_fixed)])
589
+ cmd.extend(format_settings[of])
590
+ o_dir, o_base = os.path.split(o)
591
+ o_base_n, o_base_ext = os.path.splitext(o_base)
592
+ o_base_n = self.sanitize(o_base_n)
593
+ o_base_n = self.short(o_base_n)
594
+ o = os.path.join(o_dir, f"{o_base_n}{o_base_ext}")
595
+ o = self.iter(o)
596
+ cmd.append(o)
597
+
598
+ process = subprocess.Popen(
599
+ cmd,
600
+ stdin=subprocess.PIPE,
601
+ stdout=subprocess.PIPE,
602
+ stderr=subprocess.PIPE,
603
+ )
604
+
605
+ try:
606
+ stdout, stderr = process.communicate(input=audio_bytes, timeout=300)
607
+ except subprocess.TimeoutExpired:
608
+ process.kill()
609
+ raise ErrorEncode("FFmpeg timeout: операция заняла слишком много времени")
610
+
611
+ if process.returncode != 0:
612
+ raise ErrorEncode(f"FFmpeg завершился с ошибкой (код: {process.returncode})")
613
+
614
+ return os.path.abspath(o)
615
+
616
+ class Inverter(Audio):
617
+ def __init__(self):
618
+ super().__init__()
619
+ self.test = "test"
620
+ self.w_types = [
621
+ "boxcar", # Прямоугольное окно
622
+ "triang", # Треугольное окно
623
+ "blackman", # Окно Блэкмана
624
+ "hamming", # Окно Хэмминга
625
+ "hann", # Окно Ханна
626
+ "bartlett", # Окно Бартлетта
627
+ "flattop", # Окно с плоской вершиной
628
+ "parzen", # Окно Парзена
629
+ "bohman", # Окно Бохмана
630
+ "blackmanharris", # Окно Блэкмана-Харриса
631
+ "nuttall", # Окно Нуттала
632
+ "barthann", # Окно Бартлетта-Ханна
633
+ "cosine", # Косинусное окно
634
+ "exponential", # Экспоненциальное окно
635
+ "tukey", # Окно Туки
636
+ "taylor", # Окно Тейлора
637
+ "lanczos", # Окно Ланцоша
638
+ ]
639
+
640
+ def load_audio(self, filepath):
641
+ """Загрузка аудиофайла с помощью librosa"""
642
+ try:
643
+ y, sr, _ = self.read(i=filepath, sr=None, mono=False)
644
+ return y, sr
645
+ except Exception as e:
646
+ print(f"Ошибка загрузки аудио: {e}")
647
+ return None, None
648
+
649
+ def process_channel(self, y1_ch, y2_ch, sr, method, w_size=2048, overlap=2, w_type="hann"):
650
+ """Обработка одного аудиоканала"""
651
+ HOP_LENGTH = w_size // overlap
652
+ if method == "waveform":
653
+ return y1_ch - y2_ch
654
+
655
+ elif method == "spectrogram":
656
+ # Вычисляем спектрограммы
657
+ S1 = librosa.stft(
658
+ y1_ch, n_fft=w_size, hop_length=HOP_LENGTH, win_length=w_size
659
+ )
660
+ S2 = librosa.stft(
661
+ y2_ch, n_fft=w_size, hop_length=HOP_LENGTH, win_length=w_size
662
+ )
663
+
664
+ # Амплитудные спектрограммы
665
+ mag1 = np.abs(S1)
666
+ mag2 = np.abs(S2)
667
+
668
+ # Спектральное вычитание
669
+ mag_result = np.maximum(mag1 - mag2, 0)
670
+
671
+ # Сохраняем фазовую информацию исходного сигнала
672
+ phase = np.angle(S1)
673
+
674
+ # Комбинируем амплитуду результата с фазой
675
+ S_result = mag_result * np.exp(1j * phase)
676
+
677
+ # Обратное преобразование
678
+ return librosa.istft(
679
+ S_result,
680
+ n_fft=w_size,
681
+ hop_length=HOP_LENGTH,
682
+ win_length=w_size,
683
+ length=len(y1_ch),
684
+ )
685
+
686
+ def process_audio(self, audio1_path, audio2_path, out_format, method, output_path="./inverted.mp3", w_size=2048, overlap=2, w_type="hann"):
687
+ # Загрузка аудиофайлов
688
+ y1, sr1 = self.load_audio(audio1_path)
689
+ y2, sr2 = self.load_audio(audio2_path)
690
+
691
+ if sr1 is None or sr2 is None:
692
+ raise Exception("Произошла ошибка при чтении файлов")
693
+
694
+ # Определяем количество каналов
695
+ channels1 = 1 if y1.ndim == 1 else y1.shape[0]
696
+ channels2 = 1 if y2.ndim == 1 else y2.shape[0]
697
+
698
+ # Преобразование в форму (samples, channels)
699
+ if channels1 > 1:
700
+ y1 = y1.T # (channels, samples) -> (samples, channels)
701
+ else:
702
+ y1 = y1.reshape(-1, 1)
703
+
704
+ if channels2 > 1:
705
+ y2 = y2.T # (channels, samples) -> (samples, channels)
706
+ else:
707
+ y2 = y2.reshape(-1, 1)
708
+
709
+ if sr1 != sr2:
710
+ if channels2 > 1:
711
+ # Ресемплинг для каждого канала отдельно
712
+ y2_resampled_list = []
713
+ for c in range(channels2):
714
+ channel_resampled = librosa.resample(
715
+ y2[:, c], orig_sr=sr2, target_sr=sr1
716
+ )
717
+ y2_resampled_list.append(channel_resampled)
718
+
719
+ # Находим минимальную длину среди всех каналов
720
+ min_channel_length = min(len(ch) for ch in y2_resampled_list)
721
+
722
+ # Обрезаем все каналы до одинаковой длины и собираем в массив
723
+ y2_resampled = np.zeros((min_channel_length, channels2), dtype=np.float32)
724
+ for c, channel in enumerate(y2_resampled_list):
725
+ y2_resampled[:, c] = channel[:min_channel_length]
726
+
727
+ y2 = y2_resampled
728
+ else:
729
+ y2 = librosa.resample(y2[:, 0], orig_sr=sr2, target_sr=sr1)
730
+ y2 = y2.reshape(-1, 1)
731
+ sr2 = sr1
732
+
733
+ # Приводим к одинаковой длине
734
+ min_len = min(len(y1), len(y2))
735
+ y1 = y1[:min_len]
736
+ y2 = y2[:min_len]
737
+
738
+ # Обрабатываем каждый канал отдельно
739
+ result_channels = []
740
+
741
+ # Если основной сигнал моно, а удаляемый стерео - преобразуем удаляемый в моно
742
+ if channels1 == 1 and channels2 > 1:
743
+ y2 = y2.mean(axis=1, keepdims=True)
744
+ channels2 = 1
745
+
746
+ for c in range(channels1):
747
+ # Выбираем канал для основного сигнала
748
+ y1_ch = y1[:, c]
749
+
750
+ # Выбираем канал для удаляемого сигнала
751
+ if channels2 == 1:
752
+ y2_ch = y2[:, 0]
753
+ else:
754
+ # Если каналов удаляемого сигнала больше, используем соответствующий канал
755
+ y2_ch = y2[:, min(c, channels2 - 1)]
756
+
757
+ # Обрабатываем канал
758
+ result_ch = self.process_channel(y1_ch, y2_ch, sr1, method, w_size=w_size, overlap=overlap, w_type=w_type)
759
+ result_channels.append(result_ch)
760
+
761
+ # Собираем каналы в один массив
762
+ if len(result_channels) > 1:
763
+ result = np.column_stack(result_channels)
764
+ else:
765
+ result = np.array(result_channels[0])
766
+
767
+ # Нормализация (предотвращение клиппинга)
768
+ if result.ndim > 1:
769
+ # Для многоканального аудио нормализуем каждый канал отдельно
770
+ for c in range(result.shape[1]):
771
+ channel = result[:, c]
772
+ max_val = np.max(np.abs(channel))
773
+ if max_val > 0:
774
+ result[:, c] = channel * 0.9 / max_val
775
+ else:
776
+ max_val = np.max(np.abs(result))
777
+ if max_val > 0:
778
+ result = result * 0.9 / max_val
779
+
780
+ inverted = self.write(o=output_path, array=result.T, sr=sr1, of=out_format, br="320k")
781
+ return inverted
mvsepless/downloader.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yt_dlp
3
+ import time
4
+ from tqdm import tqdm
5
+ import urllib.request
6
+
7
+ DOWNLOAD_DIR = os.environ.get(
8
+ "MVSEPLESS_DOWNLOAD_DIR", os.path.join(os.getcwd(), "downloaded")
9
+ )
10
+
11
+ def dw_file(url_model: str, local_path: str, retries: int = 30):
12
+ dir_name = os.path.dirname(local_path)
13
+ if dir_name != "":
14
+ os.makedirs(dir_name, exist_ok=True)
15
+
16
+ class TqdmUpTo(tqdm):
17
+ def update_to(self, b=1, bsize=1, tsize=None):
18
+ if tsize is not None:
19
+ self.total = tsize
20
+ self.update(b * bsize - self.n)
21
+
22
+ for attempt in range(retries):
23
+ try:
24
+ with TqdmUpTo(
25
+ unit="B",
26
+ unit_scale=True,
27
+ unit_divisor=1024,
28
+ miniters=1,
29
+ desc=os.path.basename(local_path),
30
+ ) as t:
31
+ urllib.request.urlretrieve(
32
+ url_model, local_path, reporthook=t.update_to
33
+ )
34
+ # Если дошли сюда - загрузка успешна, прерываем цикл
35
+ break
36
+ except Exception as e:
37
+ print(f"Попытка {attempt + 1}/{retries} не удалась. Ошибка: {e}")
38
+ if attempt < retries - 1: # Если это не последняя попытка
39
+ print("Повторная попытка...")
40
+ time.sleep(2) # Небольшая задержка перед повторной попыткой
41
+ else:
42
+ print("Все попытки загрузки завершились неудачно")
43
+ raise # Пробрасываем исключение дальше
44
+
45
+ def dw_yt_dlp(
46
+ url,
47
+ output_dir=None,
48
+ cookie=None,
49
+ output_format="mp3",
50
+ output_bitrate="320",
51
+ title=None,
52
+ ):
53
+ # Подготовка шаблона имени файла
54
+ outtmpl = "%(title)s.%(ext)s" if title is None else f"{title}.%(ext)s"
55
+
56
+ ydl_opts = {
57
+ "format": "bestaudio/best",
58
+ "outtmpl": os.path.join(DOWNLOAD_DIR if not output_dir else output_dir, outtmpl),
59
+ "postprocessors": [
60
+ {
61
+ "key": "FFmpegExtractAudio",
62
+ "preferredcodec": output_format,
63
+ "preferredquality": output_bitrate,
64
+ }
65
+ ],
66
+ "noplaylist": True, # Скачивать только одно видео, не плейлист
67
+ "quiet": True, # Отключить вывод в консоль
68
+ "no_warnings": True, # Скрыть предупреждения
69
+ }
70
+
71
+ # Добавляем cookies если указаны
72
+ if cookie and os.path.exists(cookie):
73
+ ydl_opts["cookiefile"] = cookie
74
+
75
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
76
+ try:
77
+ info = ydl.extract_info(url, download=True)
78
+ if "_type" in info and info["_type"] == "playlist":
79
+ # Для плейлистов берем первое видео
80
+ entry = info["entries"][0]
81
+ filename = ydl.prepare_filename(entry)
82
+ else:
83
+ # Для одиночного видео
84
+ filename = ydl.prepare_filename(info)
85
+
86
+ # Заменяем оригинальное расширение на выбранный формат
87
+ base, _ = os.path.splitext(filename)
88
+ audio_file = base + f".{output_format}"
89
+
90
+ return os.path.join(DOWNLOAD_DIR, audio_file)
91
+ except Exception as e:
92
+ return None
mvsepless/ensemble.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ __author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
3
+
4
+ import os
5
+ import sys
6
+ import librosa
7
+ import tempfile
8
+ import numpy as np
9
+ import argparse
10
+ if not __package__:
11
+ from audio import Audio
12
+ from namer import Namer
13
+ else:
14
+ from .audio import Audio
15
+ from .namer import Namer
16
+
17
+ audio = Audio()
18
+
19
+ def stft(wave, nfft, hl):
20
+ wave_left = np.asfortranarray(wave[0])
21
+ wave_right = np.asfortranarray(wave[1])
22
+ spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
23
+ spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
24
+ spec = np.asfortranarray([spec_left, spec_right])
25
+ return spec
26
+
27
+
28
+ def istft(spec, hl, length):
29
+ spec_left = np.asfortranarray(spec[0])
30
+ spec_right = np.asfortranarray(spec[1])
31
+ wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
32
+ wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
33
+ wave = np.asfortranarray([wave_left, wave_right])
34
+ return wave
35
+
36
+
37
+ def absmax(a, *, axis):
38
+ dims = list(a.shape)
39
+ dims.pop(axis)
40
+ indices = np.ogrid[tuple(slice(0, d) for d in dims)]
41
+ argmax = np.abs(a).argmax(axis=axis)
42
+ # Convert indices to list before insertion
43
+ indices = list(indices)
44
+ indices.insert(axis % len(a.shape), argmax)
45
+ return a[tuple(indices)]
46
+
47
+
48
+ def absmin(a, *, axis):
49
+ dims = list(a.shape)
50
+ dims.pop(axis)
51
+ indices = np.ogrid[tuple(slice(0, d) for d in dims)]
52
+ argmax = np.abs(a).argmin(axis=axis)
53
+ indices.insert((len(a.shape) + axis) % len(a.shape), argmax)
54
+ return a[tuple(indices)]
55
+
56
+
57
+ def lambda_max(arr, axis=None, key=None, keepdims=False):
58
+ idxs = np.argmax(key(arr), axis)
59
+ if axis is not None:
60
+ idxs = np.expand_dims(idxs, axis)
61
+ result = np.take_along_axis(arr, idxs, axis)
62
+ if not keepdims:
63
+ result = np.squeeze(result, axis=axis)
64
+ return result
65
+ else:
66
+ return arr.flatten()[idxs]
67
+
68
+
69
+ def lambda_min(arr, axis=None, key=None, keepdims=False):
70
+ idxs = np.argmin(key(arr), axis)
71
+ if axis is not None:
72
+ idxs = np.expand_dims(idxs, axis)
73
+ result = np.take_along_axis(arr, idxs, axis)
74
+ if not keepdims:
75
+ result = np.squeeze(result, axis=axis)
76
+ return result
77
+ else:
78
+ return arr.flatten()[idxs]
79
+
80
+
81
+ def average_waveforms(pred_track, weights, algorithm):
82
+ """
83
+ :param pred_track: shape = (num, channels, length)
84
+ :param weights: shape = (num, )
85
+ :param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
86
+ :return: averaged waveform in shape (channels, length)
87
+ """
88
+
89
+ pred_track = np.array(pred_track)
90
+ final_length = pred_track.shape[-1]
91
+
92
+ mod_track = []
93
+ for i in range(pred_track.shape[0]):
94
+ if algorithm == "avg_wave":
95
+ mod_track.append(pred_track[i] * weights[i])
96
+ elif algorithm in ["median_wave", "min_wave", "max_wave"]:
97
+ mod_track.append(pred_track[i])
98
+ elif algorithm in ["avg_fft", "min_fft", "max_fft", "median_fft"]:
99
+ spec = stft(pred_track[i], nfft=2048, hl=1024)
100
+ if algorithm in ["avg_fft"]:
101
+ mod_track.append(spec * weights[i])
102
+ else:
103
+ mod_track.append(spec)
104
+ pred_track = np.array(mod_track)
105
+
106
+ if algorithm in ["avg_wave"]:
107
+ pred_track = pred_track.sum(axis=0)
108
+ pred_track /= np.array(weights).sum().T
109
+ elif algorithm in ["median_wave"]:
110
+ pred_track = np.median(pred_track, axis=0)
111
+ elif algorithm in ["min_wave"]:
112
+ pred_track = np.array(pred_track)
113
+ pred_track = lambda_min(pred_track, axis=0, key=np.abs)
114
+ elif algorithm in ["max_wave"]:
115
+ pred_track = np.array(pred_track)
116
+ pred_track = lambda_max(pred_track, axis=0, key=np.abs)
117
+ elif algorithm in ["avg_fft"]:
118
+ pred_track = pred_track.sum(axis=0)
119
+ pred_track /= np.array(weights).sum()
120
+ pred_track = istft(pred_track, 1024, final_length)
121
+ elif algorithm in ["min_fft"]:
122
+ pred_track = np.array(pred_track)
123
+ pred_track = lambda_min(pred_track, axis=0, key=np.abs)
124
+ pred_track = istft(pred_track, 1024, final_length)
125
+ elif algorithm in ["max_fft"]:
126
+ pred_track = np.array(pred_track)
127
+ pred_track = absmax(pred_track, axis=0)
128
+ pred_track = istft(pred_track, 1024, final_length)
129
+ elif algorithm in ["median_fft"]:
130
+ pred_track = np.array(pred_track)
131
+ pred_track = np.median(pred_track, axis=0)
132
+ pred_track = istft(pred_track, 1024, final_length)
133
+ return pred_track
134
+
135
+
136
+ def ensemble_audio_files(
137
+ files, output="res.wav", ensemble_type="avg_wave", weights=None, out_format="wav", add_wav=False
138
+ ) -> str | tuple[str, str]:
139
+ """
140
+ Основная функция для объединения аудиофайлов
141
+
142
+ :param files: список путей к аудиофайлам
143
+ :param output: путь для сохранения результа��а
144
+ :param ensemble_type: метод объединения (avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft)
145
+ :param weights: список весов для каждого файла (None для равных весов)
146
+ :return: None
147
+ """
148
+ print("Алгоритм склеивания: {}".format(ensemble_type))
149
+ print("Количество входных файлов: {}".format(len(files)))
150
+ if weights is not None:
151
+ weights = np.array(weights)
152
+ else:
153
+ weights = np.ones(len(files))
154
+ print("Весы: {}".format(weights))
155
+ print("Имя выходного файла: {}".format(output))
156
+
157
+ data = []
158
+ sr = None
159
+ max_length = 0
160
+ max_channels = 0
161
+
162
+ # Первый проход: определяем максимальную длину и количество каналов
163
+ for f in files:
164
+ if not os.path.isfile(f):
165
+ print("Не удается найти файл: {}. Check paths.".format(f))
166
+ exit()
167
+ print("Читается файл: {}".format(f))
168
+ wav, current_sr, _ = audio.read(i=f, sr=None, mono=False)
169
+ if sr is None:
170
+ sr = current_sr
171
+ elif sr != current_sr:
172
+ print("Частота дискретизации на всех файлах должна быть одинаковой")
173
+ exit()
174
+
175
+ # Определяем количество каналов
176
+ if wav.ndim == 1:
177
+ channels = 1
178
+ length = len(wav)
179
+ else:
180
+ channels = wav.shape[0]
181
+ length = wav.shape[1]
182
+
183
+ max_length = max(max_length, length)
184
+ max_channels = max(max_channels, channels)
185
+ print("Форма сигнала: {} частота дискретизации: {}".format(wav.shape, sr))
186
+
187
+ # Второй проход: обработка и выравнивание файлов
188
+ for f in files:
189
+ wav, current_sr, _ = audio.read(i=f, sr=None, mono=False)
190
+
191
+ # Обработка каналов
192
+ if wav.ndim == 1:
193
+ # Моно -> стерео
194
+ wav = np.vstack([wav, wav])
195
+ elif wav.shape[0] == 1:
196
+ # Один канал -> стерео
197
+ wav = np.vstack([wav[0], wav[0]])
198
+ elif wav.shape[0] > 2:
199
+ # Более 2 каналов -> берем первые два
200
+ wav = wav[:2, :]
201
+
202
+ # Выравнивание длины
203
+ if wav.shape[1] < max_length:
204
+ pad_width = ((0, 0), (0, max_length - wav.shape[1]))
205
+ wav = np.pad(wav, pad_width, mode="constant")
206
+ elif wav.shape[1] > max_length:
207
+ wav = wav[:, :max_length]
208
+
209
+ data.append(wav)
210
+
211
+ data = np.array(data)
212
+ res = average_waveforms(data, weights, ensemble_type)
213
+ print("Форма результата: {}".format(res.shape))
214
+
215
+ output_wav = f"{output}_orig.wav"
216
+ output = f"{output}.{out_format}"
217
+
218
+ output = audio.write(o=output, array=res.T, sr=sr, of=out_format, br="320k")
219
+ if add_wav:
220
+ output_wav = audio.write(o=output_wav, array=res.T, sr=sr, of="wav")
221
+ return output, output_wav
222
+
223
+ else:
224
+ return output
mvsepless/infer.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import argparse
5
+ import time
6
+ from datetime import datetime
7
+ import gc
8
+ import glob
9
+ import yaml
10
+ import torch
11
+ import numpy as np
12
+ import soundfile as sf
13
+ import torch.nn as nn
14
+
15
+ from typing import Literal
16
+
17
+ from audio import Audio
18
+ from namer import Namer
19
+
20
+ namer = Namer()
21
+ audio = Audio()
22
+
23
+ from infer_utils import (
24
+ prefer_target_instrument,
25
+ demix,
26
+ get_model_from_config
27
+ )
28
+
29
+
30
+ def normalize_peak(audio, peak):
31
+ current_peak = np.max(np.abs(audio))
32
+ if current_peak == 0:
33
+ return audio # избегаем деления на ноль
34
+ scale_factor = peak / current_peak
35
+ return audio * scale_factor
36
+
37
+
38
+ gc.enable()
39
+
40
+
41
+ def cleanup_model(model):
42
+ try:
43
+ if isinstance(model, torch.nn.DataParallel):
44
+ model = model.module
45
+
46
+ model.to("cpu")
47
+
48
+ for name, param in list(model.named_parameters()):
49
+ del param
50
+ for name, buf in list(model.named_buffers()):
51
+ del buf
52
+
53
+ del model
54
+
55
+ if torch.cuda.is_available():
56
+ torch.cuda.empty_cache()
57
+ torch.cuda.ipc_collect()
58
+
59
+ gc.collect()
60
+ sys.stdout.write(json.dumps({"cleanup": "Модель выгружена из памяти"}, ensure_ascii=False) + '\n')
61
+ sys.stdout.flush()
62
+ except Exception as e:
63
+ sys.stdout.write(json.dumps({"error": f"Ошибка при выгрузке модели: {str(e)}"}, ensure_ascii=False) + '\n')
64
+ sys.stdout.flush()
65
+
66
+
67
+ def once_inference(
68
+ path: str = None,
69
+ model: any = None,
70
+ config: any = None,
71
+ device: any = None,
72
+ model_type: str = None,
73
+ extract_instrumental: bool = False,
74
+ detailed_pbar: bool = False,
75
+ output_format: Literal[
76
+ "mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "aiff"
77
+ ] = "mp3",
78
+ output_bitrate: str = "320k",
79
+ use_tta: bool = False,
80
+ verbose: bool = False,
81
+ model_name: str = None,
82
+ sample_rate: int = 44100,
83
+ instruments: list = [],
84
+ store_dir: str = None,
85
+ template: str = None,
86
+ selected_instruments: list = [],
87
+ model_id: int = 0,
88
+ ):
89
+ results = []
90
+ sys.stdout.write(json.dumps({"reading": path}, ensure_ascii=False) + '\n')
91
+ sys.stdout.flush()
92
+ sys.stdout.write(json.dumps({"selected_stems": selected_instruments}, ensure_ascii=False) + '\n')
93
+ sys.stdout.flush()
94
+ sys.stdout.write(json.dumps({"stems": list(instruments)}, ensure_ascii=False) + '\n')
95
+ sys.stdout.flush()
96
+
97
+ if config.training.target_instrument is not None:
98
+ sys.stdout.write(json.dumps({"target_instrument": config.training.target_instrument}, ensure_ascii=False) + '\n')
99
+ sys.stdout.flush()
100
+
101
+ try:
102
+ mix, sr, _ = audio.read(i=path, sr=sample_rate, mono=False)
103
+ except Exception as e:
104
+ error_msg = f"Не удалось прочитать аудио: {path}\nОшибка: {e}"
105
+ sys.stdout.write(json.dumps({"error": error_msg}, ensure_ascii=False) + '\n')
106
+ sys.stdout.flush()
107
+ return results
108
+
109
+ mix_orig = mix.copy()
110
+
111
+ mean = std = None
112
+ if config.inference.get("normalize", False):
113
+ mono = mix.mean(0)
114
+ mean = mono.mean()
115
+ std = mono.std()
116
+ mix = (mix - mean) / std
117
+
118
+ if use_tta:
119
+ track_proc_list = [mix.copy(), mix[::-1].copy(), -1.0 * mix.copy()]
120
+ else:
121
+ track_proc_list = [mix.copy()]
122
+ full_result = []
123
+ for m in track_proc_list:
124
+ try:
125
+ waveforms = demix(
126
+ config, model, m, device, pbar=detailed_pbar, model_type=model_type
127
+ )
128
+
129
+ full_result.append(waveforms)
130
+ except Exception as e:
131
+ sys.stdout.write(json.dumps({"error": f"Ошибка при демиксе: {e}"}, ensure_ascii=False) + '\n')
132
+ sys.stdout.flush()
133
+ del m
134
+ gc.collect()
135
+
136
+ if not full_result:
137
+ sys.stdout.write(json.dumps({"error": "Пустой результат демикса."}, ensure_ascii=False) + '\n')
138
+ sys.stdout.flush()
139
+ return results
140
+
141
+ waveforms = full_result[0]
142
+ for i in range(1, len(full_result)):
143
+ d = full_result[i]
144
+ for el in d:
145
+ if i == 2:
146
+ waveforms[el] += -1.0 * d[el]
147
+ elif i == 1:
148
+ waveforms[el] += d[el][::-1].copy()
149
+ else:
150
+ waveforms[el] += d[el]
151
+ for el in waveforms:
152
+ waveforms[el] /= len(full_result)
153
+
154
+ if (
155
+ extract_instrumental and config.training.target_instrument is not None
156
+ ): # Если включен "Extract Instrumental / Извлечь инструментал" и найден целевой инструмент
157
+ second_stem = [
158
+ s
159
+ for s in config.training.instruments
160
+ if s != config.training.target_instrument
161
+ ]
162
+ if second_stem:
163
+ second_stem_key = second_stem[0]
164
+ if second_stem_key not in instruments:
165
+ instruments.append(second_stem_key)
166
+ waveforms[second_stem_key] = mix_orig - waveforms[instruments[0]]
167
+
168
+ elif (
169
+ extract_instrumental
170
+ and selected_instruments
171
+ and config.training.target_instrument is None
172
+ ): # Если включен "Extract Instrumental / Извлечь инструментал" и выбраны инструменты, то создаются стемы "inverted -" и "inverted +" (если не найден целевого инструмент)
173
+
174
+
175
+ all_instruments = config.training.instruments
176
+ if len(all_instruments) > 2:
177
+
178
+ waveforms["inverted -"] = mix_orig.copy()
179
+ for instr in instruments:
180
+ if instr in waveforms:
181
+ waveforms["inverted -"] -= waveforms[
182
+ instr
183
+ ] # стем "inverted -": вычитание выбранного стема из оригинального сигнала (не всегда хорошо)
184
+
185
+ if "inverted -" not in instruments:
186
+ instruments.append("inverted -")
187
+
188
+ unselected_stems = [s for s in all_instruments if s not in selected_instruments]
189
+ if unselected_stems:
190
+ waveforms["inverted +"] = np.zeros_like(mix_orig)
191
+ for stem in unselected_stems:
192
+ if stem in waveforms:
193
+ waveforms["inverted +"] += waveforms[
194
+ stem
195
+ ] # стем "inverted +": сложение не выбранных инструментов в один стем
196
+ if "inverted +" not in instruments:
197
+ instruments.append("inverted +")
198
+
199
+ peak = np.max(np.abs(waveforms["inverted -"]))
200
+ waveforms["inverted +"] = normalize_peak(waveforms["inverted +"], peak)
201
+
202
+ elif (
203
+ extract_instrumental
204
+ and not selected_instruments
205
+ and config.training.target_instrument is None
206
+ and (
207
+ all(
208
+ instr in config.training.instruments
209
+ for instr in ["bass", "drums", "other", "vocals"]
210
+ )
211
+ or all(
212
+ instr in config.training.instruments
213
+ for instr in ["bass", "drums", "other", "vocals", "piano", "guitar"]
214
+ )
215
+ )
216
+ ):
217
+
218
+ waveforms["instrumental -"] = mix_orig.copy()
219
+ waveforms["instrumental -"] -= waveforms[
220
+ "vocals"
221
+ ] # стем "inverted -": вычитание выбранного стема из оригинального сигнала (не всегда хорошо)
222
+
223
+ if "instrumental -" not in instruments:
224
+ instruments.append("instrumental -")
225
+
226
+ all_instruments = config.training.instruments
227
+ non_vocal_stems = [s for s in all_instruments if s not in ["vocals"]]
228
+ if non_vocal_stems:
229
+ waveforms["instrumental +"] = np.zeros_like(mix_orig)
230
+ for stem in non_vocal_stems:
231
+ if stem in waveforms:
232
+ waveforms["instrumental +"] += waveforms[
233
+ stem
234
+ ] # стем "inverted +": сложение не выбранных инструментов в один стем
235
+ if "instrumental +" not in instruments:
236
+ instruments.append("instrumental +")
237
+
238
+ peak = np.max(np.abs(waveforms["instrumental -"]))
239
+ waveforms["instrumental +"] = normalize_peak(waveforms["instrumental +"], peak)
240
+
241
+ template = namer.sanitize(template)
242
+ template = namer.dedup_template(template, keys=["NAME", "MODEL", "STEM", "ID"])
243
+ template = namer.short(template, length=40)
244
+
245
+ for instr in instruments:
246
+ try:
247
+ estimates = waveforms[instr].T
248
+ if mean is not None and std is not None:
249
+ estimates = estimates * std + mean
250
+
251
+ file_name = os.path.splitext(os.path.basename(path))[0]
252
+ file_name_shorted = namer.short_input_name_template(template, STEM=instr, MODEL=model_name, ID=model_id, NAME=file_name)
253
+ custom_name = namer.template(
254
+ template, STEM=instr, MODEL=model_name, ID=model_id, NAME=file_name_shorted
255
+ )
256
+ output_path = os.path.join(store_dir, f"{custom_name}.{output_format}")
257
+
258
+ sys.stdout.write(json.dumps({"writing": output_path}, ensure_ascii=False) + '\n')
259
+ sys.stdout.flush()
260
+
261
+ output_path = audio.write(
262
+ o=output_path, array=estimates, sr=sr, of=output_format, br=output_bitrate
263
+ ) # запись стема в аудио файл с помощью универсальной функции
264
+
265
+ results.append(
266
+ (instr, output_path)
267
+ ) # запись информации о разделении: (название стема, путь к файлу)
268
+ del estimates
269
+ except Exception as e:
270
+ sys.stdout.write(json.dumps({"error": f"Ошибка при обработке {instr}: {e}"}, ensure_ascii=False) + '\n')
271
+ sys.stdout.flush()
272
+ gc.collect()
273
+
274
+ del mix, mix_orig, waveforms, full_result
275
+ gc.collect()
276
+
277
+ return results
278
+
279
+
280
+ def run_inference(
281
+ model: any = None,
282
+ config: any = None,
283
+ input_path: str = None,
284
+ store_dir: str = None,
285
+ device: any = None,
286
+ model_type: str = None,
287
+ extract_instrumental: bool = False,
288
+ disable_detailed_pbar: bool = False,
289
+ output_format: Literal[
290
+ "mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "aiff"
291
+ ] = "mp3",
292
+ output_bitrate: str = "320k",
293
+ use_tta: bool = False,
294
+ verbose: bool = False,
295
+ model_name: str = None,
296
+ template: str = "NAME_STEM",
297
+ selected_instruments: list = [],
298
+ model_id: int = 0,
299
+ ):
300
+ start_time = time.time()
301
+ if model_type != "vr":
302
+ model.eval()
303
+ sample_rate = 44100
304
+ if "sample_rate" in config.audio:
305
+ sample_rate = config.audio["sample_rate"]
306
+
307
+ instruments = prefer_target_instrument(config)
308
+
309
+ if config.training.target_instrument is not None:
310
+ sys.stdout.write(json.dumps({"info": "Целевой инструмент найден в конфигурации модели. Выбранные стемы будут проигнорированы."}, ensure_ascii=False) + '\n')
311
+ sys.stdout.flush()
312
+ else:
313
+ if selected_instruments is not None and selected_instruments != []:
314
+ instruments = [
315
+ instr for instr in instruments if instr in selected_instruments
316
+ ]
317
+ if verbose:
318
+ sys.stdout.write(json.dumps({"selected_stems": instruments}, ensure_ascii=False) + '\n')
319
+ sys.stdout.flush()
320
+
321
+ os.makedirs(store_dir, exist_ok=True)
322
+
323
+ detailed_pbar = not disable_detailed_pbar
324
+
325
+ results = once_inference(
326
+ path=input_path,
327
+ model=model,
328
+ config=config,
329
+ device=device,
330
+ model_type=model_type,
331
+ extract_instrumental=extract_instrumental,
332
+ detailed_pbar=detailed_pbar,
333
+ output_format=output_format,
334
+ output_bitrate=output_bitrate,
335
+ use_tta=use_tta,
336
+ verbose=verbose,
337
+ model_name=model_name,
338
+ sample_rate=sample_rate,
339
+ instruments=instruments,
340
+ store_dir=store_dir,
341
+ template=template,
342
+ selected_instruments=selected_instruments,
343
+ model_id=model_id,
344
+ )
345
+
346
+ time.sleep(1)
347
+ time_taken = time.time() - start_time
348
+ sys.stdout.write(json.dumps({"time": f"{time_taken:.2f} сек."}, ensure_ascii=False) + '\n')
349
+ sys.stdout.flush()
350
+ sys.stdout.write(json.dumps({"done": results}, ensure_ascii=False) + '\n')
351
+ sys.stdout.flush()
352
+ return results
353
+
354
+
355
+ def load_model(model_type, config_path, start_check_point, device_ids, force_cpu=False):
356
+ device = "cpu"
357
+ if force_cpu:
358
+ device = "cpu"
359
+ elif torch.cuda.is_available():
360
+ sys.stdout.write(json.dumps({"info": "Разделение выполняется на ядрах CUDA. Для выполнения на процессоре установите force_cpu=True."}, ensure_ascii=False) + '\n')
361
+ sys.stdout.flush()
362
+ device = "cuda"
363
+
364
+ if device_ids is None:
365
+ device = "cuda:0"
366
+ elif isinstance(device_ids, (list, tuple)):
367
+ device = f"cuda:{device_ids[0]}" if device_ids else "cuda:0"
368
+ elif isinstance(device_ids, bool):
369
+ device = "cuda:0"
370
+ else:
371
+ device = f"cuda:{int(device_ids)}"
372
+ elif torch.backends.mps.is_available():
373
+ device = "mps"
374
+
375
+ sys.stdout.write(json.dumps({"device": device}, ensure_ascii=False) + '\n')
376
+ sys.stdout.flush()
377
+
378
+ model_load_start_time = time.time()
379
+ torch.backends.cudnn.benchmark = True
380
+
381
+ model, config = get_model_from_config(model_type, config_path)
382
+
383
+ if model_type == "vr":
384
+ model.load_checkpoint(start_check_point, device)
385
+ model.settings(enable_post_process=False,
386
+ post_process_threshold=config.inference.post_process_threshold,
387
+ batch_size=config.inference.batch_size,
388
+ window_size=config.inference.window_size,
389
+ high_end_process=config.inference.high_end_process,
390
+ primary_stem=config.training.instruments[0],
391
+ secondary_stem=config.training.instruments[1]
392
+ )
393
+ return model, config, device
394
+
395
+ elif model_type == "mdxnet":
396
+ if start_check_point != "":
397
+ sys.stdout.write(json.dumps({"checkpoint": start_check_point}) + '\n')
398
+ sys.stdout.flush()
399
+ model.init_onnx_session(start_check_point, device)
400
+
401
+ return model, config, device
402
+
403
+ else:
404
+ if start_check_point != "":
405
+ sys.stdout.write(json.dumps({"checkpoint": start_check_point}) + '\n')
406
+ sys.stdout.flush()
407
+
408
+ if model_type in ["htdemucs", "apollo"]:
409
+ state_dict = torch.load(
410
+ start_check_point, map_location=device, weights_only=False
411
+ )
412
+ if "state" in state_dict:
413
+ state_dict = state_dict["state"]
414
+ if "state_dict" in state_dict:
415
+ state_dict = state_dict["state_dict"]
416
+ else:
417
+ try:
418
+ state_dict = torch.load(
419
+ start_check_point, map_location=device, weights_only=True
420
+ )
421
+ except torch.serialization.pickle.UnpicklingError:
422
+ with torch.serialization.safe_globals([torch._C._nn.gelu]):
423
+ state_dict = torch.load(
424
+ start_check_point, map_location=device, weights_only=True
425
+ )
426
+ try:
427
+ model.load_state_dict(state_dict)
428
+ except RuntimeError:
429
+ model.load_state_dict(state_dict, strict=False)
430
+
431
+ sys.stdout.write(json.dumps({"stems": list(config.training.instruments)}, ensure_ascii=False) + '\n')
432
+ sys.stdout.flush()
433
+
434
+ if (
435
+ isinstance(device_ids, (list, tuple))
436
+ and len(device_ids) > 1
437
+ and not force_cpu
438
+ and torch.cuda.is_available()
439
+ ):
440
+ model = nn.DataParallel(model, device_ids=[int(d) for d in device_ids])
441
+
442
+ model = model.to(device)
443
+
444
+ load_time = time.time() - model_load_start_time
445
+
446
+ sys.stdout.write(json.dumps({"model_load_time": f"{load_time:.2f} сек."}, ensure_ascii=False) + '\n')
447
+ sys.stdout.flush()
448
+
449
+ return model, config, device
450
+
451
+
452
+ def mvsep_offline(
453
+ input_path,
454
+ store_dir,
455
+ model_type,
456
+ config_path,
457
+ start_check_point,
458
+ extract_instrumental,
459
+ output_format,
460
+ output_bitrate,
461
+ model_name,
462
+ template,
463
+ device_ids=None,
464
+ disable_detailed_pbar=False,
465
+ use_tta=False,
466
+ force_cpu=False,
467
+ verbose=False,
468
+ selected_instruments=None,
469
+ model_id=0,
470
+ ):
471
+ model, config, device = load_model(
472
+ model_type, config_path, start_check_point, device_ids, force_cpu
473
+ )
474
+
475
+ results = run_inference(
476
+ model=model,
477
+ config=config,
478
+ input_path=input_path,
479
+ store_dir=store_dir,
480
+ device=device,
481
+ model_type=model_type,
482
+ extract_instrumental=extract_instrumental,
483
+ disable_detailed_pbar=disable_detailed_pbar,
484
+ output_format=output_format,
485
+ output_bitrate=output_bitrate,
486
+ use_tta=use_tta,
487
+ verbose=verbose,
488
+ model_name=model_name,
489
+ template=template,
490
+ selected_instruments=selected_instruments,
491
+ model_id=model_id,
492
+ )
493
+
494
+ if model_type != "vr":
495
+ cleanup_model(model)
496
+ del config
497
+ gc.collect()
498
+ return results
499
+
500
+
501
+ def parse_args():
502
+ parser = argparse.ArgumentParser(
503
+ description="Модифицированный Music-Source-Separation-Training для разделения аудио на источники"
504
+ )
505
+
506
+ # Обязательные аргументы
507
+ parser.add_argument("--input", type=str, help="Путь к входному файлу или папке")
508
+ parser.add_argument(
509
+ "--store_dir", type=str, required=True, help="Путь для сохранения результатов"
510
+ )
511
+
512
+ # Основные параметры модели
513
+ parser.add_argument(
514
+ "--model_type",
515
+ type=str,
516
+ default="htdemucs",
517
+ choices=[
518
+ "mel_band_roformer",
519
+ "bs_roformer",
520
+ "mdx23c",
521
+ "scnet",
522
+ "htdemucs",
523
+ "bandit",
524
+ "bandit_v2",
525
+ "mdxnet",
526
+ "vr"
527
+ ],
528
+ help="Тип модели (по умолчанию: htdemucs)",
529
+ )
530
+ parser.add_argument(
531
+ "--config_path",
532
+ type=str,
533
+ required=True,
534
+ help="Путь к конфигурационному файлу модели",
535
+ )
536
+ parser.add_argument(
537
+ "--start_check_point", type=str, required=True, help="Путь к чекпоинту модели"
538
+ )
539
+
540
+ # Параметры вывода
541
+ parser.add_argument(
542
+ "--output_format",
543
+ type=str,
544
+ default="wav",
545
+ choices=audio.output_formats,
546
+ help="Формат выходных файлов",
547
+ )
548
+ parser.add_argument(
549
+ "--output_bitrate", type=str, required=True, help="Битрейт выходного файла"
550
+ )
551
+
552
+ parser.add_argument(
553
+ "--selected_instruments",
554
+ nargs="+",
555
+ help="Список стемов для сохранения (например: vocals drums)",
556
+ )
557
+ parser.add_argument(
558
+ "--extract_instrumental",
559
+ action="store_true",
560
+ help="Извлечь инструментальную версию",
561
+ )
562
+ parser.add_argument(
563
+ "--template",
564
+ type=str,
565
+ default="NAME_STEM",
566
+ help="Шаблон для имен выходных файлов",
567
+ )
568
+ parser.add_argument(
569
+ "--model_name",
570
+ type=str,
571
+ default="model",
572
+ help="Имя модели для шаблона имен файлов",
573
+ )
574
+ parser.add_argument("-m_id", "--model_id", type=int, required=True, help="Model ID")
575
+ parser.add_argument(
576
+ "--device_ids", nargs="+", help="ID GPU устройств для использования"
577
+ )
578
+ parser.add_argument(
579
+ "--force_cpu", action="store_true", help="Принудительно использовать CPU"
580
+ )
581
+ parser.add_argument(
582
+ "--use_tta", action="store_true", help="Использовать тестовую аугментацию"
583
+ )
584
+ parser.add_argument(
585
+ "--disable_detailed_pbar",
586
+ action="store_true",
587
+ help="Отключить детальный прогресс-бар",
588
+ )
589
+ parser.add_argument("--verbose", action="store_true", help="Подробный вывод")
590
+
591
+ return parser.parse_args()
592
+
593
+
594
+ def main():
595
+ args = parse_args()
596
+
597
+ device_ids = None
598
+ if args.device_ids:
599
+ device_ids = [int(x) for x in args.device_ids]
600
+
601
+ results = mvsep_offline(
602
+ input_path=args.input,
603
+ store_dir=args.store_dir,
604
+ model_type=args.model_type,
605
+ config_path=args.config_path,
606
+ start_check_point=args.start_check_point,
607
+ extract_instrumental=args.extract_instrumental,
608
+ output_format=args.output_format,
609
+ output_bitrate=args.output_bitrate,
610
+ model_name=args.model_name,
611
+ template=args.template,
612
+ device_ids=device_ids,
613
+ disable_detailed_pbar=args.disable_detailed_pbar,
614
+ use_tta=args.use_tta,
615
+ force_cpu=args.force_cpu,
616
+ verbose=args.verbose,
617
+ selected_instruments=args.selected_instruments,
618
+ model_id=args.model_id,
619
+ )
620
+
621
+
622
+ if __name__ == "__main__":
623
+ main()
mvsepless/infer_utils.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ __author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
3
+
4
+ import sys
5
+ import json
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import yaml
10
+ import librosa
11
+ import torch.nn.functional as F
12
+ from ml_collections import ConfigDict
13
+ from omegaconf import OmegaConf
14
+ from typing import Dict, List, Tuple, Any, List, Optional
15
+
16
+
17
+ def load_config(model_type: str, config_path: str) -> Any:
18
+ """
19
+ Load the configuration from the specified path based on the model type.
20
+ """
21
+ try:
22
+ with open(config_path, "r") as f:
23
+ if model_type == "htdemucs":
24
+ config = OmegaConf.load(config_path)
25
+ else:
26
+ config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
27
+ return config
28
+ except FileNotFoundError:
29
+ raise FileNotFoundError(f"Configuration file not found at {config_path}")
30
+ except Exception as e:
31
+ raise ValueError(f"Error loading configuration: {e}")
32
+
33
+
34
+ def get_model_from_config(model_type: str, config_path: str) -> Tuple:
35
+ """
36
+ Load the model specified by the model type and configuration file.
37
+ """
38
+ config = load_config(model_type, config_path)
39
+
40
+ if model_type == "mdx23c":
41
+ from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net
42
+
43
+ model = TFC_TDF_net(config)
44
+
45
+ elif model_type == "mdxnet":
46
+ from models.mdx_net import MDXNet
47
+
48
+ model = MDXNet(**dict(config.model))
49
+
50
+ # В функции get_model_from_config добавьте:
51
+
52
+ elif model_type == "vr":
53
+ from models.vr_arch import VRNet
54
+ # Передаем instruments из config.training в модель
55
+ model = VRNet(**dict(config.model))
56
+
57
+ elif model_type == "htdemucs":
58
+ from models.demucs4ht import get_model
59
+
60
+ model = get_model(config)
61
+
62
+ elif model_type == "mel_band_roformer":
63
+ if hasattr(config, "windowed"): # Это не нарушает совместимость со обычными моделями на Mel-Band Roformer
64
+ from models.windowed_roformer.model import MelBandRoformerWSA
65
+
66
+ model = MelBandRoformerWSA(**dict(config.model))
67
+
68
+ else:
69
+ from models.bs_roformer import MelBandRoformer
70
+
71
+ model = MelBandRoformer(**dict(config.model))
72
+
73
+ elif model_type == "bs_roformer":
74
+ if hasattr(config.model, "use_shared_bias"):
75
+ from models.bs_roformer import BSRoformer_SW
76
+
77
+ model = BSRoformer_SW(**dict(config.model))
78
+ elif hasattr(config.model, "fno"):
79
+ from models.bs_roformer import BSRoformer_FNO
80
+
81
+ model = BSRoformer_FNO(**dict(config.model))
82
+ else:
83
+ from models.bs_roformer import BSRoformer
84
+
85
+ model = BSRoformer(**dict(config.model))
86
+
87
+ elif model_type == "bandit":
88
+ from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple
89
+
90
+ model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model)
91
+
92
+ elif model_type == "bandit_v2":
93
+ from models.bandit_v2.bandit import Bandit
94
+
95
+ model = Bandit(**config.kwargs)
96
+ elif model_type == "scnet_unofficial":
97
+ from models.scnet_unofficial import SCNet
98
+
99
+ model = SCNet(**config.model)
100
+ elif model_type == "scnet":
101
+ from models.scnet import SCNet
102
+
103
+ model = SCNet(**config.model)
104
+ else:
105
+ raise ValueError(f"Unknown model type: {model_type}")
106
+
107
+ return model, config
108
+
109
+ def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor:
110
+ """
111
+ Generate a windowing array with a linear fade-in at the beginning and a fade-out at the end.
112
+ """
113
+ fadein = torch.linspace(0, 1, fade_size)
114
+ fadeout = torch.linspace(1, 0, fade_size)
115
+
116
+ window = torch.ones(window_size)
117
+ window[-fade_size:] = fadeout
118
+ window[:fade_size] = fadein
119
+ return window
120
+
121
+ def demix_mdxnet(
122
+ config: Any,
123
+ model: Any,
124
+ mix: np.ndarray,
125
+ device: torch.device,
126
+ pbar: bool = False,
127
+ ) -> Dict[str, np.ndarray]:
128
+ """
129
+ MDX-Net specific demixing function с поддержкой overlap
130
+ """
131
+ mix_tensor = torch.tensor(mix, dtype=torch.float32)
132
+ inv_mix_tensor = torch.tensor(-mix, dtype=torch.float32)
133
+
134
+ num_overlap = config.inference.num_overlap
135
+ denoise = config.inference.denoise
136
+ stem_name = model.primary_stem
137
+ if denoise:
138
+ processed_wav = model.process_wave(mix_tensor, device, num_overlap, pbar=pbar)
139
+ inv_processed_wav = model.process_wave(inv_mix_tensor, device, num_overlap, pbar=pbar)
140
+ result = processed_wav.cpu().numpy()
141
+ inv_result = inv_processed_wav.cpu().numpy()
142
+ result_separation = (result + -inv_result) * 0.5
143
+ else:
144
+ processed_wav = model.process_wave(mix_tensor, device, num_overlap, pbar=pbar)
145
+ result_separation = processed_wav.cpu().numpy()
146
+
147
+ result_separation = np.nan_to_num(result_separation, nan=0.0, posinf=0.0, neginf=0.0)
148
+
149
+ return {stem_name: result_separation} # Перемещаем на CPU для возврата
150
+
151
+ def demix_vr(
152
+ config: Any,
153
+ model: Any,
154
+ mix: np.ndarray,
155
+ device: torch.device,
156
+ pbar: bool = False,
157
+ ) -> Dict[str, np.ndarray]:
158
+ """
159
+ VR-specific demixing function that processes the entire audio at once
160
+ since VR architecture doesn't support chunk-based processing
161
+ """
162
+ # Convert to tensor and add batch dimension
163
+ return model.demix(mix, config.audio.sample_rate, device, config.inference.aggression)
164
+
165
+ def demix_demucs(config, model, mix, device, pbar=False):
166
+ mix = torch.tensor(mix, dtype=torch.float32)
167
+ chunk_size = config.training.samplerate * config.training.segment
168
+ num_instruments = len(config.training.instruments)
169
+ num_overlap = config.inference.num_overlap
170
+ step = chunk_size // num_overlap
171
+ fade_size = chunk_size // 10 # Добавляем fade_size для оконной функции
172
+ windowing_array = _getWindowingArray(chunk_size, fade_size) # Создаём окно
173
+
174
+ batch_size = config.inference.batch_size
175
+ use_amp = getattr(config.training, "use_amp", True)
176
+
177
+ with torch.cuda.amp.autocast(enabled=use_amp):
178
+ with torch.inference_mode():
179
+ req_shape = (num_instruments,) + mix.shape
180
+ result = torch.zeros(req_shape, dtype=torch.float32)
181
+ counter = torch.zeros(req_shape, dtype=torch.float32)
182
+
183
+ i = 0
184
+ batch_data = []
185
+ batch_locations = []
186
+
187
+ while i < mix.shape[1]:
188
+ part = mix[:, i : i + chunk_size].to(device)
189
+ chunk_len = part.shape[-1]
190
+ pad_mode = "reflect" if chunk_len > chunk_size // 2 else "constant"
191
+ part = nn.functional.pad(
192
+ part, (0, chunk_size - chunk_len), mode=pad_mode, value=0
193
+ )
194
+
195
+ batch_data.append(part)
196
+ batch_locations.append((i, chunk_len))
197
+ i += step
198
+
199
+ if len(batch_data) >= batch_size or i >= mix.shape[1]:
200
+ arr = torch.stack(batch_data, dim=0)
201
+ x = model(arr)
202
+
203
+ window = windowing_array.clone()
204
+ if i - step == 0: # Первый чанк, без fade-in
205
+ window[:fade_size] = 1
206
+ elif i >= mix.shape[1]: # Последний чанк, без fade-out
207
+ window[-fade_size:] = 1
208
+
209
+ for j, (start, seg_len) in enumerate(batch_locations):
210
+ result[..., start : start + seg_len] += (
211
+ x[j, ..., :seg_len].cpu() * window[..., :seg_len]
212
+ )
213
+ counter[..., start : start + seg_len] += window[..., :seg_len]
214
+
215
+ # Output progress
216
+ processed = min(i, mix.shape[1])
217
+ total = mix.shape[1]
218
+ sys.stdout.write(json.dumps({"processing": {"processed": processed, "total": total}}) + '\n')
219
+ sys.stdout.flush()
220
+
221
+ batch_data.clear()
222
+ batch_locations.clear()
223
+
224
+ estimated_sources = result / counter
225
+ estimated_sources = estimated_sources.cpu().numpy()
226
+ np.nan_to_num(estimated_sources, copy=False, nan=0.0)
227
+
228
+ if num_instruments <= 1:
229
+ return estimated_sources
230
+ else:
231
+ instruments = config.training.instruments
232
+ return {k: v for k, v in zip(instruments, estimated_sources)}
233
+
234
+ def demix_generic(
235
+ config: ConfigDict,
236
+ model: torch.nn.Module,
237
+ mix: torch.Tensor,
238
+ device: torch.device,
239
+ pbar: bool = False,
240
+ ) -> Dict[str, np.ndarray]:
241
+ """
242
+ Generic demixing function for models that support chunk-based processing
243
+ """
244
+ mix = torch.tensor(mix, dtype=torch.float32)
245
+ chunk_size = config.audio.chunk_size
246
+ instruments = prefer_target_instrument(config)
247
+ num_instruments = len(instruments)
248
+ num_overlap = config.inference.num_overlap
249
+
250
+ fade_size = chunk_size // 10
251
+ step = chunk_size // num_overlap
252
+ border = chunk_size - step
253
+ length_init = mix.shape[-1]
254
+ windowing_array = _getWindowingArray(chunk_size, fade_size)
255
+
256
+ # Add padding to handle edge artifacts
257
+ if length_init > 2 * border and border > 0:
258
+ mix = nn.functional.pad(mix, (border, border), mode="reflect")
259
+
260
+ batch_size = config.inference.batch_size
261
+ use_amp = getattr(config.training, "use_amp", True)
262
+
263
+ with torch.cuda.amp.autocast(enabled=use_amp):
264
+ with torch.inference_mode():
265
+ # Initialize result and counter tensors
266
+ req_shape = (num_instruments,) + mix.shape
267
+ result = torch.zeros(req_shape, dtype=torch.float32)
268
+ counter = torch.zeros(req_shape, dtype=torch.float32)
269
+
270
+ i = 0
271
+ batch_data = []
272
+ batch_locations = []
273
+
274
+ while i < mix.shape[1]:
275
+ # Extract chunk and apply padding if necessary
276
+ part = mix[:, i : i + chunk_size].to(device)
277
+ chunk_len = part.shape[-1]
278
+
279
+ pad_mode = "reflect" if chunk_len > chunk_size // 2 else "constant"
280
+ part = nn.functional.pad(
281
+ part, (0, chunk_size - chunk_len), mode=pad_mode, value=0
282
+ )
283
+
284
+ batch_data.append(part)
285
+ batch_locations.append((i, chunk_len))
286
+ i += step
287
+
288
+ # Process batch if it's full or the end is reached
289
+ if len(batch_data) >= batch_size or i >= mix.shape[1]:
290
+ arr = torch.stack(batch_data, dim=0)
291
+ x = model(arr)
292
+
293
+ window = windowing_array.clone()
294
+ if i - step == 0: # First audio chunk, no fadein
295
+ window[:fade_size] = 1
296
+ elif i >= mix.shape[1]: # Last audio chunk, no fadeout
297
+ window[-fade_size:] = 1
298
+
299
+ for j, (start, seg_len) in enumerate(batch_locations):
300
+ result[..., start : start + seg_len] += (
301
+ x[j, ..., :seg_len].cpu() * window[..., :seg_len]
302
+ )
303
+ counter[..., start : start + seg_len] += window[..., :seg_len]
304
+
305
+ # Output progress
306
+ processed = min(i, mix.shape[1])
307
+ total = mix.shape[1]
308
+ sys.stdout.write(json.dumps({"processing": {"processed": processed, "total": total}}, ensure_ascii=False) + '\n')
309
+ sys.stdout.flush()
310
+
311
+ batch_data.clear()
312
+ batch_locations.clear()
313
+
314
+ # Compute final estimated sources
315
+ estimated_sources = result / counter
316
+ estimated_sources = estimated_sources.cpu().numpy()
317
+ np.nan_to_num(estimated_sources, copy=False, nan=0.0)
318
+
319
+ # Remove padding
320
+ if length_init > 2 * border and border > 0:
321
+ estimated_sources = estimated_sources[..., border:-border]
322
+
323
+ # Return the result as a dictionary
324
+ return {k: v for k, v in zip(instruments, estimated_sources)}
325
+
326
+ def demix(
327
+ config: ConfigDict,
328
+ model: torch.nn.Module,
329
+ mix: np.ndarray,
330
+ device: torch.device,
331
+ model_type: str,
332
+ pbar: bool = False,
333
+ ) -> Dict[str, np.ndarray]:
334
+ """
335
+ Unified function for audio source separation with support for multiple processing modes.
336
+ """
337
+ # Handle different model types
338
+ if model_type == "vr":
339
+ return demix_vr(config, model, mix, device, pbar)
340
+ elif model_type == "mdxnet":
341
+ return demix_mdxnet(config, model, mix, device, pbar)
342
+ elif model_type == "htdemucs":
343
+ # HTDemucs uses its own processing
344
+ return demix_demucs(config, model, mix, device, pbar)
345
+ else:
346
+ # Generic processing for other models
347
+ return demix_generic(config, model, mix, device, pbar)
348
+
349
+
350
+ def prefer_target_instrument(config: ConfigDict) -> List[str]:
351
+ """
352
+ Return the list of target instruments based on the configuration.
353
+ If a specific target instrument is specified in the configuration,
354
+ it returns a list with that instrument. Otherwise, it returns the list of instruments.
355
+ """
356
+ if config.training.get("target_instrument"):
357
+ return [config.training.target_instrument]
358
+ else:
359
+ return config.training.instruments
360
+
361
+
362
+ def prefer_target_instrument_test(
363
+ config: ConfigDict, selected_instruments: Optional[List[str]] = None
364
+ ) -> List[str]:
365
+ """
366
+ Return the list of target instruments based on the configuration and selected instruments.
367
+ If selected_instruments is specified, returns the intersection with available instruments.
368
+ Otherwise, if a target instrument is specified, returns it, else returns all instruments.
369
+ """
370
+ available_instruments = config.training.instruments
371
+
372
+ if selected_instruments is not None:
373
+ # Return only selected instruments that are available
374
+ return [
375
+ instr for instr in selected_instruments if instr in available_instruments
376
+ ]
377
+ elif config.training.get("target_instrument"):
378
+ # Default behavior if no selection - return target instrument
379
+ return [config.training.target_instrument]
380
+ else:
381
+ # If no target and no selection, return all instruments
382
+ return available_instruments
mvsepless/model_manager.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import yaml
5
+ from tabulate import tabulate
6
+ import shutil
7
+ from tqdm import tqdm
8
+ import urllib.request
9
+ import gdown
10
+ import requests
11
+ import zipfile
12
+ import tempfile
13
+ import secrets
14
+ import string
15
+ import argparse
16
+ from typing import Dict, Any
17
+ script_dir = os.path.dirname(os.path.abspath(__file__))
18
+ if not __package__:
19
+ from downloader import dw_file
20
+ else:
21
+ from .downloader import dw_file
22
+
23
+ def generate_secure_random(length=10):
24
+ """Генерирует криптографически безопасную случайную строку"""
25
+ characters = string.ascii_letters + string.digits
26
+ return ''.join(secrets.choice(characters) for _ in range(length))
27
+
28
+ class MvseplessModelManager:
29
+ def __init__(
30
+ self,
31
+ models_info_path=os.path.join(script_dir, "models.json"),
32
+ cache_dir=os.path.join(script_dir, "mvsepless_models_cache"),
33
+ ):
34
+ self.models_cache_dir = cache_dir
35
+ self.models_info_path = models_info_path
36
+ with open(self.models_info_path, "r", encoding="utf-8") as f:
37
+ models_info = json.load(f)
38
+ self.models_info = models_info
39
+
40
+ def get_mt(self):
41
+ return list(self.models_info.keys())
42
+
43
+ def get_mn(self, model_type):
44
+ try:
45
+ mt = self.models_info.get(model_type, None)
46
+ if mt:
47
+ return list(self.models_info[model_type].keys())
48
+ return []
49
+ except (KeyError, TypeError):
50
+ return []
51
+
52
+ def get_stems(self, model_type, model_name):
53
+ try:
54
+ mt = self.models_info.get(model_type, None)
55
+ if mt:
56
+ mn = self.models_info[model_type].get(model_name, None)
57
+ if mn and "stems" in self.models_info[model_type][model_name]:
58
+ return self.models_info[model_type][model_name]["stems"]
59
+ return []
60
+ except (KeyError, TypeError):
61
+ return []
62
+
63
+ def get_id(self, model_type, model_name):
64
+ try:
65
+ mt = self.models_info.get(model_type, None)
66
+ if mt:
67
+ mn = self.models_info[model_type].get(model_name, None)
68
+ if mn and "id" in self.models_info[model_type][model_name]:
69
+ return self.models_info[model_type][model_name]["id"]
70
+ return 0
71
+ except (KeyError, TypeError):
72
+ return 0
73
+
74
+ def get_tgt_inst(self, model_type, model_name):
75
+ try:
76
+ mt = self.models_info.get(model_type, None)
77
+ if mt:
78
+ mn = self.models_info[model_type].get(model_name, None)
79
+ if mn and "target_instrument" in self.models_info[model_type][model_name]:
80
+ return self.models_info[model_type][model_name]["target_instrument"]
81
+ return None
82
+ except (KeyError, TypeError):
83
+ return None
84
+
85
+ def display_models_info(self, filter: str = None):
86
+ # Собираем данные для таблицы
87
+ table_data = []
88
+ headers = [
89
+ "Тип модели",
90
+ "ID",
91
+ "Имя модели",
92
+ "Стемы",
93
+ "Целевой инструмент",
94
+ ]
95
+
96
+ for model_type, models in self.models_info.items():
97
+ for model_name, model_info in models.items():
98
+ try:
99
+ stems_list = model_info.get("stems", [])
100
+ id = model_info.get("id", "н/д")
101
+ # Применяем фильтр (регистронезависимо)
102
+ if filter:
103
+ filter_lower = filter.lower()
104
+ if not any(filter_lower == s.lower() for s in stems_list):
105
+ continue
106
+
107
+ # Подготавливаем данные для строки таблицы
108
+ row = [
109
+ model_type,
110
+ id,
111
+ model_name,
112
+ ", ".join(stems_list) or "н/д",
113
+ model_info.get("target_instrument", "н/д"),
114
+ ]
115
+ table_data.append(row)
116
+ except (KeyError, TypeError, AttributeError) as e:
117
+ print(f"Ошибка при обработке модели {model_type}/{model_name}: {e}")
118
+ continue
119
+
120
+ # Выводим результат
121
+ if table_data:
122
+ print(tabulate(table_data, headers=headers, tablefmt="grid"))
123
+ else:
124
+ print("Нет моделей, которые содержат указанный стем")
125
+
126
+ def download_model(
127
+ self, model_paths, model_name, model_type, ckpt_url, conf_url
128
+ ):
129
+ model_dir = os.path.join(model_paths, model_type)
130
+ os.makedirs(model_dir, exist_ok=True)
131
+
132
+ config_path = None
133
+ checkpoint_path = None
134
+
135
+ if model_type == "mel_band_roformer":
136
+ config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
137
+ checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
138
+
139
+ elif model_type == "vr":
140
+ config_path = os.path.join(model_dir, f"{model_name}.yaml")
141
+ checkpoint_path = os.path.join(model_dir, f"{model_name}.pth")
142
+
143
+ elif model_type == "mdxnet":
144
+ config_path = os.path.join(model_dir, f"{model_name}.yaml")
145
+ checkpoint_path = os.path.join(model_dir, f"{model_name}.onnx")
146
+
147
+ elif model_type == "bs_roformer":
148
+ config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
149
+ checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
150
+
151
+ elif model_type == "mdx23c":
152
+ config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
153
+ checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
154
+
155
+ elif model_type == "scnet":
156
+ config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
157
+ checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
158
+
159
+ elif model_type == "bandit":
160
+ config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
161
+ checkpoint_path = os.path.join(model_dir, f"{model_name}.chpt")
162
+
163
+ elif model_type == "bandit_v2":
164
+ config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
165
+ checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
166
+
167
+ elif model_type == "htdemucs":
168
+ config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
169
+ checkpoint_path = os.path.join(model_dir, f"{model_name}.th")
170
+
171
+ else:
172
+ raise ValueError(
173
+ f"{self.I18N_helper.t('error_unsupported_model_type')}: {model_type}"
174
+ )
175
+
176
+ # Проверяем, что пути заданы (на всякий случай)
177
+ if config_path is None or checkpoint_path is None:
178
+ raise RuntimeError()
179
+
180
+ # Если файлы уже есть — пропускаем загрузку
181
+ if os.path.exists(checkpoint_path) and os.path.exists(config_path):
182
+ if os.path.getsize(checkpoint_path) == 0 or os.path.getsize(checkpoint_path) == 0:
183
+ for local_path, url_model in [
184
+ (checkpoint_path, ckpt_url),
185
+ (config_path, conf_url),
186
+ ]:
187
+ if not os.path.exists(local_path):
188
+
189
+ dw_file(url_model, local_path)
190
+ else:
191
+ pass
192
+ else:
193
+ for local_path, url_model in [
194
+ (checkpoint_path, ckpt_url),
195
+ (config_path, conf_url),
196
+ ]:
197
+ if not os.path.exists(local_path):
198
+
199
+ dw_file(url_model, local_path)
200
+
201
+ return config_path, checkpoint_path
202
+
203
+ def conf_editor(self, config_path, mdx_denoise, vr_aggr, model_type):
204
+
205
+ class IndentDumper(yaml.Dumper):
206
+ def increase_indent(self, flow=False, indentless=False):
207
+ return super(IndentDumper, self).increase_indent(flow, False)
208
+
209
+ def tuple_constructor(loader, node):
210
+ # Load the sequence of values from the YAML node
211
+ values = loader.construct_sequence(node)
212
+ # Return a tuple constructed from the sequence
213
+ return tuple(values)
214
+
215
+ # Register the constructor with PyYAML
216
+ yaml.SafeLoader.add_constructor(
217
+ "tag:yaml.org,2002:python/tuple", tuple_constructor
218
+ )
219
+
220
+ def conf_edit(config_path, mdx_denoise, vr_aggr, model_type):
221
+ with open(config_path, "r") as f:
222
+ data = yaml.load(f, Loader=yaml.SafeLoader)
223
+
224
+ # handle cases where 'use_amp' is missing from config:
225
+ if "use_amp" not in data.keys():
226
+ data["training"]["use_amp"] = True
227
+
228
+ if model_type != "vr":
229
+ if data["inference"]["num_overlap"] != 2:
230
+ data["inference"]["num_overlap"] = 2
231
+
232
+ if data["inference"]["batch_size"] != 1:
233
+ data["inference"]["batch_size"] = 1
234
+
235
+ if model_type == "mdxnet":
236
+ data["inference"]["denoise"] = mdx_denoise
237
+
238
+ elif model_type == "vr":
239
+ data["inference"]["aggression"] = vr_aggr
240
+
241
+ with open(config_path, "w") as f:
242
+ yaml.dump(
243
+ data,
244
+ f,
245
+ default_flow_style=False,
246
+ sort_keys=False,
247
+ Dumper=IndentDumper,
248
+ allow_unicode=True,
249
+ )
250
+
251
+ conf_edit(config_path, mdx_denoise, vr_aggr, model_type)
252
+
253
+ class VbachModelManager:
254
+ def __init__(self):
255
+ self.rmvpe_path = os.path.join(script_dir, "predictors", "rmvpe.pt")
256
+ self.fcpe_path = os.path.join(script_dir, "predictors", "fcpe.pt")
257
+ self.hubert_path = os.path.join(script_dir, "embedders", "hubert_base.pt")
258
+ self.requirements = [["https://huggingface.co/Politrees/RVC_resources/resolve/main/predictors/rmvpe.pt", self.rmvpe_path], ["https://huggingface.co/Politrees/RVC_resources/resolve/main/predictors/fcpe.pt", self.fcpe_path], ["https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/hubert_base.pt", self.hubert_path]]
259
+ self.voicemodels_dir = os.path.join(script_dir, "vbach_models_cache")
260
+ os.makedirs(self.voicemodels_dir, exist_ok=True)
261
+ self.voicemodels_info = os.path.join(self.voicemodels_dir, "vbach_models.json")
262
+ self.voicemodels: Dict[str, Dict[str, str]] = {}
263
+ self.download_requirements()
264
+ self.check_and_load()
265
+ pass
266
+
267
+ def write_voicemodels_info(self):
268
+ with open(self.voicemodels_info, "w") as f:
269
+ json.dump(self.voicemodels, f, indent=4)
270
+
271
+ def load_voicemodels_info(self):
272
+ with open(self.voicemodels_info, "r") as f:
273
+ return json.load(f)
274
+
275
+ def add_voice_model(
276
+ self,
277
+ name,
278
+ pth_path,
279
+ index_path,
280
+ ):
281
+ self.voicemodels[name] = {"pth": pth_path, "index": index_path}
282
+ self.write_voicemodels_info()
283
+
284
+ def del_voice_model(
285
+ self, name
286
+ ):
287
+ if name in self.parse_voice_models():
288
+ pth = self.voicemodels[name].get("pth", None)
289
+ index = self.voicemodels[name].get("index", None)
290
+ if index:
291
+ os.remove(index)
292
+ if pth:
293
+ os.remove(pth)
294
+ del self.voicemodels[name]
295
+ self.write_voicemodels_info()
296
+ return f"Модель {name} удалена"
297
+ else:
298
+ return f"Модель не была удалена, как так её не существует"
299
+
300
+ def parse_voice_models(self):
301
+ list_models = list(self.voicemodels.keys())
302
+ return list_models
303
+
304
+ def parse_pth_and_index(self, name):
305
+ pth = self.voicemodels[name].get("pth", None)
306
+ index = self.voicemodels[name].get("index", None)
307
+ return pth, index
308
+
309
+ def check_and_load(self):
310
+ if os.path.exists(self.voicemodels_info):
311
+ self.voicemodels = self.load_voicemodels_info()
312
+ else:
313
+ self.write_voicemodels_info()
314
+
315
+ def clear_voicemodels_info(self):
316
+ self.voicemodels: Dict[str, Dict[str, str]] = {}
317
+ self.write_voicemodels_info()
318
+
319
+ def download_file(self, url_model, local_path):
320
+ dir_name = os.path.dirname(local_path)
321
+ if dir_name != "":
322
+ os.makedirs(dir_name, exist_ok=True)
323
+ class TqdmUpTo(tqdm):
324
+ def update_to(self, b=1, bsize=1, tsize=None):
325
+ if tsize is not None:
326
+ self.total = tsize
327
+ self.update(b * bsize - self.n)
328
+
329
+ with TqdmUpTo(
330
+ unit="B",
331
+ unit_scale=True,
332
+ unit_divisor=1024,
333
+ miniters=1,
334
+ desc=os.path.basename(local_path),
335
+ ) as t:
336
+ urllib.request.urlretrieve(
337
+ url_model, local_path, reporthook=t.update_to
338
+ )
339
+
340
+ def download_requirements(self):
341
+ for url, file in self.requirements:
342
+ if not os.path.exists(file):
343
+ self.download_file(url_model=url, local_path=file)
344
+
345
+ def download_voice_model_file(self, url, zip_name):
346
+ try:
347
+ if "drive.google.com" in url:
348
+ self.download_from_google_drive(url, zip_name)
349
+ elif "pixeldrain.com" in url:
350
+ self.download_from_pixeldrain(url, zip_name)
351
+ elif "disk.yandex.ru" in url or "yadi.sk" in url:
352
+ self.download_from_yandex(url, zip_name)
353
+ else:
354
+ self.download_file(url, zip_name)
355
+ except Exception as e:
356
+ print(e)
357
+
358
+ def download_from_google_drive(self, url, zip_name):
359
+ file_id = (
360
+ url.split("file/d/")[1].split("/")[0]
361
+ if "file/d/" in url
362
+ else url.split("id=")[1].split("&")[0]
363
+ )
364
+ gdown.download(id=file_id, output=str(zip_name), quiet=False)
365
+
366
+ def download_from_pixeldrain(self, url, zip_name):
367
+ file_id = url.split("pixeldrain.com/u/")[1]
368
+ response = requests.get(f"https://pixeldrain.com/api/file/{file_id}")
369
+ with open(zip_name, "wb") as f:
370
+ f.write(response.content)
371
+
372
+ def download_from_yandex(self, url, zip_name):
373
+ yandex_public_key = f"download?public_key={url}"
374
+ yandex_api_url = f"https://cloud-api.yandex.net/v1/disk/public/resources/{yandex_public_key}"
375
+ response = requests.get(yandex_api_url)
376
+ if response.status_code == 200:
377
+ download_link = response.json().get("href")
378
+ urllib.request.urlretrieve(download_link, zip_name)
379
+ else:
380
+ print(response.status_code)
381
+
382
+ def extract_zip(self, zip_name, model_name):
383
+ model_dir = os.path.join(self.voicemodels_dir, f"{model_name}_{generate_secure_random(17)}")
384
+ os.makedirs(model_dir, exist_ok=True)
385
+ try:
386
+ with zipfile.ZipFile(zip_name, "r") as zip_ref:
387
+ zip_ref.extractall(model_dir)
388
+ os.remove(zip_name)
389
+
390
+ added_voice_models = []
391
+
392
+ index_filepath, model_filepaths = None, []
393
+ for root, _, files in os.walk(model_dir):
394
+ for name in files:
395
+ file_path = os.path.join(root, name)
396
+ if name.endswith(".index") and os.stat(file_path).st_size > 1024 * 100:
397
+ index_filepath = file_path
398
+ if name.endswith(".pth") and os.stat(file_path).st_size > 1024 * 1024 * 20:
399
+ model_filepaths.append(file_path)
400
+
401
+ if len(model_filepaths) == 1:
402
+ self.add_voice_model(model_name, model_filepaths[0], index_filepath)
403
+ added_voice_models.append(model_name)
404
+ else:
405
+ for i, pth in enumerate(model_filepaths):
406
+ self.add_voice_model(f"{model_name}_{i + 1}", pth, index_filepath)
407
+ added_voice_models.append(f"{model_name}_{i + 1}")
408
+ list_models_str = '\n'.join(added_voice_models)
409
+ return f"Добавленные модели:\n{list_models_str}"
410
+ except Exception as e:
411
+ return f"Произошла ошибка при загрузке модели: {e}"
412
+
413
+ def install_model_zip(self, zip, model_name, mode="url"):
414
+ if model_name in self.parse_voice_models():
415
+ print("Эта модель уже есть в списке установленных моделей. Она будут перезаписана")
416
+ if mode == "url":
417
+ with tempfile.TemporaryDirectory(prefix="vbach_temp_model", ignore_cleanup_errors=True) as tmp:
418
+ zip_path = os.path.join(tmp, "model.zip")
419
+ self.download_voice_model_file(zip, zip_path)
420
+ status = self.extract_zip(zip_path, model_name)
421
+ if mode == "local":
422
+ status = self.extract_zip(zip, model_name)
423
+ return status
424
+
425
+ def install_model_files(self, index, pth, model_name, mode="url"):
426
+ if model_name in self.parse_voice_models():
427
+ print("Эта модель уже есть в списке установленных моделей. Она будут перезаписана")
428
+ model_dir = os.path.join(self.voicemodels_dir, f"{model_name}_{generate_secure_random(17)}")
429
+ os.makedirs(model_dir, exist_ok=True)
430
+ local_index_path = None
431
+ local_pth_path = None
432
+ try:
433
+ if mode == "url":
434
+ if index:
435
+ local_index_path = os.path.join(model_dir, "model.index")
436
+ self.download_voice_model_file(index, local_index_path)
437
+ if pth:
438
+ local_pth_path = os.path.join(model_dir, "model.pth")
439
+ self.download_voice_model_file(pth, local_pth_path)
440
+
441
+ if mode == "local":
442
+ if index:
443
+ if os.path.exists(index):
444
+ local_index_path = os.path.join(model_dir, os.path.basename(index))
445
+ shutil.copy(index, local_index_path)
446
+ if pth:
447
+ if os.path.exists(pth):
448
+ local_pth_path = os.path.join(model_dir, os.path.basename(pth))
449
+ shutil.copy(pth, local_pth_path)
450
+
451
+ self.add_voice_model(model_name, local_pth_path, local_index_path)
452
+ return f"Модель {model_name} добавлена"
453
+ except Exception as e:
454
+ return f"Произошла ошибка при загрузке модели: {e}"
455
+
456
+
457
+ if __name__ == "__main__":
458
+ parser = argparse.ArgumentParser(description="Менеджер моделей")
459
+ subparsers = parser.add_subparsers(title="subcommands", dest="command", required=True)
460
+
461
+ # Mvsepless subcommand
462
+ mvsepless_parser = subparsers.add_parser("mvsepless", help="Скачивание моделей в MVSepLess")
463
+ mvsepless_parser.add_argument("--model_type", required=True, help="Тип модели")
464
+ mvsepless_parser.add_argument("--model_name", required=True, help="Имя модели")
465
+
466
+ # Vbach subcommand
467
+ vbach_parser = subparsers.add_parser("vbach", help="Установка голосовых моделей в Vbach")
468
+ vbach_subparsers = vbach_parser.add_subparsers(title="vbach_commands", dest="vbach_command", required=True)
469
+
470
+ # Vbach install_local
471
+ install_local_parser = vbach_subparsers.add_parser("install_local", help="Установка голосовой модели по локальным файлам")
472
+ install_local_parser.add_argument("--model_name", required=True, help="Имя голосовой модели")
473
+ install_local_parser.add_argument("--pth", required=True, help="Путь к *.pth файлу")
474
+ install_local_parser.add_argument("--index", required=False, help="Путь к *.index файлу")
475
+
476
+ # Vbach install_url_zip
477
+ install_url_zip_parser = vbach_subparsers.add_parser("install_url_zip", help="Установка голосовой модели по URL (архив с файлами)")
478
+ install_url_zip_parser.add_argument("--model_name", required=True, help="Имя голосовой модели")
479
+ install_url_zip_parser.add_argument("--url", required=True, help="URL *.zip файла")
480
+
481
+ # Vbach install_url_files
482
+ install_url_files_parser = vbach_subparsers.add_parser("install_url_files", help="Установка голосовой модели по URL (отдельные файлы)")
483
+ install_url_files_parser.add_argument("--model_name", required=True, help="Имя голосовой модели")
484
+ install_url_files_parser.add_argument("--pth_url", required=True, help="URL *.pth файла")
485
+ install_url_files_parser.add_argument("--index_url", required=False, help="URL *.index файла")
486
+
487
+ # Vbach list
488
+ list_parser = vbach_subparsers.add_parser("list", help="List installed voice models")
489
+
490
+ args = parser.parse_args()
491
+
492
+ if args.command == "mvsepless":
493
+
494
+ _model_manager = MvseplessModelManager()
495
+ info = _model_manager.models_info[args.model_type].get(args.model_name, None)
496
+ if not info:
497
+ raise ValueError(f"Модель {args.model_name} не найдена для типа {args.model_type}")
498
+ conf, ckpt = _model_manager.download_model(
499
+ _model_manager.models_cache_dir,
500
+ args.model_name,
501
+ args.model_type,
502
+ info["checkpoint_url"],
503
+ info["config_url"],
504
+ )
505
+
506
+ elif args.command == "vbach":
507
+ model_manager = VbachModelManager()
508
+
509
+ if args.vbach_command == "install_local":
510
+ status = model_manager.install_model_files(
511
+ args.index, args.pth, args.model_name, mode="local"
512
+ )
513
+ print(status)
514
+
515
+ elif args.vbach_command == "install_url_zip":
516
+ status = model_manager.install_model_zip(
517
+ args.url, args.model_name, mode="url"
518
+ )
519
+ print(status)
520
+
521
+ elif args.vbach_command == "install_url_files":
522
+ status = model_manager.install_model_files(
523
+ args.index_url, args.pth_url, args.model_name, mode="url"
524
+ )
525
+ print(status)
526
+
527
+ elif args.vbach_command == "list":
528
+ models = model_manager.parse_voice_models()
529
+ if models:
530
+ print("Установленные модели:")
531
+ for model in models:
532
+ print(f" - {model}")
533
+ else:
534
+ print("Нет установленных моделей")
535
+
536
+
537
+
538
+
539
+
540
+
mvsepless/models.json ADDED
The diff for this file is too large to render. See raw diff
 
mvsepless/models/bandit/core/__init__.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from collections import defaultdict
3
+ from itertools import chain, combinations
4
+ from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Type, TypedDict
5
+
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ import torchaudio as ta
9
+ import torchmetrics as tm
10
+ from asteroid import losses as asteroid_losses
11
+
12
+ # from deepspeed.ops.adam import DeepSpeedCPUAdam
13
+ # from geoopt import optim as gooptim
14
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
15
+ from torch import nn, optim
16
+ from torch.optim import lr_scheduler
17
+ from torch.optim.lr_scheduler import LRScheduler
18
+
19
+ from . import loss, metrics as metrics_, model
20
+ from .data._types import BatchedDataDict
21
+ from .data.augmentation import BaseAugmentor, StemAugmentor
22
+ from .utils import audio as audio_
23
+ from .utils.audio import BaseFader
24
+
25
+ # from pandas.io.json._normalize import nested_to_record
26
+
27
+ ConfigDict = TypedDict("ConfigDict", {"name": str, "kwargs": Dict[str, Any]})
28
+
29
+
30
+ class SchedulerConfigDict(ConfigDict):
31
+ monitor: str
32
+
33
+
34
+ OptimizerSchedulerConfigDict = TypedDict(
35
+ "OptimizerSchedulerConfigDict",
36
+ {"optimizer": ConfigDict, "scheduler": SchedulerConfigDict},
37
+ total=False,
38
+ )
39
+
40
+
41
+ class LRSchedulerReturnDict(TypedDict, total=False):
42
+ scheduler: LRScheduler
43
+ monitor: str
44
+
45
+
46
+ class ConfigureOptimizerReturnDict(TypedDict, total=False):
47
+ optimizer: torch.optim.Optimizer
48
+ lr_scheduler: LRSchedulerReturnDict
49
+
50
+
51
+ OutputType = Dict[str, Any]
52
+ MetricsType = Dict[str, torch.Tensor]
53
+
54
+
55
+ def get_optimizer_class(name: str) -> Type[optim.Optimizer]:
56
+
57
+ if name == "DeepSpeedCPUAdam":
58
+ return DeepSpeedCPUAdam
59
+
60
+ for module in [optim, gooptim]:
61
+ if name in module.__dict__:
62
+ return module.__dict__[name]
63
+
64
+ raise NameError
65
+
66
+
67
+ def parse_optimizer_config(
68
+ config: OptimizerSchedulerConfigDict, parameters: Iterator[nn.Parameter]
69
+ ) -> ConfigureOptimizerReturnDict:
70
+ optim_class = get_optimizer_class(config["optimizer"]["name"])
71
+ optimizer = optim_class(parameters, **config["optimizer"]["kwargs"])
72
+
73
+ optim_dict: ConfigureOptimizerReturnDict = {
74
+ "optimizer": optimizer,
75
+ }
76
+
77
+ if "scheduler" in config:
78
+
79
+ lr_scheduler_class_ = config["scheduler"]["name"]
80
+ lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_]
81
+ lr_scheduler_dict: LRSchedulerReturnDict = {
82
+ "scheduler": lr_scheduler_class(optimizer, **config["scheduler"]["kwargs"])
83
+ }
84
+
85
+ if lr_scheduler_class_ == "ReduceLROnPlateau":
86
+ lr_scheduler_dict["monitor"] = config["scheduler"]["monitor"]
87
+
88
+ optim_dict["lr_scheduler"] = lr_scheduler_dict
89
+
90
+ return optim_dict
91
+
92
+
93
+ def parse_model_config(config: ConfigDict) -> Any:
94
+ name = config["name"]
95
+
96
+ for module in [model]:
97
+ if name in module.__dict__:
98
+ return module.__dict__[name](**config["kwargs"])
99
+
100
+ raise NameError
101
+
102
+
103
+ _LEGACY_LOSS_NAMES = ["HybridL1Loss"]
104
+
105
+
106
+ def _parse_legacy_loss_config(config: ConfigDict) -> nn.Module:
107
+ name = config["name"]
108
+
109
+ if name == "HybridL1Loss":
110
+ return loss.TimeFreqL1Loss(**config["kwargs"])
111
+
112
+ raise NameError
113
+
114
+
115
+ def parse_loss_config(config: ConfigDict) -> nn.Module:
116
+ name = config["name"]
117
+
118
+ if name in _LEGACY_LOSS_NAMES:
119
+ return _parse_legacy_loss_config(config)
120
+
121
+ for module in [loss, nn.modules.loss, asteroid_losses]:
122
+ if name in module.__dict__:
123
+ # print(config["kwargs"])
124
+ return module.__dict__[name](**config["kwargs"])
125
+
126
+ raise NameError
127
+
128
+
129
+ def get_metric(config: ConfigDict) -> tm.Metric:
130
+ name = config["name"]
131
+
132
+ for module in [tm, metrics_]:
133
+ if name in module.__dict__:
134
+ return module.__dict__[name](**config["kwargs"])
135
+ raise NameError
136
+
137
+
138
+ def parse_metric_config(config: Dict[str, ConfigDict]) -> tm.MetricCollection:
139
+ metrics = {}
140
+
141
+ for metric in config:
142
+ metrics[metric] = get_metric(config[metric])
143
+
144
+ return tm.MetricCollection(metrics)
145
+
146
+
147
+ def parse_fader_config(config: ConfigDict) -> BaseFader:
148
+ name = config["name"]
149
+
150
+ for module in [audio_]:
151
+ if name in module.__dict__:
152
+ return module.__dict__[name](**config["kwargs"])
153
+
154
+ raise NameError
155
+
156
+
157
+ class LightningSystem(pl.LightningModule):
158
+ _VOX_STEMS = ["speech", "vocals"]
159
+ _BG_STEMS = ["background", "effects", "mne"]
160
+
161
+ def __init__(
162
+ self, config: Dict, loss_adjustment: float = 1.0, attach_fader: bool = False
163
+ ) -> None:
164
+ super().__init__()
165
+ self.optimizer_config = config["optimizer"]
166
+ self.model = parse_model_config(config["model"])
167
+ self.loss = parse_loss_config(config["loss"])
168
+ self.metrics = nn.ModuleDict(
169
+ {
170
+ stem: parse_metric_config(config["metrics"]["dev"])
171
+ for stem in self.model.stems
172
+ }
173
+ )
174
+
175
+ self.metrics.disallow_fsdp = True
176
+
177
+ self.test_metrics = nn.ModuleDict(
178
+ {
179
+ stem: parse_metric_config(config["metrics"]["test"])
180
+ for stem in self.model.stems
181
+ }
182
+ )
183
+
184
+ self.test_metrics.disallow_fsdp = True
185
+
186
+ self.fs = config["model"]["kwargs"]["fs"]
187
+
188
+ self.fader_config = config["inference"]["fader"]
189
+ if attach_fader:
190
+ self.fader = parse_fader_config(config["inference"]["fader"])
191
+ else:
192
+ self.fader = None
193
+
194
+ self.augmentation: Optional[BaseAugmentor]
195
+ if config.get("augmentation", None) is not None:
196
+ self.augmentation = StemAugmentor(**config["augmentation"])
197
+ else:
198
+ self.augmentation = None
199
+
200
+ self.predict_output_path: Optional[str] = None
201
+ self.loss_adjustment = loss_adjustment
202
+
203
+ self.val_prefix = None
204
+ self.test_prefix = None
205
+
206
+ def configure_optimizers(self) -> Any:
207
+ return parse_optimizer_config(
208
+ self.optimizer_config, self.trainer.model.parameters()
209
+ )
210
+
211
+ def compute_loss(
212
+ self, batch: BatchedDataDict, output: OutputType
213
+ ) -> Dict[str, torch.Tensor]:
214
+ return {"loss": self.loss(output, batch)}
215
+
216
+ def update_metrics(
217
+ self, batch: BatchedDataDict, output: OutputType, mode: str
218
+ ) -> None:
219
+
220
+ if mode == "test":
221
+ metrics = self.test_metrics
222
+ else:
223
+ metrics = self.metrics
224
+
225
+ for stem, metric in metrics.items():
226
+
227
+ if stem == "mne:+":
228
+ stem = "mne"
229
+
230
+ # print(f"matching for {stem}")
231
+ if mode == "train":
232
+ metric.update(
233
+ output["audio"][stem], # .cpu(),
234
+ batch["audio"][stem], # .cpu()
235
+ )
236
+ else:
237
+ if stem not in batch["audio"]:
238
+ matched = False
239
+ if stem in self._VOX_STEMS:
240
+ for bstem in self._VOX_STEMS:
241
+ if bstem in batch["audio"]:
242
+ batch["audio"][stem] = batch["audio"][bstem]
243
+ matched = True
244
+ break
245
+ elif stem in self._BG_STEMS:
246
+ for bstem in self._BG_STEMS:
247
+ if bstem in batch["audio"]:
248
+ batch["audio"][stem] = batch["audio"][bstem]
249
+ matched = True
250
+ break
251
+ else:
252
+ matched = True
253
+
254
+ # print(batch["audio"].keys())
255
+
256
+ if matched:
257
+ # print(f"matched {stem}!")
258
+ if stem == "mne" and "mne" not in output["audio"]:
259
+ output["audio"]["mne"] = (
260
+ output["audio"]["music"] + output["audio"]["effects"]
261
+ )
262
+
263
+ metric.update(
264
+ output["audio"][stem], # .cpu(),
265
+ batch["audio"][stem], # .cpu(),
266
+ )
267
+
268
+ # print(metric.compute())
269
+
270
+ def compute_metrics(self, mode: str = "dev") -> Dict[str, torch.Tensor]:
271
+
272
+ if mode == "test":
273
+ metrics = self.test_metrics
274
+ else:
275
+ metrics = self.metrics
276
+
277
+ metric_dict = {}
278
+
279
+ for stem, metric in metrics.items():
280
+ md = metric.compute()
281
+ metric_dict.update({f"{stem}/{k}": v for k, v in md.items()})
282
+
283
+ self.log_dict(metric_dict, prog_bar=True, logger=False)
284
+
285
+ return metric_dict
286
+
287
+ def reset_metrics(self, test_mode: bool = False) -> None:
288
+
289
+ if test_mode:
290
+ metrics = self.test_metrics
291
+ else:
292
+ metrics = self.metrics
293
+
294
+ for _, metric in metrics.items():
295
+ metric.reset()
296
+
297
+ def forward(self, batch: BatchedDataDict) -> Any:
298
+ batch, output = self.model(batch)
299
+
300
+ return batch, output
301
+
302
+ def common_step(self, batch: BatchedDataDict, mode: str) -> Any:
303
+ batch, output = self.forward(batch)
304
+ # print(batch)
305
+ # print(output)
306
+ loss_dict = self.compute_loss(batch, output)
307
+
308
+ with torch.no_grad():
309
+ self.update_metrics(batch, output, mode=mode)
310
+
311
+ if mode == "train":
312
+ self.log("loss", loss_dict["loss"], prog_bar=True)
313
+
314
+ return output, loss_dict
315
+
316
+ def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]:
317
+
318
+ if self.augmentation is not None:
319
+ with torch.no_grad():
320
+ batch = self.augmentation(batch)
321
+
322
+ _, loss_dict = self.common_step(batch, mode="train")
323
+
324
+ with torch.inference_mode():
325
+ self.log_dict_with_prefix(
326
+ loss_dict, "train", batch_size=batch["audio"]["mixture"].shape[0]
327
+ )
328
+
329
+ loss_dict["loss"] *= self.loss_adjustment
330
+
331
+ return loss_dict
332
+
333
+ def on_train_batch_end(
334
+ self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int
335
+ ) -> None:
336
+
337
+ metric_dict = self.compute_metrics()
338
+ self.log_dict_with_prefix(metric_dict, "train")
339
+ self.reset_metrics()
340
+
341
+ def validation_step(
342
+ self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
343
+ ) -> Dict[str, Any]:
344
+
345
+ with torch.inference_mode():
346
+ curr_val_prefix = f"val{dataloader_idx}" if dataloader_idx > 0 else "val"
347
+
348
+ if curr_val_prefix != self.val_prefix:
349
+ # print(f"Switching to validation dataloader {dataloader_idx}")
350
+ if self.val_prefix is not None:
351
+ self._on_validation_epoch_end()
352
+ self.val_prefix = curr_val_prefix
353
+ _, loss_dict = self.common_step(batch, mode="val")
354
+
355
+ self.log_dict_with_prefix(
356
+ loss_dict,
357
+ self.val_prefix,
358
+ batch_size=batch["audio"]["mixture"].shape[0],
359
+ prog_bar=True,
360
+ add_dataloader_idx=False,
361
+ )
362
+
363
+ return loss_dict
364
+
365
+ def on_validation_epoch_end(self) -> None:
366
+ self._on_validation_epoch_end()
367
+
368
+ def _on_validation_epoch_end(self) -> None:
369
+ metric_dict = self.compute_metrics()
370
+ self.log_dict_with_prefix(
371
+ metric_dict, self.val_prefix, prog_bar=True, add_dataloader_idx=False
372
+ )
373
+ # self.logger.save()
374
+ # print(self.val_prefix, "Validation metrics:", metric_dict)
375
+ self.reset_metrics()
376
+
377
+ def old_predtest_step(
378
+ self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
379
+ ) -> Tuple[BatchedDataDict, OutputType]:
380
+
381
+ audio_batch = batch["audio"]["mixture"]
382
+ track_batch = batch.get("track", ["" for _ in range(len(audio_batch))])
383
+
384
+ output_list_of_dicts = [
385
+ self.fader(audio[None, ...], lambda a: self.test_forward(a, track))
386
+ for audio, track in zip(audio_batch, track_batch)
387
+ ]
388
+
389
+ output_dict_of_lists = defaultdict(list)
390
+
391
+ for output_dict in output_list_of_dicts:
392
+ for stem, audio in output_dict.items():
393
+ output_dict_of_lists[stem].append(audio)
394
+
395
+ output = {
396
+ "audio": {
397
+ stem: torch.concat(output_list, dim=0)
398
+ for stem, output_list in output_dict_of_lists.items()
399
+ }
400
+ }
401
+
402
+ return batch, output
403
+
404
+ def predtest_step(
405
+ self, batch: BatchedDataDict, batch_idx: int = -1, dataloader_idx: int = 0
406
+ ) -> Tuple[BatchedDataDict, OutputType]:
407
+
408
+ if getattr(self.model, "bypass_fader", False):
409
+ batch, output = self.model(batch)
410
+ else:
411
+ audio_batch = batch["audio"]["mixture"]
412
+ output = self.fader(
413
+ audio_batch, lambda a: self.test_forward(a, "", batch=batch)
414
+ )
415
+
416
+ return batch, output
417
+
418
+ def test_forward(
419
+ self, audio: torch.Tensor, track: str = "", batch: BatchedDataDict = None
420
+ ) -> torch.Tensor:
421
+
422
+ if self.fader is None:
423
+ self.attach_fader()
424
+
425
+ cond = batch.get("condition", None)
426
+
427
+ if cond is not None and cond.shape[0] == 1:
428
+ cond = cond.repeat(audio.shape[0], 1)
429
+
430
+ _, output = self.forward(
431
+ {
432
+ "audio": {"mixture": audio},
433
+ "track": track,
434
+ "condition": cond,
435
+ }
436
+ ) # TODO: support track properly
437
+
438
+ return output["audio"]
439
+
440
+ def on_test_epoch_start(self) -> None:
441
+ self.attach_fader(force_reattach=True)
442
+
443
+ def test_step(
444
+ self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
445
+ ) -> Any:
446
+ curr_test_prefix = f"test{dataloader_idx}"
447
+
448
+ # print(batch["audio"].keys())
449
+
450
+ if curr_test_prefix != self.test_prefix:
451
+ # print(f"Switching to test dataloader {dataloader_idx}")
452
+ if self.test_prefix is not None:
453
+ self._on_test_epoch_end()
454
+ self.test_prefix = curr_test_prefix
455
+
456
+ with torch.inference_mode():
457
+ _, output = self.predtest_step(batch, batch_idx, dataloader_idx)
458
+ # print(output)
459
+ self.update_metrics(batch, output, mode="test")
460
+
461
+ return output
462
+
463
+ def on_test_epoch_end(self) -> None:
464
+ self._on_test_epoch_end()
465
+
466
+ def _on_test_epoch_end(self) -> None:
467
+ metric_dict = self.compute_metrics(mode="test")
468
+ self.log_dict_with_prefix(
469
+ metric_dict, self.test_prefix, prog_bar=True, add_dataloader_idx=False
470
+ )
471
+ # self.logger.save()
472
+ # print(self.test_prefix, "Test metrics:", metric_dict)
473
+ self.reset_metrics()
474
+
475
+ def predict_step(
476
+ self,
477
+ batch: BatchedDataDict,
478
+ batch_idx: int = 0,
479
+ dataloader_idx: int = 0,
480
+ include_track_name: Optional[bool] = None,
481
+ get_no_vox_combinations: bool = True,
482
+ get_residual: bool = False,
483
+ treat_batch_as_channels: bool = False,
484
+ fs: Optional[int] = None,
485
+ ) -> Any:
486
+ assert self.predict_output_path is not None
487
+
488
+ batch_size = batch["audio"]["mixture"].shape[0]
489
+
490
+ if include_track_name is None:
491
+ include_track_name = batch_size > 1
492
+
493
+ with torch.inference_mode():
494
+ batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
495
+ print("Pred test finished...")
496
+ torch.cuda.empty_cache()
497
+ metric_dict = {}
498
+
499
+ if get_residual:
500
+ mixture = batch["audio"]["mixture"]
501
+ extracted = sum([output["audio"][stem] for stem in output["audio"]])
502
+ residual = mixture - extracted
503
+ print(extracted.shape, mixture.shape, residual.shape)
504
+
505
+ output["audio"]["residual"] = residual
506
+
507
+ if get_no_vox_combinations:
508
+ no_vox_stems = [
509
+ stem for stem in output["audio"] if stem not in self._VOX_STEMS
510
+ ]
511
+ no_vox_combinations = chain.from_iterable(
512
+ combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
513
+ )
514
+
515
+ for combination in no_vox_combinations:
516
+ combination_ = list(combination)
517
+ output["audio"]["+".join(combination_)] = sum(
518
+ [output["audio"][stem] for stem in combination_]
519
+ )
520
+
521
+ if treat_batch_as_channels:
522
+ for stem in output["audio"]:
523
+ output["audio"][stem] = output["audio"][stem].reshape(
524
+ 1, -1, output["audio"][stem].shape[-1]
525
+ )
526
+ batch_size = 1
527
+
528
+ for b in range(batch_size):
529
+ print("!!", b)
530
+ for stem in output["audio"]:
531
+ print(f"Saving audio for {stem} to {self.predict_output_path}")
532
+ track_name = batch["track"][b].split("/")[-1]
533
+
534
+ if batch.get("audio", {}).get(stem, None) is not None:
535
+ self.test_metrics[stem].reset()
536
+ metrics = self.test_metrics[stem](
537
+ batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
538
+ )
539
+ snr = metrics["snr"]
540
+ sisnr = metrics["sisnr"]
541
+ sdr = metrics["sdr"]
542
+ metric_dict[stem] = metrics
543
+ print(
544
+ track_name,
545
+ f"snr={snr:2.2f} dB",
546
+ f"sisnr={sisnr:2.2f}",
547
+ f"sdr={sdr:2.2f} dB",
548
+ )
549
+ filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
550
+ else:
551
+ filename = f"{stem}.wav"
552
+
553
+ if include_track_name:
554
+ output_dir = os.path.join(self.predict_output_path, track_name)
555
+ else:
556
+ output_dir = self.predict_output_path
557
+
558
+ os.makedirs(output_dir, exist_ok=True)
559
+
560
+ if fs is None:
561
+ fs = self.fs
562
+
563
+ ta.save(
564
+ os.path.join(output_dir, filename),
565
+ output["audio"][stem][b, ...].cpu(),
566
+ fs,
567
+ )
568
+
569
+ return metric_dict
570
+
571
+ def get_stems(
572
+ self,
573
+ batch: BatchedDataDict,
574
+ batch_idx: int = 0,
575
+ dataloader_idx: int = 0,
576
+ include_track_name: Optional[bool] = None,
577
+ get_no_vox_combinations: bool = True,
578
+ get_residual: bool = False,
579
+ treat_batch_as_channels: bool = False,
580
+ fs: Optional[int] = None,
581
+ ) -> Any:
582
+ assert self.predict_output_path is not None
583
+
584
+ batch_size = batch["audio"]["mixture"].shape[0]
585
+
586
+ if include_track_name is None:
587
+ include_track_name = batch_size > 1
588
+
589
+ with torch.inference_mode():
590
+ batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
591
+ torch.cuda.empty_cache()
592
+ metric_dict = {}
593
+
594
+ if get_residual:
595
+ mixture = batch["audio"]["mixture"]
596
+ extracted = sum([output["audio"][stem] for stem in output["audio"]])
597
+ residual = mixture - extracted
598
+ # print(extracted.shape, mixture.shape, residual.shape)
599
+
600
+ output["audio"]["residual"] = residual
601
+
602
+ if get_no_vox_combinations:
603
+ no_vox_stems = [
604
+ stem for stem in output["audio"] if stem not in self._VOX_STEMS
605
+ ]
606
+ no_vox_combinations = chain.from_iterable(
607
+ combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
608
+ )
609
+
610
+ for combination in no_vox_combinations:
611
+ combination_ = list(combination)
612
+ output["audio"]["+".join(combination_)] = sum(
613
+ [output["audio"][stem] for stem in combination_]
614
+ )
615
+
616
+ if treat_batch_as_channels:
617
+ for stem in output["audio"]:
618
+ output["audio"][stem] = output["audio"][stem].reshape(
619
+ 1, -1, output["audio"][stem].shape[-1]
620
+ )
621
+ batch_size = 1
622
+
623
+ result = {}
624
+ for b in range(batch_size):
625
+ for stem in output["audio"]:
626
+ track_name = batch["track"][b].split("/")[-1]
627
+
628
+ if batch.get("audio", {}).get(stem, None) is not None:
629
+ self.test_metrics[stem].reset()
630
+ metrics = self.test_metrics[stem](
631
+ batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
632
+ )
633
+ snr = metrics["snr"]
634
+ sisnr = metrics["sisnr"]
635
+ sdr = metrics["sdr"]
636
+ metric_dict[stem] = metrics
637
+ print(
638
+ track_name,
639
+ f"snr={snr:2.2f} dB",
640
+ f"sisnr={sisnr:2.2f}",
641
+ f"sdr={sdr:2.2f} dB",
642
+ )
643
+ filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
644
+ else:
645
+ filename = f"{stem}.wav"
646
+
647
+ if include_track_name:
648
+ output_dir = os.path.join(self.predict_output_path, track_name)
649
+ else:
650
+ output_dir = self.predict_output_path
651
+
652
+ os.makedirs(output_dir, exist_ok=True)
653
+
654
+ if fs is None:
655
+ fs = self.fs
656
+
657
+ result[stem] = output["audio"][stem][b, ...].cpu().numpy()
658
+
659
+ return result
660
+
661
+ def load_state_dict(
662
+ self, state_dict: Mapping[str, Any], strict: bool = False
663
+ ) -> Any:
664
+
665
+ return super().load_state_dict(state_dict, strict=False)
666
+
667
+ def set_predict_output_path(self, path: str) -> None:
668
+ self.predict_output_path = path
669
+ os.makedirs(self.predict_output_path, exist_ok=True)
670
+
671
+ self.attach_fader()
672
+
673
+ def attach_fader(self, force_reattach=False) -> None:
674
+ if self.fader is None or force_reattach:
675
+ self.fader = parse_fader_config(self.fader_config)
676
+ self.fader.to(self.device)
677
+
678
+ def log_dict_with_prefix(
679
+ self,
680
+ dict_: Dict[str, torch.Tensor],
681
+ prefix: str,
682
+ batch_size: Optional[int] = None,
683
+ **kwargs: Any,
684
+ ) -> None:
685
+ self.log_dict(
686
+ {f"{prefix}/{k}": v for k, v in dict_.items()},
687
+ batch_size=batch_size,
688
+ logger=True,
689
+ sync_dist=True,
690
+ **kwargs,
691
+ )
mvsepless/models/bandit/core/data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .dnr.datamodule import DivideAndRemasterDataModule
2
+ from .musdb.datamodule import MUSDB18DataModule
mvsepless/models/bandit/core/data/_types.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Sequence, TypedDict
2
+
3
+ import torch
4
+
5
+ AudioDict = Dict[str, torch.Tensor]
6
+
7
+ DataDict = TypedDict("DataDict", {"audio": AudioDict, "track": str})
8
+
9
+ BatchedDataDict = TypedDict(
10
+ "BatchedDataDict", {"audio": AudioDict, "track": Sequence[str]}
11
+ )
12
+
13
+
14
+ class DataDictWithLanguage(TypedDict):
15
+ audio: AudioDict
16
+ track: str
17
+ language: str
mvsepless/models/bandit/core/data/augmentation.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from typing import Any, Dict, Union
3
+
4
+ import torch
5
+ import torch_audiomentations as tam
6
+ from torch import nn
7
+
8
+ from ._types import BatchedDataDict, DataDict
9
+
10
+
11
+ class BaseAugmentor(nn.Module, ABC):
12
+ def forward(
13
+ self, item: Union[DataDict, BatchedDataDict]
14
+ ) -> Union[DataDict, BatchedDataDict]:
15
+ raise NotImplementedError
16
+
17
+
18
+ class StemAugmentor(BaseAugmentor):
19
+ def __init__(
20
+ self,
21
+ audiomentations: Dict[str, Dict[str, Any]],
22
+ fix_clipping: bool = True,
23
+ scaler_margin: float = 0.5,
24
+ apply_both_default_and_common: bool = False,
25
+ ) -> None:
26
+ super().__init__()
27
+
28
+ augmentations = {}
29
+
30
+ self.has_default = "[default]" in audiomentations
31
+ self.has_common = "[common]" in audiomentations
32
+ self.apply_both_default_and_common = apply_both_default_and_common
33
+
34
+ for stem in audiomentations:
35
+ if audiomentations[stem]["name"] == "Compose":
36
+ augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
37
+ [
38
+ getattr(tam, aug["name"])(**aug["kwargs"])
39
+ for aug in audiomentations[stem]["kwargs"]["transforms"]
40
+ ],
41
+ **audiomentations[stem]["kwargs"]["kwargs"],
42
+ )
43
+ else:
44
+ augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
45
+ **audiomentations[stem]["kwargs"]
46
+ )
47
+
48
+ self.augmentations = nn.ModuleDict(augmentations)
49
+ self.fix_clipping = fix_clipping
50
+ self.scaler_margin = scaler_margin
51
+
52
+ def check_and_fix_clipping(
53
+ self, item: Union[DataDict, BatchedDataDict]
54
+ ) -> Union[DataDict, BatchedDataDict]:
55
+ max_abs = []
56
+
57
+ for stem in item["audio"]:
58
+ max_abs.append(item["audio"][stem].abs().max().item())
59
+
60
+ if max(max_abs) > 1.0:
61
+ scaler = 1.0 / (
62
+ max(max_abs)
63
+ + torch.rand((1,), device=item["audio"]["mixture"].device)
64
+ * self.scaler_margin
65
+ )
66
+
67
+ for stem in item["audio"]:
68
+ item["audio"][stem] *= scaler
69
+
70
+ return item
71
+
72
+ def forward(
73
+ self, item: Union[DataDict, BatchedDataDict]
74
+ ) -> Union[DataDict, BatchedDataDict]:
75
+
76
+ for stem in item["audio"]:
77
+ if stem == "mixture":
78
+ continue
79
+
80
+ if self.has_common:
81
+ item["audio"][stem] = self.augmentations["[common]"](
82
+ item["audio"][stem]
83
+ ).samples
84
+
85
+ if stem in self.augmentations:
86
+ item["audio"][stem] = self.augmentations[stem](
87
+ item["audio"][stem]
88
+ ).samples
89
+ elif self.has_default:
90
+ if not self.has_common or self.apply_both_default_and_common:
91
+ item["audio"][stem] = self.augmentations["[default]"](
92
+ item["audio"][stem]
93
+ ).samples
94
+
95
+ item["audio"]["mixture"] = sum(
96
+ [item["audio"][stem] for stem in item["audio"] if stem != "mixture"]
97
+ ) # type: ignore[call-overload, assignment]
98
+
99
+ if self.fix_clipping:
100
+ item = self.check_and_fix_clipping(item)
101
+
102
+ return item
mvsepless/models/bandit/core/data/augmented.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Dict, Optional, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.utils import data
7
+
8
+
9
+ class AugmentedDataset(data.Dataset):
10
+ def __init__(
11
+ self,
12
+ dataset: data.Dataset,
13
+ augmentation: nn.Module = nn.Identity(),
14
+ target_length: Optional[int] = None,
15
+ ) -> None:
16
+ warnings.warn(
17
+ "This class is no longer used. Attach augmentation to "
18
+ "the LightningSystem instead.",
19
+ DeprecationWarning,
20
+ )
21
+
22
+ self.dataset = dataset
23
+ self.augmentation = augmentation
24
+
25
+ self.ds_length: int = len(dataset) # type: ignore[arg-type]
26
+ self.length = target_length if target_length is not None else self.ds_length
27
+
28
+ def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str, torch.Tensor]]]:
29
+ item = self.dataset[index % self.ds_length]
30
+ item = self.augmentation(item)
31
+ return item
32
+
33
+ def __len__(self) -> int:
34
+ return self.length
mvsepless/models/bandit/core/data/base.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import numpy as np
6
+ import pedalboard as pb
7
+ import torch
8
+ import torchaudio as ta
9
+ from torch.utils import data
10
+
11
+ from ._types import AudioDict, DataDict
12
+
13
+
14
+ class BaseSourceSeparationDataset(data.Dataset, ABC):
15
+ def __init__(
16
+ self,
17
+ split: str,
18
+ stems: List[str],
19
+ files: List[str],
20
+ data_path: str,
21
+ fs: int,
22
+ npy_memmap: bool,
23
+ recompute_mixture: bool,
24
+ ):
25
+ self.split = split
26
+ self.stems = stems
27
+ self.stems_no_mixture = [s for s in stems if s != "mixture"]
28
+ self.files = files
29
+ self.data_path = data_path
30
+ self.fs = fs
31
+ self.npy_memmap = npy_memmap
32
+ self.recompute_mixture = recompute_mixture
33
+
34
+ @abstractmethod
35
+ def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
36
+ raise NotImplementedError
37
+
38
+ def _get_audio(self, stems, identifier: Dict[str, Any]):
39
+ audio = {}
40
+ for stem in stems:
41
+ audio[stem] = self.get_stem(stem=stem, identifier=identifier)
42
+
43
+ return audio
44
+
45
+ def get_audio(self, identifier: Dict[str, Any]) -> AudioDict:
46
+
47
+ if self.recompute_mixture:
48
+ audio = self._get_audio(self.stems_no_mixture, identifier=identifier)
49
+ audio["mixture"] = self.compute_mixture(audio)
50
+ return audio
51
+ else:
52
+ return self._get_audio(self.stems, identifier=identifier)
53
+
54
+ @abstractmethod
55
+ def get_identifier(self, index: int) -> Dict[str, Any]:
56
+ pass
57
+
58
+ def compute_mixture(self, audio: AudioDict) -> torch.Tensor:
59
+
60
+ return sum(audio[stem] for stem in audio if stem != "mixture")
mvsepless/models/bandit/core/data/dnr/__init__.py ADDED
File without changes
mvsepless/models/bandit/core/data/dnr/datamodule.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Mapping, Optional
3
+
4
+ import pytorch_lightning as pl
5
+
6
+ from .dataset import (
7
+ DivideAndRemasterDataset,
8
+ DivideAndRemasterDeterministicChunkDataset,
9
+ DivideAndRemasterRandomChunkDataset,
10
+ DivideAndRemasterRandomChunkDatasetWithSpeechReverb,
11
+ )
12
+
13
+
14
+ def DivideAndRemasterDataModule(
15
+ data_root: str = "$DATA_ROOT/DnR/v2",
16
+ batch_size: int = 2,
17
+ num_workers: int = 8,
18
+ train_kwargs: Optional[Mapping] = None,
19
+ val_kwargs: Optional[Mapping] = None,
20
+ test_kwargs: Optional[Mapping] = None,
21
+ datamodule_kwargs: Optional[Mapping] = None,
22
+ use_speech_reverb: bool = False,
23
+ # augmentor=None
24
+ ) -> pl.LightningDataModule:
25
+ if train_kwargs is None:
26
+ train_kwargs = {}
27
+
28
+ if val_kwargs is None:
29
+ val_kwargs = {}
30
+
31
+ if test_kwargs is None:
32
+ test_kwargs = {}
33
+
34
+ if datamodule_kwargs is None:
35
+ datamodule_kwargs = {}
36
+
37
+ if num_workers is None:
38
+ num_workers = os.cpu_count()
39
+
40
+ if num_workers is None:
41
+ num_workers = 32
42
+
43
+ num_workers = min(num_workers, 64)
44
+
45
+ if use_speech_reverb:
46
+ train_cls = DivideAndRemasterRandomChunkDatasetWithSpeechReverb
47
+ else:
48
+ train_cls = DivideAndRemasterRandomChunkDataset
49
+
50
+ train_dataset = train_cls(data_root, "train", **train_kwargs)
51
+
52
+ # if augmentor is not None:
53
+ # train_dataset = AugmentedDataset(train_dataset, augmentor)
54
+
55
+ datamodule = pl.LightningDataModule.from_datasets(
56
+ train_dataset=train_dataset,
57
+ val_dataset=DivideAndRemasterDeterministicChunkDataset(
58
+ data_root, "val", **val_kwargs
59
+ ),
60
+ test_dataset=DivideAndRemasterDataset(data_root, "test", **test_kwargs),
61
+ batch_size=batch_size,
62
+ num_workers=num_workers,
63
+ **datamodule_kwargs
64
+ )
65
+
66
+ datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign]
67
+
68
+ return datamodule
mvsepless/models/bandit/core/data/dnr/dataset.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import numpy as np
6
+ import pedalboard as pb
7
+ import torch
8
+ import torchaudio as ta
9
+ from torch.utils import data
10
+
11
+ from .._types import AudioDict, DataDict
12
+ from ..base import BaseSourceSeparationDataset
13
+
14
+
15
+ class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC):
16
+ ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"]
17
+ STEM_NAME_MAP = {
18
+ "mixture": "mix",
19
+ "speech": "speech",
20
+ "music": "music",
21
+ "effects": "sfx",
22
+ }
23
+ SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"}
24
+
25
+ FULL_TRACK_LENGTH_SECOND = 60
26
+ FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100
27
+
28
+ def __init__(
29
+ self,
30
+ split: str,
31
+ stems: List[str],
32
+ files: List[str],
33
+ data_path: str,
34
+ fs: int = 44100,
35
+ npy_memmap: bool = True,
36
+ recompute_mixture: bool = False,
37
+ ) -> None:
38
+ super().__init__(
39
+ split=split,
40
+ stems=stems,
41
+ files=files,
42
+ data_path=data_path,
43
+ fs=fs,
44
+ npy_memmap=npy_memmap,
45
+ recompute_mixture=recompute_mixture,
46
+ )
47
+
48
+ def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
49
+
50
+ if stem == "mne":
51
+ return self.get_stem(stem="music", identifier=identifier) + self.get_stem(
52
+ stem="effects", identifier=identifier
53
+ )
54
+
55
+ track = identifier["track"]
56
+ path = os.path.join(self.data_path, track)
57
+
58
+ if self.npy_memmap:
59
+ audio = np.load(
60
+ os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"), mmap_mode="r"
61
+ )
62
+ else:
63
+ # noinspection PyUnresolvedReferences
64
+ audio, _ = ta.load(os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav"))
65
+
66
+ return audio
67
+
68
+ def get_identifier(self, index):
69
+ return dict(track=self.files[index])
70
+
71
+ def __getitem__(self, index: int) -> DataDict:
72
+ identifier = self.get_identifier(index)
73
+ audio = self.get_audio(identifier)
74
+
75
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
76
+
77
+
78
+ class DivideAndRemasterDataset(DivideAndRemasterBaseDataset):
79
+ def __init__(
80
+ self,
81
+ data_root: str,
82
+ split: str,
83
+ stems: Optional[List[str]] = None,
84
+ fs: int = 44100,
85
+ npy_memmap: bool = True,
86
+ ) -> None:
87
+
88
+ if stems is None:
89
+ stems = self.ALLOWED_STEMS
90
+ self.stems = stems
91
+
92
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
93
+
94
+ files = sorted(os.listdir(data_path))
95
+ files = [
96
+ f
97
+ for f in files
98
+ if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
99
+ ]
100
+ # pprint(list(enumerate(files)))
101
+ if split == "train":
102
+ assert len(files) == 3406, len(files)
103
+ elif split == "val":
104
+ assert len(files) == 487, len(files)
105
+ elif split == "test":
106
+ assert len(files) == 973, len(files)
107
+
108
+ self.n_tracks = len(files)
109
+
110
+ super().__init__(
111
+ data_path=data_path,
112
+ split=split,
113
+ stems=stems,
114
+ files=files,
115
+ fs=fs,
116
+ npy_memmap=npy_memmap,
117
+ )
118
+
119
+ def __len__(self) -> int:
120
+ return self.n_tracks
121
+
122
+
123
+ class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset):
124
+ def __init__(
125
+ self,
126
+ data_root: str,
127
+ split: str,
128
+ target_length: int,
129
+ chunk_size_second: float,
130
+ stems: Optional[List[str]] = None,
131
+ fs: int = 44100,
132
+ npy_memmap: bool = True,
133
+ ) -> None:
134
+
135
+ if stems is None:
136
+ stems = self.ALLOWED_STEMS
137
+ self.stems = stems
138
+
139
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
140
+
141
+ files = sorted(os.listdir(data_path))
142
+ files = [
143
+ f
144
+ for f in files
145
+ if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
146
+ ]
147
+
148
+ if split == "train":
149
+ assert len(files) == 3406, len(files)
150
+ elif split == "val":
151
+ assert len(files) == 487, len(files)
152
+ elif split == "test":
153
+ assert len(files) == 973, len(files)
154
+
155
+ self.n_tracks = len(files)
156
+
157
+ self.target_length = target_length
158
+ self.chunk_size = int(chunk_size_second * fs)
159
+
160
+ super().__init__(
161
+ data_path=data_path,
162
+ split=split,
163
+ stems=stems,
164
+ files=files,
165
+ fs=fs,
166
+ npy_memmap=npy_memmap,
167
+ )
168
+
169
+ def __len__(self) -> int:
170
+ return self.target_length
171
+
172
+ def get_identifier(self, index):
173
+ return super().get_identifier(index % self.n_tracks)
174
+
175
+ def get_stem(
176
+ self,
177
+ *,
178
+ stem: str,
179
+ identifier: Dict[str, Any],
180
+ chunk_here: bool = False,
181
+ ) -> torch.Tensor:
182
+
183
+ stem = super().get_stem(stem=stem, identifier=identifier)
184
+
185
+ if chunk_here:
186
+ start = np.random.randint(
187
+ 0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
188
+ )
189
+ end = start + self.chunk_size
190
+
191
+ stem = stem[:, start:end]
192
+
193
+ return stem
194
+
195
+ def __getitem__(self, index: int) -> DataDict:
196
+ identifier = self.get_identifier(index)
197
+ # self.index_lock = index
198
+ audio = self.get_audio(identifier)
199
+ # self.index_lock = None
200
+
201
+ start = np.random.randint(0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size)
202
+ end = start + self.chunk_size
203
+
204
+ audio = {k: v[:, start:end] for k, v in audio.items()}
205
+
206
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
207
+
208
+
209
+ class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset):
210
+ def __init__(
211
+ self,
212
+ data_root: str,
213
+ split: str,
214
+ chunk_size_second: float,
215
+ hop_size_second: float,
216
+ stems: Optional[List[str]] = None,
217
+ fs: int = 44100,
218
+ npy_memmap: bool = True,
219
+ ) -> None:
220
+
221
+ if stems is None:
222
+ stems = self.ALLOWED_STEMS
223
+ self.stems = stems
224
+
225
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
226
+
227
+ files = sorted(os.listdir(data_path))
228
+ files = [
229
+ f
230
+ for f in files
231
+ if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
232
+ ]
233
+ # pprint(list(enumerate(files)))
234
+ if split == "train":
235
+ assert len(files) == 3406, len(files)
236
+ elif split == "val":
237
+ assert len(files) == 487, len(files)
238
+ elif split == "test":
239
+ assert len(files) == 973, len(files)
240
+
241
+ self.n_tracks = len(files)
242
+
243
+ self.chunk_size = int(chunk_size_second * fs)
244
+ self.hop_size = int(hop_size_second * fs)
245
+ self.n_chunks_per_track = int(
246
+ (self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second
247
+ )
248
+
249
+ self.length = self.n_tracks * self.n_chunks_per_track
250
+
251
+ super().__init__(
252
+ data_path=data_path,
253
+ split=split,
254
+ stems=stems,
255
+ files=files,
256
+ fs=fs,
257
+ npy_memmap=npy_memmap,
258
+ )
259
+
260
+ def get_identifier(self, index):
261
+ return super().get_identifier(index % self.n_tracks)
262
+
263
+ def __len__(self) -> int:
264
+ return self.length
265
+
266
+ def __getitem__(self, item: int) -> DataDict:
267
+
268
+ index = item % self.n_tracks
269
+ chunk = item // self.n_tracks
270
+
271
+ data_ = super().__getitem__(index)
272
+
273
+ audio = data_["audio"]
274
+
275
+ start = chunk * self.hop_size
276
+ end = start + self.chunk_size
277
+
278
+ for stem in self.stems:
279
+ data_["audio"][stem] = audio[stem][:, start:end]
280
+
281
+ return data_
282
+
283
+
284
+ class DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
285
+ DivideAndRemasterRandomChunkDataset
286
+ ):
287
+ def __init__(
288
+ self,
289
+ data_root: str,
290
+ split: str,
291
+ target_length: int,
292
+ chunk_size_second: float,
293
+ stems: Optional[List[str]] = None,
294
+ fs: int = 44100,
295
+ npy_memmap: bool = True,
296
+ ) -> None:
297
+
298
+ if stems is None:
299
+ stems = self.ALLOWED_STEMS
300
+
301
+ stems_no_mixture = [s for s in stems if s != "mixture"]
302
+
303
+ super().__init__(
304
+ data_root=data_root,
305
+ split=split,
306
+ target_length=target_length,
307
+ chunk_size_second=chunk_size_second,
308
+ stems=stems_no_mixture,
309
+ fs=fs,
310
+ npy_memmap=npy_memmap,
311
+ )
312
+
313
+ self.stems = stems
314
+ self.stems_no_mixture = stems_no_mixture
315
+
316
+ def __getitem__(self, index: int) -> DataDict:
317
+
318
+ data_ = super().__getitem__(index)
319
+
320
+ dry = data_["audio"]["speech"][:]
321
+ n_samples = dry.shape[-1]
322
+
323
+ wet_level = np.random.rand()
324
+
325
+ speech = pb.Reverb(
326
+ room_size=np.random.rand(),
327
+ damping=np.random.rand(),
328
+ wet_level=wet_level,
329
+ dry_level=(1 - wet_level),
330
+ width=np.random.rand(),
331
+ ).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples]
332
+
333
+ data_["audio"]["speech"] = speech
334
+
335
+ data_["audio"]["mixture"] = sum(
336
+ [data_["audio"][s] for s in self.stems_no_mixture]
337
+ )
338
+
339
+ return data_
340
+
341
+ def __len__(self) -> int:
342
+ return super().__len__()
343
+
344
+
345
+ if __name__ == "__main__":
346
+
347
+ from pprint import pprint
348
+ from tqdm.auto import tqdm
349
+
350
+ for split_ in ["train", "val", "test"]:
351
+ ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
352
+ data_root="$DATA_ROOT/DnR/v2np",
353
+ split=split_,
354
+ target_length=100,
355
+ chunk_size_second=6.0,
356
+ )
357
+
358
+ print(split_, len(ds))
359
+
360
+ for track_ in tqdm(ds): # type: ignore
361
+ pprint(track_)
362
+ track_["audio"] = {k: v.shape for k, v in track_["audio"].items()}
363
+ pprint(track_)
364
+ # break
365
+
366
+ break
mvsepless/models/bandit/core/data/dnr/preprocess.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from typing import Tuple
4
+
5
+ import numpy as np
6
+ import torchaudio as ta
7
+ from tqdm.contrib.concurrent import process_map
8
+
9
+
10
+ def process_one(inputs: Tuple[str, str, int]) -> None:
11
+ infile, outfile, target_fs = inputs
12
+
13
+ dir = os.path.dirname(outfile)
14
+ os.makedirs(dir, exist_ok=True)
15
+
16
+ data, fs = ta.load(infile)
17
+
18
+ if fs != target_fs:
19
+ data = ta.functional.resample(
20
+ data, fs, target_fs, resampling_method="sinc_interp_kaiser"
21
+ )
22
+ fs = target_fs
23
+
24
+ data = data.numpy()
25
+ data = data.astype(np.float32)
26
+
27
+ if os.path.exists(outfile):
28
+ data_ = np.load(outfile)
29
+ if np.allclose(data, data_):
30
+ return
31
+
32
+ np.save(outfile, data)
33
+
34
+
35
+ def preprocess(data_path: str, output_path: str, fs: int) -> None:
36
+ files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
37
+ print(files)
38
+ outfiles = [
39
+ f.replace(data_path, output_path).replace(".wav", ".npy") for f in files
40
+ ]
41
+
42
+ os.makedirs(output_path, exist_ok=True)
43
+ inputs = list(zip(files, outfiles, [fs] * len(files)))
44
+
45
+ process_map(process_one, inputs, chunksize=32)
46
+
47
+
48
+ if __name__ == "__main__":
49
+ import fire
50
+
51
+ fire.Fire()
mvsepless/models/bandit/core/data/musdb/__init__.py ADDED
File without changes
mvsepless/models/bandit/core/data/musdb/datamodule.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from typing import Mapping, Optional
3
+
4
+ import pytorch_lightning as pl
5
+
6
+ from .dataset import (
7
+ MUSDB18BaseDataset,
8
+ MUSDB18FullTrackDataset,
9
+ MUSDB18SadDataset,
10
+ MUSDB18SadOnTheFlyAugmentedDataset,
11
+ )
12
+
13
+
14
+ def MUSDB18DataModule(
15
+ data_root: str = "$DATA_ROOT/MUSDB18/HQ",
16
+ target_stem: str = "vocals",
17
+ batch_size: int = 2,
18
+ num_workers: int = 8,
19
+ train_kwargs: Optional[Mapping] = None,
20
+ val_kwargs: Optional[Mapping] = None,
21
+ test_kwargs: Optional[Mapping] = None,
22
+ datamodule_kwargs: Optional[Mapping] = None,
23
+ use_on_the_fly: bool = True,
24
+ npy_memmap: bool = True,
25
+ ) -> pl.LightningDataModule:
26
+ if train_kwargs is None:
27
+ train_kwargs = {}
28
+
29
+ if val_kwargs is None:
30
+ val_kwargs = {}
31
+
32
+ if test_kwargs is None:
33
+ test_kwargs = {}
34
+
35
+ if datamodule_kwargs is None:
36
+ datamodule_kwargs = {}
37
+
38
+ train_dataset: MUSDB18BaseDataset
39
+
40
+ if use_on_the_fly:
41
+ train_dataset = MUSDB18SadOnTheFlyAugmentedDataset(
42
+ data_root=os.path.join(data_root, "saded-np"),
43
+ split="train",
44
+ target_stem=target_stem,
45
+ **train_kwargs
46
+ )
47
+ else:
48
+ train_dataset = MUSDB18SadDataset(
49
+ data_root=os.path.join(data_root, "saded-np"),
50
+ split="train",
51
+ target_stem=target_stem,
52
+ **train_kwargs
53
+ )
54
+
55
+ datamodule = pl.LightningDataModule.from_datasets(
56
+ train_dataset=train_dataset,
57
+ val_dataset=MUSDB18SadDataset(
58
+ data_root=os.path.join(data_root, "saded-np"),
59
+ split="val",
60
+ target_stem=target_stem,
61
+ **val_kwargs
62
+ ),
63
+ test_dataset=MUSDB18FullTrackDataset(
64
+ data_root=os.path.join(data_root, "canonical"), split="test", **test_kwargs
65
+ ),
66
+ batch_size=batch_size,
67
+ num_workers=num_workers,
68
+ **datamodule_kwargs
69
+ )
70
+
71
+ datamodule.predict_dataloader = ( # type: ignore[method-assign]
72
+ datamodule.test_dataloader
73
+ )
74
+
75
+ return datamodule
mvsepless/models/bandit/core/data/musdb/dataset.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC
3
+ from typing import List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchaudio as ta
8
+ from torch.utils import data
9
+
10
+ from .._types import AudioDict, DataDict
11
+ from ..base import BaseSourceSeparationDataset
12
+
13
+
14
+ class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC):
15
+
16
+ ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"]
17
+
18
+ def __init__(
19
+ self,
20
+ split: str,
21
+ stems: List[str],
22
+ files: List[str],
23
+ data_path: str,
24
+ fs: int = 44100,
25
+ npy_memmap=False,
26
+ ) -> None:
27
+ super().__init__(
28
+ split=split,
29
+ stems=stems,
30
+ files=files,
31
+ data_path=data_path,
32
+ fs=fs,
33
+ npy_memmap=npy_memmap,
34
+ recompute_mixture=False,
35
+ )
36
+
37
+ def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
38
+ track = identifier["track"]
39
+ path = os.path.join(self.data_path, track)
40
+ # noinspection PyUnresolvedReferences
41
+
42
+ if self.npy_memmap:
43
+ audio = np.load(os.path.join(path, f"{stem}.wav.npy"), mmap_mode="r")
44
+ else:
45
+ audio, _ = ta.load(os.path.join(path, f"{stem}.wav"))
46
+
47
+ return audio
48
+
49
+ def get_identifier(self, index):
50
+ return dict(track=self.files[index])
51
+
52
+ def __getitem__(self, index: int) -> DataDict:
53
+ identifier = self.get_identifier(index)
54
+ audio = self.get_audio(identifier)
55
+
56
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
57
+
58
+
59
+ class MUSDB18FullTrackDataset(MUSDB18BaseDataset):
60
+
61
+ N_TRAIN_TRACKS = 100
62
+ N_TEST_TRACKS = 50
63
+ VALIDATION_FILES = [
64
+ "Actions - One Minute Smile",
65
+ "Clara Berry And Wooldog - Waltz For My Victims",
66
+ "Johnny Lokke - Promises & Lies",
67
+ "Patrick Talbot - A Reason To Leave",
68
+ "Triviul - Angelsaint",
69
+ "Alexander Ross - Goodbye Bolero",
70
+ "Fergessen - Nos Palpitants",
71
+ "Leaf - Summerghost",
72
+ "Skelpolu - Human Mistakes",
73
+ "Young Griffo - Pennies",
74
+ "ANiMAL - Rockshow",
75
+ "James May - On The Line",
76
+ "Meaxic - Take A Step",
77
+ "Traffic Experiment - Sirens",
78
+ ]
79
+
80
+ def __init__(
81
+ self, data_root: str, split: str, stems: Optional[List[str]] = None
82
+ ) -> None:
83
+
84
+ if stems is None:
85
+ stems = self.ALLOWED_STEMS
86
+ self.stems = stems
87
+
88
+ if split == "test":
89
+ subset = "test"
90
+ elif split in ["train", "val"]:
91
+ subset = "train"
92
+ else:
93
+ raise NameError
94
+
95
+ data_path = os.path.join(data_root, subset)
96
+
97
+ files = sorted(os.listdir(data_path))
98
+ files = [f for f in files if not f.startswith(".")]
99
+ # pprint(list(enumerate(files)))
100
+ if subset == "train":
101
+ assert len(files) == 100, len(files)
102
+ if split == "train":
103
+ files = [f for f in files if f not in self.VALIDATION_FILES]
104
+ assert len(files) == 100 - len(self.VALIDATION_FILES)
105
+ else:
106
+ files = [f for f in files if f in self.VALIDATION_FILES]
107
+ assert len(files) == len(self.VALIDATION_FILES)
108
+ else:
109
+ split = "test"
110
+ assert len(files) == 50
111
+
112
+ self.n_tracks = len(files)
113
+
114
+ super().__init__(data_path=data_path, split=split, stems=stems, files=files)
115
+
116
+ def __len__(self) -> int:
117
+ return self.n_tracks
118
+
119
+
120
+ class MUSDB18SadDataset(MUSDB18BaseDataset):
121
+ def __init__(
122
+ self,
123
+ data_root: str,
124
+ split: str,
125
+ target_stem: str,
126
+ stems: Optional[List[str]] = None,
127
+ target_length: Optional[int] = None,
128
+ npy_memmap=False,
129
+ ) -> None:
130
+
131
+ if stems is None:
132
+ stems = self.ALLOWED_STEMS
133
+
134
+ data_path = os.path.join(data_root, target_stem, split)
135
+
136
+ files = sorted(os.listdir(data_path))
137
+ files = [f for f in files if not f.startswith(".")]
138
+
139
+ super().__init__(
140
+ data_path=data_path,
141
+ split=split,
142
+ stems=stems,
143
+ files=files,
144
+ npy_memmap=npy_memmap,
145
+ )
146
+ self.n_segments = len(files)
147
+ self.target_stem = target_stem
148
+ self.target_length = (
149
+ target_length if target_length is not None else self.n_segments
150
+ )
151
+
152
+ def __len__(self) -> int:
153
+ return self.target_length
154
+
155
+ def __getitem__(self, index: int) -> DataDict:
156
+
157
+ index = index % self.n_segments
158
+
159
+ return super().__getitem__(index)
160
+
161
+ def get_identifier(self, index):
162
+ return super().get_identifier(index % self.n_segments)
163
+
164
+
165
+ class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset):
166
+ def __init__(
167
+ self,
168
+ data_root: str,
169
+ split: str,
170
+ target_stem: str,
171
+ stems: Optional[List[str]] = None,
172
+ target_length: int = 20000,
173
+ apply_probability: Optional[float] = None,
174
+ chunk_size_second: float = 3.0,
175
+ random_scale_range_db: Tuple[float, float] = (-10, 10),
176
+ drop_probability: float = 0.1,
177
+ rescale: bool = True,
178
+ ) -> None:
179
+ super().__init__(data_root, split, target_stem, stems)
180
+
181
+ if apply_probability is None:
182
+ apply_probability = (target_length - self.n_segments) / target_length
183
+
184
+ self.apply_probability = apply_probability
185
+ self.drop_probability = drop_probability
186
+ self.chunk_size_second = chunk_size_second
187
+ self.random_scale_range_db = random_scale_range_db
188
+ self.rescale = rescale
189
+
190
+ self.chunk_size_sample = int(self.chunk_size_second * self.fs)
191
+ self.target_length = target_length
192
+
193
+ def __len__(self) -> int:
194
+ return self.target_length
195
+
196
+ def __getitem__(self, index: int) -> DataDict:
197
+
198
+ index = index % self.n_segments
199
+
200
+ # if np.random.rand() > self.apply_probability:
201
+ # return super().__getitem__(index)
202
+
203
+ audio = {}
204
+ identifier = self.get_identifier(index)
205
+
206
+ # assert self.target_stem in self.stems_no_mixture
207
+ for stem in self.stems_no_mixture:
208
+ if stem == self.target_stem:
209
+ identifier_ = identifier
210
+ else:
211
+ if np.random.rand() < self.apply_probability:
212
+ index_ = np.random.randint(self.n_segments)
213
+ identifier_ = self.get_identifier(index_)
214
+ else:
215
+ identifier_ = identifier
216
+
217
+ audio[stem] = self.get_stem(stem=stem, identifier=identifier_)
218
+
219
+ # if stem == self.target_stem:
220
+
221
+ if self.chunk_size_sample < audio[stem].shape[-1]:
222
+ chunk_start = np.random.randint(
223
+ audio[stem].shape[-1] - self.chunk_size_sample
224
+ )
225
+ else:
226
+ chunk_start = 0
227
+
228
+ if np.random.rand() < self.drop_probability:
229
+ # db_scale = "-inf"
230
+ linear_scale = 0.0
231
+ else:
232
+ db_scale = np.random.uniform(*self.random_scale_range_db)
233
+ linear_scale = np.power(10, db_scale / 20)
234
+ # db_scale = f"{db_scale:+2.1f}"
235
+ # print(linear_scale)
236
+ audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample] = (
237
+ linear_scale
238
+ * audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample]
239
+ )
240
+
241
+ audio["mixture"] = self.compute_mixture(audio)
242
+
243
+ if self.rescale:
244
+ max_abs_val = max(
245
+ [torch.max(torch.abs(audio[stem])) for stem in self.stems]
246
+ ) # type: ignore[type-var]
247
+ if max_abs_val > 1:
248
+ audio = {k: v / max_abs_val for k, v in audio.items()}
249
+
250
+ track = identifier["track"]
251
+
252
+ return {"audio": audio, "track": f"{self.split}/{track}"}
253
+
254
+
255
+ # if __name__ == "__main__":
256
+ #
257
+ # from pprint import pprint
258
+ # from tqdm.auto import tqdm
259
+ #
260
+ # for split_ in ["train", "val", "test"]:
261
+ # ds = MUSDB18SadOnTheFlyAugmentedDataset(
262
+ # data_root="$DATA_ROOT/MUSDB18/HQ/saded",
263
+ # split=split_,
264
+ # target_stem="vocals"
265
+ # )
266
+ #
267
+ # print(split_, len(ds))
268
+ #
269
+ # for track_ in tqdm(ds):
270
+ # track_["audio"] = {
271
+ # k: v.shape for k, v in track_["audio"].items()
272
+ # }
273
+ # pprint(track_)
mvsepless/models/bandit/core/data/musdb/preprocess.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torchaudio as ta
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from tqdm.contrib.concurrent import process_map
10
+
11
+ from .._types import DataDict
12
+ from .dataset import MUSDB18FullTrackDataset
13
+ import pyloudnorm as pyln
14
+
15
+
16
+ class SourceActivityDetector(nn.Module):
17
+ def __init__(
18
+ self,
19
+ analysis_stem: str,
20
+ output_path: str,
21
+ fs: int = 44100,
22
+ segment_length_second: float = 6.0,
23
+ hop_length_second: float = 3.0,
24
+ n_chunks: int = 10,
25
+ chunk_epsilon: float = 1e-5,
26
+ energy_threshold_quantile: float = 0.15,
27
+ segment_epsilon: float = 1e-3,
28
+ salient_proportion_threshold: float = 0.5,
29
+ target_lufs: float = -24,
30
+ ) -> None:
31
+ super().__init__()
32
+
33
+ self.fs = fs
34
+ self.segment_length = int(segment_length_second * self.fs)
35
+ self.hop_length = int(hop_length_second * self.fs)
36
+ self.n_chunks = n_chunks
37
+ assert self.segment_length % self.n_chunks == 0
38
+ self.chunk_size = self.segment_length // self.n_chunks
39
+ self.chunk_epsilon = chunk_epsilon
40
+ self.energy_threshold_quantile = energy_threshold_quantile
41
+ self.segment_epsilon = segment_epsilon
42
+ self.salient_proportion_threshold = salient_proportion_threshold
43
+ self.analysis_stem = analysis_stem
44
+
45
+ self.meter = pyln.Meter(self.fs)
46
+ self.target_lufs = target_lufs
47
+
48
+ self.output_path = output_path
49
+
50
+ def forward(self, data: DataDict) -> None:
51
+
52
+ stem_ = self.analysis_stem if (self.analysis_stem != "none") else "mixture"
53
+
54
+ x = data["audio"][stem_]
55
+
56
+ xnp = x.numpy()
57
+ loudness = self.meter.integrated_loudness(xnp.T)
58
+
59
+ for stem in data["audio"]:
60
+ s = data["audio"][stem]
61
+ s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T
62
+ s = torch.as_tensor(s)
63
+ data["audio"][stem] = s
64
+
65
+ if x.ndim == 3:
66
+ assert x.shape[0] == 1
67
+ x = x[0]
68
+
69
+ n_chan, n_samples = x.shape
70
+
71
+ n_segments = (
72
+ int(np.ceil((n_samples - self.segment_length) / self.hop_length)) + 1
73
+ )
74
+
75
+ segments = torch.zeros((n_segments, n_chan, self.segment_length))
76
+ for i in range(n_segments):
77
+ start = i * self.hop_length
78
+ end = start + self.segment_length
79
+ end = min(end, n_samples)
80
+
81
+ xseg = x[:, start:end]
82
+
83
+ if end - start < self.segment_length:
84
+ xseg = F.pad(
85
+ xseg, pad=(0, self.segment_length - (end - start)), value=torch.nan
86
+ )
87
+
88
+ segments[i, :, :] = xseg
89
+
90
+ chunks = segments.reshape((n_segments, n_chan, self.n_chunks, self.chunk_size))
91
+
92
+ if self.analysis_stem != "none":
93
+ chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3))
94
+ chunk_energies = torch.nan_to_num(chunk_energies, nan=0)
95
+ chunk_energies[chunk_energies == 0] = self.chunk_epsilon
96
+
97
+ energy_threshold = torch.nanquantile(
98
+ chunk_energies, q=self.energy_threshold_quantile
99
+ )
100
+
101
+ if energy_threshold < self.segment_epsilon:
102
+ energy_threshold = self.segment_epsilon # type: ignore[assignment]
103
+
104
+ chunks_above_threshold = chunk_energies > energy_threshold
105
+ n_chunks_above_threshold = torch.mean(
106
+ chunks_above_threshold.to(torch.float), dim=-1
107
+ )
108
+
109
+ segment_above_threshold = (
110
+ n_chunks_above_threshold > self.salient_proportion_threshold
111
+ )
112
+
113
+ if torch.sum(segment_above_threshold) == 0:
114
+ return
115
+
116
+ else:
117
+ segment_above_threshold = torch.ones((n_segments,))
118
+
119
+ for i in range(n_segments):
120
+ if not segment_above_threshold[i]:
121
+ continue
122
+
123
+ outpath = os.path.join(
124
+ self.output_path,
125
+ self.analysis_stem,
126
+ f"{data['track']} - {self.analysis_stem}{i:03d}",
127
+ )
128
+ os.makedirs(outpath, exist_ok=True)
129
+
130
+ for stem in data["audio"]:
131
+ if stem == self.analysis_stem:
132
+ segment = torch.nan_to_num(segments[i, :, :], nan=0)
133
+ else:
134
+ start = i * self.hop_length
135
+ end = start + self.segment_length
136
+ end = min(n_samples, end)
137
+
138
+ segment = data["audio"][stem][:, start:end]
139
+
140
+ if end - start < self.segment_length:
141
+ segment = F.pad(
142
+ segment, (0, self.segment_length - (end - start))
143
+ )
144
+
145
+ assert segment.shape[-1] == self.segment_length, segment.shape
146
+
147
+ # ta.save(os.path.join(outpath, f"{stem}.wav"), segment, self.fs)
148
+
149
+ np.save(os.path.join(outpath, f"{stem}.wav"), segment)
150
+
151
+
152
+ def preprocess(
153
+ analysis_stem: str,
154
+ output_path: str = "/data/MUSDB18/HQ/saded-np",
155
+ fs: int = 44100,
156
+ segment_length_second: float = 6.0,
157
+ hop_length_second: float = 3.0,
158
+ n_chunks: int = 10,
159
+ chunk_epsilon: float = 1e-5,
160
+ energy_threshold_quantile: float = 0.15,
161
+ segment_epsilon: float = 1e-3,
162
+ salient_proportion_threshold: float = 0.5,
163
+ ) -> None:
164
+
165
+ sad = SourceActivityDetector(
166
+ analysis_stem=analysis_stem,
167
+ output_path=output_path,
168
+ fs=fs,
169
+ segment_length_second=segment_length_second,
170
+ hop_length_second=hop_length_second,
171
+ n_chunks=n_chunks,
172
+ chunk_epsilon=chunk_epsilon,
173
+ energy_threshold_quantile=energy_threshold_quantile,
174
+ segment_epsilon=segment_epsilon,
175
+ salient_proportion_threshold=salient_proportion_threshold,
176
+ )
177
+
178
+ for split in ["train", "val", "test"]:
179
+ ds = MUSDB18FullTrackDataset(
180
+ data_root="/data/MUSDB18/HQ/canonical",
181
+ split=split,
182
+ )
183
+
184
+ tracks = []
185
+ for i, track in enumerate(tqdm(ds, total=len(ds))):
186
+ if i % 32 == 0 and tracks:
187
+ process_map(sad, tracks, max_workers=8)
188
+ tracks = []
189
+ tracks.append(track)
190
+ process_map(sad, tracks, max_workers=8)
191
+
192
+
193
+ def loudness_norm_one(inputs):
194
+ infile, outfile, target_lufs = inputs
195
+
196
+ audio, fs = ta.load(infile)
197
+ audio = audio.mean(dim=0, keepdim=True).numpy().T
198
+
199
+ meter = pyln.Meter(fs)
200
+ loudness = meter.integrated_loudness(audio)
201
+ audio = pyln.normalize.loudness(audio, loudness, target_lufs)
202
+
203
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
204
+ np.save(outfile, audio.T)
205
+
206
+
207
+ def loudness_norm(
208
+ data_path: str,
209
+ # output_path: str,
210
+ target_lufs=-17.0,
211
+ ):
212
+ files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
213
+
214
+ outfiles = [f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files]
215
+
216
+ files = [(f, o, target_lufs) for f, o in zip(files, outfiles)]
217
+
218
+ process_map(loudness_norm_one, files, chunksize=2)
219
+
220
+
221
+ if __name__ == "__main__":
222
+
223
+ from tqdm.auto import tqdm
224
+ import fire
225
+
226
+ fire.Fire()
mvsepless/models/bandit/core/data/musdb/validation.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ validation:
2
+ - 'Actions - One Minute Smile'
3
+ - 'Clara Berry And Wooldog - Waltz For My Victims'
4
+ - 'Johnny Lokke - Promises & Lies'
5
+ - 'Patrick Talbot - A Reason To Leave'
6
+ - 'Triviul - Angelsaint'
7
+ - 'Alexander Ross - Goodbye Bolero'
8
+ - 'Fergessen - Nos Palpitants'
9
+ - 'Leaf - Summerghost'
10
+ - 'Skelpolu - Human Mistakes'
11
+ - 'Young Griffo - Pennies'
12
+ - 'ANiMAL - Rockshow'
13
+ - 'James May - On The Line'
14
+ - 'Meaxic - Take A Step'
15
+ - 'Traffic Experiment - Sirens'
mvsepless/models/bandit/core/loss/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from ._multistem import MultiStemWrapperFromConfig
2
+ from ._timefreq import (
3
+ ReImL1Loss,
4
+ ReImL2Loss,
5
+ TimeFreqL1Loss,
6
+ TimeFreqL2Loss,
7
+ TimeFreqSignalNoisePNormRatioLoss,
8
+ )
mvsepless/models/bandit/core/loss/_complex.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.modules import loss as _loss
6
+ from torch.nn.modules.loss import _Loss
7
+
8
+
9
+ class ReImLossWrapper(_Loss):
10
+ def __init__(self, module: _Loss) -> None:
11
+ super().__init__()
12
+ self.module = module
13
+
14
+ def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
15
+ return self.module(torch.view_as_real(preds), torch.view_as_real(target))
16
+
17
+
18
+ class ReImL1Loss(ReImLossWrapper):
19
+ def __init__(self, **kwargs: Any) -> None:
20
+ l1_loss = _loss.L1Loss(**kwargs)
21
+ super().__init__(module=(l1_loss))
22
+
23
+
24
+ class ReImL2Loss(ReImLossWrapper):
25
+ def __init__(self, **kwargs: Any) -> None:
26
+ l2_loss = _loss.MSELoss(**kwargs)
27
+ super().__init__(module=(l2_loss))
mvsepless/models/bandit/core/loss/_multistem.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ import torch
4
+ from asteroid import losses as asteroid_losses
5
+ from torch import nn
6
+ from torch.nn.modules.loss import _Loss
7
+
8
+ from . import snr
9
+
10
+
11
+ def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss:
12
+
13
+ for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]:
14
+ if name in module.__dict__:
15
+ return module.__dict__[name](**kwargs)
16
+
17
+ raise NameError
18
+
19
+
20
+ class MultiStemWrapper(_Loss):
21
+ def __init__(self, module: _Loss, modality: str = "audio") -> None:
22
+ super().__init__()
23
+ self.loss = module
24
+ self.modality = modality
25
+
26
+ def forward(
27
+ self,
28
+ preds: Dict[str, Dict[str, torch.Tensor]],
29
+ target: Dict[str, Dict[str, torch.Tensor]],
30
+ ) -> torch.Tensor:
31
+ loss = {
32
+ stem: self.loss(preds[self.modality][stem], target[self.modality][stem])
33
+ for stem in preds[self.modality]
34
+ if stem in target[self.modality]
35
+ }
36
+
37
+ return sum(list(loss.values()))
38
+
39
+
40
+ class MultiStemWrapperFromConfig(MultiStemWrapper):
41
+ def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None:
42
+ loss = parse_loss(name, kwargs)
43
+ super().__init__(module=loss, modality=modality)
mvsepless/models/bandit/core/loss/_timefreq.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.modules.loss import _Loss
6
+
7
+ from ._multistem import MultiStemWrapper
8
+ from ._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper
9
+ from .snr import SignalNoisePNormRatio
10
+
11
+
12
+ class TimeFreqWrapper(_Loss):
13
+ def __init__(
14
+ self,
15
+ time_module: _Loss,
16
+ freq_module: Optional[_Loss] = None,
17
+ time_weight: float = 1.0,
18
+ freq_weight: float = 1.0,
19
+ multistem: bool = True,
20
+ ) -> None:
21
+ super().__init__()
22
+
23
+ if freq_module is None:
24
+ freq_module = time_module
25
+
26
+ if multistem:
27
+ time_module = MultiStemWrapper(time_module, modality="audio")
28
+ freq_module = MultiStemWrapper(freq_module, modality="spectrogram")
29
+
30
+ self.time_module = time_module
31
+ self.freq_module = freq_module
32
+
33
+ self.time_weight = time_weight
34
+ self.freq_weight = freq_weight
35
+
36
+ # TODO: add better type hints
37
+ def forward(self, preds: Any, target: Any) -> torch.Tensor:
38
+
39
+ return self.time_weight * self.time_module(
40
+ preds, target
41
+ ) + self.freq_weight * self.freq_module(preds, target)
42
+
43
+
44
+ class TimeFreqL1Loss(TimeFreqWrapper):
45
+ def __init__(
46
+ self,
47
+ time_weight: float = 1.0,
48
+ freq_weight: float = 1.0,
49
+ tkwargs: Optional[Dict[str, Any]] = None,
50
+ fkwargs: Optional[Dict[str, Any]] = None,
51
+ multistem: bool = True,
52
+ ) -> None:
53
+ if tkwargs is None:
54
+ tkwargs = {}
55
+ if fkwargs is None:
56
+ fkwargs = {}
57
+ time_module = nn.L1Loss(**tkwargs)
58
+ freq_module = ReImL1Loss(**fkwargs)
59
+ super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
60
+
61
+
62
+ class TimeFreqL2Loss(TimeFreqWrapper):
63
+ def __init__(
64
+ self,
65
+ time_weight: float = 1.0,
66
+ freq_weight: float = 1.0,
67
+ tkwargs: Optional[Dict[str, Any]] = None,
68
+ fkwargs: Optional[Dict[str, Any]] = None,
69
+ multistem: bool = True,
70
+ ) -> None:
71
+ if tkwargs is None:
72
+ tkwargs = {}
73
+ if fkwargs is None:
74
+ fkwargs = {}
75
+ time_module = nn.MSELoss(**tkwargs)
76
+ freq_module = ReImL2Loss(**fkwargs)
77
+ super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
78
+
79
+
80
+ class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper):
81
+ def __init__(
82
+ self,
83
+ time_weight: float = 1.0,
84
+ freq_weight: float = 1.0,
85
+ tkwargs: Optional[Dict[str, Any]] = None,
86
+ fkwargs: Optional[Dict[str, Any]] = None,
87
+ multistem: bool = True,
88
+ ) -> None:
89
+ if tkwargs is None:
90
+ tkwargs = {}
91
+ if fkwargs is None:
92
+ fkwargs = {}
93
+ time_module = SignalNoisePNormRatio(**tkwargs)
94
+ freq_module = SignalNoisePNormRatio(**fkwargs)
95
+ super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
mvsepless/models/bandit/core/loss/snr.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.modules.loss import _Loss
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class SignalNoisePNormRatio(_Loss):
7
+ def __init__(
8
+ self,
9
+ p: float = 1.0,
10
+ scale_invariant: bool = False,
11
+ zero_mean: bool = False,
12
+ take_log: bool = True,
13
+ reduction: str = "mean",
14
+ EPS: float = 1e-3,
15
+ ) -> None:
16
+ assert reduction != "sum", NotImplementedError
17
+ super().__init__(reduction=reduction)
18
+ assert not zero_mean
19
+
20
+ self.p = p
21
+
22
+ self.EPS = EPS
23
+ self.take_log = take_log
24
+
25
+ self.scale_invariant = scale_invariant
26
+
27
+ def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
28
+
29
+ target_ = target
30
+ if self.scale_invariant:
31
+ ndim = target.ndim
32
+ dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True)
33
+ s_target_energy = torch.sum(
34
+ target * torch.conj(target), dim=-1, keepdim=True
35
+ )
36
+
37
+ if ndim > 2:
38
+ dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True)
39
+ s_target_energy = torch.sum(
40
+ s_target_energy, dim=list(range(1, ndim)), keepdim=True
41
+ )
42
+
43
+ target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8)
44
+ target = target_ * target_scaler
45
+
46
+ if torch.is_complex(est_target):
47
+ est_target = torch.view_as_real(est_target)
48
+ target = torch.view_as_real(target)
49
+
50
+ batch_size = est_target.shape[0]
51
+ est_target = est_target.reshape(batch_size, -1)
52
+ target = target.reshape(batch_size, -1)
53
+ # target_ = target_.reshape(batch_size, -1)
54
+
55
+ if self.p == 1:
56
+ e_error = torch.abs(est_target - target).mean(dim=-1)
57
+ e_target = torch.abs(target).mean(dim=-1)
58
+ elif self.p == 2:
59
+ e_error = torch.square(est_target - target).mean(dim=-1)
60
+ e_target = torch.square(target).mean(dim=-1)
61
+ else:
62
+ raise NotImplementedError
63
+
64
+ if self.take_log:
65
+ loss = 10 * (
66
+ torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS)
67
+ )
68
+ else:
69
+ loss = (e_error + self.EPS) / (e_target + self.EPS)
70
+
71
+ if self.reduction == "mean":
72
+ loss = loss.mean()
73
+ elif self.reduction == "sum":
74
+ loss = loss.sum()
75
+
76
+ return loss
77
+
78
+
79
+ class MultichannelSingleSrcNegSDR(_Loss):
80
+ def __init__(
81
+ self,
82
+ sdr_type: str,
83
+ p: float = 2.0,
84
+ zero_mean: bool = True,
85
+ take_log: bool = True,
86
+ reduction: str = "mean",
87
+ EPS: float = 1e-8,
88
+ ) -> None:
89
+ assert reduction != "sum", NotImplementedError
90
+ super().__init__(reduction=reduction)
91
+
92
+ assert sdr_type in ["snr", "sisdr", "sdsdr"]
93
+ self.sdr_type = sdr_type
94
+ self.zero_mean = zero_mean
95
+ self.take_log = take_log
96
+ self.EPS = 1e-8
97
+
98
+ self.p = p
99
+
100
+ def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
101
+ if target.size() != est_target.size() or target.ndim != 3:
102
+ raise TypeError(
103
+ f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead"
104
+ )
105
+ # Step 1. Zero-mean norm
106
+ if self.zero_mean:
107
+ mean_source = torch.mean(target, dim=[1, 2], keepdim=True)
108
+ mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True)
109
+ target = target - mean_source
110
+ est_target = est_target - mean_estimate
111
+ # Step 2. Pair-wise SI-SDR.
112
+ if self.sdr_type in ["sisdr", "sdsdr"]:
113
+ # [batch, 1]
114
+ dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True)
115
+ # [batch, 1]
116
+ s_target_energy = torch.sum(target**2, dim=[1, 2], keepdim=True) + self.EPS
117
+ # [batch, time]
118
+ scaled_target = dot * target / s_target_energy
119
+ else:
120
+ # [batch, time]
121
+ scaled_target = target
122
+ if self.sdr_type in ["sdsdr", "snr"]:
123
+ e_noise = est_target - target
124
+ else:
125
+ e_noise = est_target - scaled_target
126
+ # [batch]
127
+
128
+ if self.p == 2.0:
129
+ losses = torch.sum(scaled_target**2, dim=[1, 2]) / (
130
+ torch.sum(e_noise**2, dim=[1, 2]) + self.EPS
131
+ )
132
+ else:
133
+ losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / (
134
+ torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
135
+ )
136
+ if self.take_log:
137
+ losses = 10 * torch.log10(losses + self.EPS)
138
+ losses = losses.mean() if self.reduction == "mean" else losses
139
+ return -losses
mvsepless/models/bandit/core/metrics/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .snr import (
2
+ ChunkMedianScaleInvariantSignalDistortionRatio,
3
+ ChunkMedianScaleInvariantSignalNoiseRatio,
4
+ ChunkMedianSignalDistortionRatio,
5
+ ChunkMedianSignalNoiseRatio,
6
+ SafeSignalDistortionRatio,
7
+ )
8
+
9
+ # from .mushra import EstimatedMushraScore
mvsepless/models/bandit/core/metrics/_squim.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from torchaudio._internal import load_state_dict_from_url
4
+
5
+ import math
6
+ from typing import List, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def transform_wb_pesq_range(x: float) -> float:
14
+ """The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined
15
+ for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric
16
+ defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score".
17
+
18
+ Args:
19
+ x (float): Narrow-band PESQ score.
20
+
21
+ Returns:
22
+ (float): Wide-band PESQ score.
23
+ """
24
+ return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
25
+
26
+
27
+ PESQRange: Tuple[float, float] = (
28
+ 1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of
29
+ # the raw score is not -0.5 anymore. It's hard to figure out the true lower bound.
30
+ # We are using 1.0 as a reasonable approximation.
31
+ transform_wb_pesq_range(4.5),
32
+ )
33
+
34
+
35
+ class RangeSigmoid(nn.Module):
36
+ def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
37
+ super(RangeSigmoid, self).__init__()
38
+ assert isinstance(val_range, tuple) and len(val_range) == 2
39
+ self.val_range: Tuple[float, float] = val_range
40
+ self.sigmoid: nn.modules.Module = nn.Sigmoid()
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ out = (
44
+ self.sigmoid(x) * (self.val_range[1] - self.val_range[0])
45
+ + self.val_range[0]
46
+ )
47
+ return out
48
+
49
+
50
+ class Encoder(nn.Module):
51
+ """Encoder module that transform 1D waveform to 2D representations.
52
+
53
+ Args:
54
+ feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512)
55
+ win_len (int, optional): kernel size in the Conv1D layer. (Default: 32)
56
+ """
57
+
58
+ def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
59
+ super(Encoder, self).__init__()
60
+
61
+ self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
62
+
63
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
64
+ """Apply waveforms to convolutional layer and ReLU layer.
65
+
66
+ Args:
67
+ x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
68
+
69
+ Returns:
70
+ (torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`.
71
+ """
72
+ out = x.unsqueeze(dim=1)
73
+ out = F.relu(self.conv1d(out))
74
+ return out
75
+
76
+
77
+ class SingleRNN(nn.Module):
78
+ def __init__(
79
+ self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0
80
+ ) -> None:
81
+ super(SingleRNN, self).__init__()
82
+
83
+ self.rnn_type = rnn_type
84
+ self.input_size = input_size
85
+ self.hidden_size = hidden_size
86
+
87
+ self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
88
+ input_size,
89
+ hidden_size,
90
+ 1,
91
+ dropout=dropout,
92
+ batch_first=True,
93
+ bidirectional=True,
94
+ )
95
+
96
+ self.proj = nn.Linear(hidden_size * 2, input_size)
97
+
98
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
99
+ # input shape: batch, seq, dim
100
+ out, _ = self.rnn(x)
101
+ out = self.proj(out)
102
+ return out
103
+
104
+
105
+ class DPRNN(nn.Module):
106
+ """*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`.
107
+
108
+ Args:
109
+ feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64)
110
+ hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128)
111
+ num_blocks (int, optional): Number of DPRNN layers. (Default: 6)
112
+ rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM")
113
+ d_model (int, optional): The number of expected features in the input. (Default: 256)
114
+ chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100)
115
+ chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50)
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ feat_dim: int = 64,
121
+ hidden_dim: int = 128,
122
+ num_blocks: int = 6,
123
+ rnn_type: str = "LSTM",
124
+ d_model: int = 256,
125
+ chunk_size: int = 100,
126
+ chunk_stride: int = 50,
127
+ ) -> None:
128
+ super(DPRNN, self).__init__()
129
+
130
+ self.num_blocks = num_blocks
131
+
132
+ self.row_rnn = nn.ModuleList([])
133
+ self.col_rnn = nn.ModuleList([])
134
+ self.row_norm = nn.ModuleList([])
135
+ self.col_norm = nn.ModuleList([])
136
+ for _ in range(num_blocks):
137
+ self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
138
+ self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
139
+ self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
140
+ self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
141
+ self.conv = nn.Sequential(
142
+ nn.Conv2d(feat_dim, d_model, 1),
143
+ nn.PReLU(),
144
+ )
145
+ self.chunk_size = chunk_size
146
+ self.chunk_stride = chunk_stride
147
+
148
+ def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
149
+ # input shape: (B, N, T)
150
+ seq_len = x.shape[-1]
151
+
152
+ rest = (
153
+ self.chunk_size
154
+ - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
155
+ )
156
+ out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
157
+
158
+ return out, rest
159
+
160
+ def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
161
+ out, rest = self.pad_chunk(x)
162
+ batch_size, feat_dim, seq_len = out.shape
163
+
164
+ segments1 = (
165
+ out[:, :, : -self.chunk_stride]
166
+ .contiguous()
167
+ .view(batch_size, feat_dim, -1, self.chunk_size)
168
+ )
169
+ segments2 = (
170
+ out[:, :, self.chunk_stride :]
171
+ .contiguous()
172
+ .view(batch_size, feat_dim, -1, self.chunk_size)
173
+ )
174
+ out = torch.cat([segments1, segments2], dim=3)
175
+ out = (
176
+ out.view(batch_size, feat_dim, -1, self.chunk_size)
177
+ .transpose(2, 3)
178
+ .contiguous()
179
+ )
180
+
181
+ return out, rest
182
+
183
+ def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
184
+ batch_size, dim, _, _ = x.shape
185
+ out = (
186
+ x.transpose(2, 3)
187
+ .contiguous()
188
+ .view(batch_size, dim, -1, self.chunk_size * 2)
189
+ )
190
+ out1 = (
191
+ out[:, :, :, : self.chunk_size]
192
+ .contiguous()
193
+ .view(batch_size, dim, -1)[:, :, self.chunk_stride :]
194
+ )
195
+ out2 = (
196
+ out[:, :, :, self.chunk_size :]
197
+ .contiguous()
198
+ .view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
199
+ )
200
+ out = out1 + out2
201
+ if rest > 0:
202
+ out = out[:, :, :-rest]
203
+ out = out.contiguous()
204
+ return out
205
+
206
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
207
+ x, rest = self.chunking(x)
208
+ batch_size, _, dim1, dim2 = x.shape
209
+ out = x
210
+ for row_rnn, row_norm, col_rnn, col_norm in zip(
211
+ self.row_rnn, self.row_norm, self.col_rnn, self.col_norm
212
+ ):
213
+ row_in = (
214
+ out.permute(0, 3, 2, 1)
215
+ .contiguous()
216
+ .view(batch_size * dim2, dim1, -1)
217
+ .contiguous()
218
+ )
219
+ row_out = row_rnn(row_in)
220
+ row_out = (
221
+ row_out.view(batch_size, dim2, dim1, -1)
222
+ .permute(0, 3, 2, 1)
223
+ .contiguous()
224
+ )
225
+ row_out = row_norm(row_out)
226
+ out = out + row_out
227
+
228
+ col_in = (
229
+ out.permute(0, 2, 3, 1)
230
+ .contiguous()
231
+ .view(batch_size * dim1, dim2, -1)
232
+ .contiguous()
233
+ )
234
+ col_out = col_rnn(col_in)
235
+ col_out = (
236
+ col_out.view(batch_size, dim1, dim2, -1)
237
+ .permute(0, 3, 1, 2)
238
+ .contiguous()
239
+ )
240
+ col_out = col_norm(col_out)
241
+ out = out + col_out
242
+ out = self.conv(out)
243
+ out = self.merging(out, rest)
244
+ out = out.transpose(1, 2).contiguous()
245
+ return out
246
+
247
+
248
+ class AutoPool(nn.Module):
249
+ def __init__(self, pool_dim: int = 1) -> None:
250
+ super(AutoPool, self).__init__()
251
+ self.pool_dim: int = pool_dim
252
+ self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
253
+ self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
254
+
255
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
256
+ weight = self.softmax(torch.mul(x, self.alpha))
257
+ out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
258
+ return out
259
+
260
+
261
+ class SquimObjective(nn.Module):
262
+ """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores
263
+ for speech enhancement (e.g., STOI, PESQ, and SI-SDR).
264
+
265
+ Args:
266
+ encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation.
267
+ dprnn (torch.nn.Module): DPRNN module to model sequential feature.
268
+ branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score.
269
+ """
270
+
271
+ def __init__(
272
+ self,
273
+ encoder: nn.Module,
274
+ dprnn: nn.Module,
275
+ branches: nn.ModuleList,
276
+ ):
277
+ super(SquimObjective, self).__init__()
278
+ self.encoder = encoder
279
+ self.dprnn = dprnn
280
+ self.branches = branches
281
+
282
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
283
+ """
284
+ Args:
285
+ x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
286
+
287
+ Returns:
288
+ List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`.
289
+ """
290
+ if x.ndim != 2:
291
+ raise ValueError(
292
+ f"The input must be a 2D Tensor. Found dimension {x.ndim}."
293
+ )
294
+ x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
295
+ out = self.encoder(x)
296
+ out = self.dprnn(out)
297
+ scores = []
298
+ for branch in self.branches:
299
+ scores.append(branch(out).squeeze(dim=1))
300
+ return scores
301
+
302
+
303
+ def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
304
+ """Create branch module after DPRNN model for predicting metric score.
305
+
306
+ Args:
307
+ d_model (int): The number of expected features in the input.
308
+ nhead (int): Number of heads in the multi-head attention model.
309
+ metric (str): The metric name to predict.
310
+
311
+ Returns:
312
+ (nn.Module): Returned module to predict corresponding metric score.
313
+ """
314
+ layer1 = nn.TransformerEncoderLayer(
315
+ d_model, nhead, d_model * 4, dropout=0.0, batch_first=True
316
+ )
317
+ layer2 = AutoPool()
318
+ if metric == "stoi":
319
+ layer3 = nn.Sequential(
320
+ nn.Linear(d_model, d_model),
321
+ nn.PReLU(),
322
+ nn.Linear(d_model, 1),
323
+ RangeSigmoid(),
324
+ )
325
+ elif metric == "pesq":
326
+ layer3 = nn.Sequential(
327
+ nn.Linear(d_model, d_model),
328
+ nn.PReLU(),
329
+ nn.Linear(d_model, 1),
330
+ RangeSigmoid(val_range=PESQRange),
331
+ )
332
+ else:
333
+ layer3: nn.modules.Module = nn.Sequential(
334
+ nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1)
335
+ )
336
+ return nn.Sequential(layer1, layer2, layer3)
337
+
338
+
339
+ def squim_objective_model(
340
+ feat_dim: int,
341
+ win_len: int,
342
+ d_model: int,
343
+ nhead: int,
344
+ hidden_dim: int,
345
+ num_blocks: int,
346
+ rnn_type: str,
347
+ chunk_size: int,
348
+ chunk_stride: Optional[int] = None,
349
+ ) -> SquimObjective:
350
+ """Build a custome :class:`torchaudio.prototype.models.SquimObjective` model.
351
+
352
+ Args:
353
+ feat_dim (int, optional): The feature dimension after Encoder module.
354
+ win_len (int): Kernel size in the Encoder module.
355
+ d_model (int): The number of expected features in the input.
356
+ nhead (int): Number of heads in the multi-head attention model.
357
+ hidden_dim (int): Hidden dimension in the RNN layer of DPRNN.
358
+ num_blocks (int): Number of DPRNN layers.
359
+ rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"].
360
+ chunk_size (int): Chunk size of input for DPRNN.
361
+ chunk_stride (int or None, optional): Stride of chunk input for DPRNN.
362
+ """
363
+ if chunk_stride is None:
364
+ chunk_stride = chunk_size // 2
365
+ encoder = Encoder(feat_dim, win_len)
366
+ dprnn = DPRNN(
367
+ feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride
368
+ )
369
+ branches = nn.ModuleList(
370
+ [
371
+ _create_branch(d_model, nhead, "stoi"),
372
+ _create_branch(d_model, nhead, "pesq"),
373
+ _create_branch(d_model, nhead, "sisdr"),
374
+ ]
375
+ )
376
+ return SquimObjective(encoder, dprnn, branches)
377
+
378
+
379
+ def squim_objective_base() -> SquimObjective:
380
+ """Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments."""
381
+ return squim_objective_model(
382
+ feat_dim=256,
383
+ win_len=64,
384
+ d_model=256,
385
+ nhead=4,
386
+ hidden_dim=256,
387
+ num_blocks=2,
388
+ rnn_type="LSTM",
389
+ chunk_size=71,
390
+ )
391
+
392
+
393
+ @dataclass
394
+ class SquimObjectiveBundle:
395
+
396
+ _path: str
397
+ _sample_rate: float
398
+
399
+ def _get_state_dict(self, dl_kwargs):
400
+ url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
401
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
402
+ state_dict = load_state_dict_from_url(url, **dl_kwargs)
403
+ return state_dict
404
+
405
+ def get_model(self, *, dl_kwargs=None) -> SquimObjective:
406
+ """Construct the SquimObjective model, and load the pretrained weight.
407
+
408
+ The weight file is downloaded from the internet and cached with
409
+ :func:`torch.hub.load_state_dict_from_url`
410
+
411
+ Args:
412
+ dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
413
+
414
+ Returns:
415
+ Variation of :py:class:`~torchaudio.models.SquimObjective`.
416
+ """
417
+ model = squim_objective_base()
418
+ model.load_state_dict(self._get_state_dict(dl_kwargs))
419
+ model.eval()
420
+ return model
421
+
422
+ @property
423
+ def sample_rate(self):
424
+ """Sample rate of the audio that the model is trained on.
425
+
426
+ :type: float
427
+ """
428
+ return self._sample_rate
429
+
430
+
431
+ SQUIM_OBJECTIVE = SquimObjectiveBundle(
432
+ "squim_objective_dns2020.pth",
433
+ _sample_rate=16000,
434
+ )
435
+ SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
436
+ :cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
437
+
438
+ The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
439
+ The weights are under `Creative Commons Attribution 4.0 International License
440
+ <https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
441
+
442
+ Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
443
+ """
mvsepless/models/bandit/core/metrics/snr.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchmetrics as tm
6
+ from torch._C import _LinAlgError
7
+ from torchmetrics import functional as tmF
8
+
9
+
10
+ class SafeSignalDistortionRatio(tm.SignalDistortionRatio):
11
+ def __init__(self, **kwargs) -> None:
12
+ super().__init__(**kwargs)
13
+
14
+ def update(self, *args, **kwargs) -> Any:
15
+ try:
16
+ super().update(*args, **kwargs)
17
+ except:
18
+ pass
19
+
20
+ def compute(self) -> Any:
21
+ if self.total == 0:
22
+ return torch.tensor(torch.nan)
23
+ return super().compute()
24
+
25
+
26
+ class BaseChunkMedianSignalRatio(tm.Metric):
27
+ def __init__(
28
+ self,
29
+ func: Callable,
30
+ window_size: int,
31
+ hop_size: int = None,
32
+ zero_mean: bool = False,
33
+ ) -> None:
34
+ super().__init__()
35
+
36
+ # self.zero_mean = zero_mean
37
+ self.func = func
38
+ self.window_size = window_size
39
+ if hop_size is None:
40
+ hop_size = window_size
41
+ self.hop_size = hop_size
42
+
43
+ self.add_state("sum_snr", default=torch.tensor(0.0), dist_reduce_fx="sum")
44
+ self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
45
+
46
+ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
47
+
48
+ n_samples = target.shape[-1]
49
+
50
+ n_chunks = int(np.ceil((n_samples - self.window_size) / self.hop_size) + 1)
51
+
52
+ snr_chunk = []
53
+
54
+ for i in range(n_chunks):
55
+ start = i * self.hop_size
56
+
57
+ if n_samples - start < self.window_size:
58
+ continue
59
+
60
+ end = start + self.window_size
61
+
62
+ try:
63
+ chunk_snr = self.func(preds[..., start:end], target[..., start:end])
64
+
65
+ # print(preds.shape, chunk_snr.shape)
66
+
67
+ if torch.all(torch.isfinite(chunk_snr)):
68
+ snr_chunk.append(chunk_snr)
69
+ except _LinAlgError:
70
+ pass
71
+
72
+ snr_chunk = torch.stack(snr_chunk, dim=-1)
73
+ snr_batch, _ = torch.nanmedian(snr_chunk, dim=-1)
74
+
75
+ self.sum_snr += snr_batch.sum()
76
+ self.total += snr_batch.numel()
77
+
78
+ def compute(self) -> Any:
79
+ return self.sum_snr / self.total
80
+
81
+
82
+ class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio):
83
+ def __init__(
84
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
85
+ ) -> None:
86
+ super().__init__(
87
+ func=tmF.signal_noise_ratio,
88
+ window_size=window_size,
89
+ hop_size=hop_size,
90
+ zero_mean=zero_mean,
91
+ )
92
+
93
+
94
+ class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio):
95
+ def __init__(
96
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
97
+ ) -> None:
98
+ super().__init__(
99
+ func=tmF.scale_invariant_signal_noise_ratio,
100
+ window_size=window_size,
101
+ hop_size=hop_size,
102
+ zero_mean=zero_mean,
103
+ )
104
+
105
+
106
+ class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio):
107
+ def __init__(
108
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
109
+ ) -> None:
110
+ super().__init__(
111
+ func=tmF.signal_distortion_ratio,
112
+ window_size=window_size,
113
+ hop_size=hop_size,
114
+ zero_mean=zero_mean,
115
+ )
116
+
117
+
118
+ class ChunkMedianScaleInvariantSignalDistortionRatio(BaseChunkMedianSignalRatio):
119
+ def __init__(
120
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
121
+ ) -> None:
122
+ super().__init__(
123
+ func=tmF.scale_invariant_signal_distortion_ratio,
124
+ window_size=window_size,
125
+ hop_size=hop_size,
126
+ zero_mean=zero_mean,
127
+ )
mvsepless/models/bandit/core/model/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .bsrnn.wrapper import (
2
+ MultiMaskMultiSourceBandSplitRNNSimple,
3
+ )
mvsepless/models/bandit/core/model/_spectral.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+
3
+ import torch
4
+ import torchaudio as ta
5
+ from torch import nn
6
+
7
+
8
+ class _SpectralComponent(nn.Module):
9
+ def __init__(
10
+ self,
11
+ n_fft: int = 2048,
12
+ win_length: Optional[int] = 2048,
13
+ hop_length: int = 512,
14
+ window_fn: str = "hann_window",
15
+ wkwargs: Optional[Dict] = None,
16
+ power: Optional[int] = None,
17
+ center: bool = True,
18
+ normalized: bool = True,
19
+ pad_mode: str = "constant",
20
+ onesided: bool = True,
21
+ **kwargs,
22
+ ) -> None:
23
+ super().__init__()
24
+
25
+ assert power is None
26
+
27
+ window_fn = torch.__dict__[window_fn]
28
+
29
+ self.stft = ta.transforms.Spectrogram(
30
+ n_fft=n_fft,
31
+ win_length=win_length,
32
+ hop_length=hop_length,
33
+ pad_mode=pad_mode,
34
+ pad=0,
35
+ window_fn=window_fn,
36
+ wkwargs=wkwargs,
37
+ power=power,
38
+ normalized=normalized,
39
+ center=center,
40
+ onesided=onesided,
41
+ )
42
+
43
+ self.istft = ta.transforms.InverseSpectrogram(
44
+ n_fft=n_fft,
45
+ win_length=win_length,
46
+ hop_length=hop_length,
47
+ pad_mode=pad_mode,
48
+ pad=0,
49
+ window_fn=window_fn,
50
+ wkwargs=wkwargs,
51
+ normalized=normalized,
52
+ center=center,
53
+ onesided=onesided,
54
+ )
mvsepless/models/bandit/core/model/bsrnn/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from typing import Iterable, Mapping, Union
3
+
4
+ from torch import nn
5
+
6
+ from .bandsplit import BandSplitModule
7
+ from .tfmodel import (
8
+ SeqBandModellingModule,
9
+ TransformerTimeFreqModule,
10
+ )
11
+
12
+
13
+ class BandsplitCoreBase(nn.Module, ABC):
14
+ band_split: nn.Module
15
+ tf_model: nn.Module
16
+ mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]]
17
+
18
+ def __init__(self) -> None:
19
+ super().__init__()
20
+
21
+ @staticmethod
22
+ def mask(x, m):
23
+ return x * m
mvsepless/models/bandit/core/model/bsrnn/bandsplit.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .utils import (
7
+ band_widths_from_specs,
8
+ check_no_gap,
9
+ check_no_overlap,
10
+ check_nonzero_bandwidth,
11
+ )
12
+
13
+
14
+ class NormFC(nn.Module):
15
+ def __init__(
16
+ self,
17
+ emb_dim: int,
18
+ bandwidth: int,
19
+ in_channel: int,
20
+ normalize_channel_independently: bool = False,
21
+ treat_channel_as_feature: bool = True,
22
+ ) -> None:
23
+ super().__init__()
24
+
25
+ self.treat_channel_as_feature = treat_channel_as_feature
26
+
27
+ if normalize_channel_independently:
28
+ raise NotImplementedError
29
+
30
+ reim = 2
31
+
32
+ self.norm = nn.LayerNorm(in_channel * bandwidth * reim)
33
+
34
+ fc_in = bandwidth * reim
35
+
36
+ if treat_channel_as_feature:
37
+ fc_in *= in_channel
38
+ else:
39
+ assert emb_dim % in_channel == 0
40
+ emb_dim = emb_dim // in_channel
41
+
42
+ self.fc = nn.Linear(fc_in, emb_dim)
43
+
44
+ def forward(self, xb):
45
+ # xb = (batch, n_time, in_chan, reim * band_width)
46
+
47
+ batch, n_time, in_chan, ribw = xb.shape
48
+ xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw))
49
+ # (batch, n_time, in_chan * reim * band_width)
50
+
51
+ if not self.treat_channel_as_feature:
52
+ xb = xb.reshape(batch, n_time, in_chan, ribw)
53
+ # (batch, n_time, in_chan, reim * band_width)
54
+
55
+ zb = self.fc(xb)
56
+ # (batch, n_time, emb_dim)
57
+ # OR
58
+ # (batch, n_time, in_chan, emb_dim_per_chan)
59
+
60
+ if not self.treat_channel_as_feature:
61
+ batch, n_time, in_chan, emb_dim_per_chan = zb.shape
62
+ # (batch, n_time, in_chan, emb_dim_per_chan)
63
+ zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan))
64
+
65
+ return zb # (batch, n_time, emb_dim)
66
+
67
+
68
+ class BandSplitModule(nn.Module):
69
+ def __init__(
70
+ self,
71
+ band_specs: List[Tuple[float, float]],
72
+ emb_dim: int,
73
+ in_channel: int,
74
+ require_no_overlap: bool = False,
75
+ require_no_gap: bool = True,
76
+ normalize_channel_independently: bool = False,
77
+ treat_channel_as_feature: bool = True,
78
+ ) -> None:
79
+ super().__init__()
80
+
81
+ check_nonzero_bandwidth(band_specs)
82
+
83
+ if require_no_gap:
84
+ check_no_gap(band_specs)
85
+
86
+ if require_no_overlap:
87
+ check_no_overlap(band_specs)
88
+
89
+ self.band_specs = band_specs
90
+ # list of [fstart, fend) in index.
91
+ # Note that fend is exclusive.
92
+ self.band_widths = band_widths_from_specs(band_specs)
93
+ self.n_bands = len(band_specs)
94
+ self.emb_dim = emb_dim
95
+
96
+ self.norm_fc_modules = nn.ModuleList(
97
+ [ # type: ignore
98
+ (
99
+ NormFC(
100
+ emb_dim=emb_dim,
101
+ bandwidth=bw,
102
+ in_channel=in_channel,
103
+ normalize_channel_independently=normalize_channel_independently,
104
+ treat_channel_as_feature=treat_channel_as_feature,
105
+ )
106
+ )
107
+ for bw in self.band_widths
108
+ ]
109
+ )
110
+
111
+ def forward(self, x: torch.Tensor):
112
+ # x = complex spectrogram (batch, in_chan, n_freq, n_time)
113
+
114
+ batch, in_chan, _, n_time = x.shape
115
+
116
+ z = torch.zeros(
117
+ size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
118
+ )
119
+
120
+ xr = torch.view_as_real(x) # batch, in_chan, n_freq, n_time, 2
121
+ xr = torch.permute(xr, (0, 3, 1, 4, 2)) # batch, n_time, in_chan, 2, n_freq
122
+ batch, n_time, in_chan, reim, band_width = xr.shape
123
+ for i, nfm in enumerate(self.norm_fc_modules):
124
+ # print(f"bandsplit/band{i:02d}")
125
+ fstart, fend = self.band_specs[i]
126
+ xb = xr[..., fstart:fend]
127
+ # (batch, n_time, in_chan, reim, band_width)
128
+ xb = torch.reshape(xb, (batch, n_time, in_chan, -1))
129
+ # (batch, n_time, in_chan, reim * band_width)
130
+ # z.append(nfm(xb)) # (batch, n_time, emb_dim)
131
+ z[:, i, :, :] = nfm(xb.contiguous())
132
+
133
+ # z = torch.stack(z, dim=1)
134
+
135
+ return z
mvsepless/models/bandit/core/model/bsrnn/core.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from . import BandsplitCoreBase
8
+ from .bandsplit import BandSplitModule
9
+ from .maskestim import (
10
+ MaskEstimationModule,
11
+ OverlappingMaskEstimationModule,
12
+ )
13
+ from .tfmodel import (
14
+ ConvolutionalTimeFreqModule,
15
+ SeqBandModellingModule,
16
+ TransformerTimeFreqModule,
17
+ )
18
+
19
+
20
+ class MultiMaskBandSplitCoreBase(BandsplitCoreBase):
21
+ def __init__(self) -> None:
22
+ super().__init__()
23
+
24
+ def forward(self, x, cond=None, compute_residual: bool = True):
25
+ # x = complex spectrogram (batch, in_chan, n_freq, n_time)
26
+ # print(x.shape)
27
+ batch, in_chan, n_freq, n_time = x.shape
28
+ x = torch.reshape(x, (-1, 1, n_freq, n_time))
29
+
30
+ z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
31
+
32
+ # if torch.any(torch.isnan(z)):
33
+ # raise ValueError("z nan")
34
+
35
+ # print(z)
36
+ q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
37
+ # print(q)
38
+
39
+ # if torch.any(torch.isnan(q)):
40
+ # raise ValueError("q nan")
41
+
42
+ out = {}
43
+
44
+ for stem, mem in self.mask_estim.items():
45
+ m = mem(q, cond=cond)
46
+
47
+ # if torch.any(torch.isnan(m)):
48
+ # raise ValueError("m nan", stem)
49
+
50
+ s = self.mask(x, m)
51
+ s = torch.reshape(s, (batch, in_chan, n_freq, n_time))
52
+ out[stem] = s
53
+
54
+ return {"spectrogram": out}
55
+
56
+ def instantiate_mask_estim(
57
+ self,
58
+ in_channel: int,
59
+ stems: List[str],
60
+ band_specs: List[Tuple[float, float]],
61
+ emb_dim: int,
62
+ mlp_dim: int,
63
+ cond_dim: int,
64
+ hidden_activation: str,
65
+ hidden_activation_kwargs: Optional[Dict] = None,
66
+ complex_mask: bool = True,
67
+ overlapping_band: bool = False,
68
+ freq_weights: Optional[List[torch.Tensor]] = None,
69
+ n_freq: Optional[int] = None,
70
+ use_freq_weights: bool = True,
71
+ mult_add_mask: bool = False,
72
+ ):
73
+ if hidden_activation_kwargs is None:
74
+ hidden_activation_kwargs = {}
75
+
76
+ if "mne:+" in stems:
77
+ stems = [s for s in stems if s != "mne:+"]
78
+
79
+ if overlapping_band:
80
+ assert freq_weights is not None
81
+ assert n_freq is not None
82
+
83
+ if mult_add_mask:
84
+
85
+ self.mask_estim = nn.ModuleDict(
86
+ {
87
+ stem: MultAddMaskEstimationModule(
88
+ band_specs=band_specs,
89
+ freq_weights=freq_weights,
90
+ n_freq=n_freq,
91
+ emb_dim=emb_dim,
92
+ mlp_dim=mlp_dim,
93
+ in_channel=in_channel,
94
+ hidden_activation=hidden_activation,
95
+ hidden_activation_kwargs=hidden_activation_kwargs,
96
+ complex_mask=complex_mask,
97
+ use_freq_weights=use_freq_weights,
98
+ )
99
+ for stem in stems
100
+ }
101
+ )
102
+ else:
103
+ self.mask_estim = nn.ModuleDict(
104
+ {
105
+ stem: OverlappingMaskEstimationModule(
106
+ band_specs=band_specs,
107
+ freq_weights=freq_weights,
108
+ n_freq=n_freq,
109
+ emb_dim=emb_dim,
110
+ mlp_dim=mlp_dim,
111
+ in_channel=in_channel,
112
+ hidden_activation=hidden_activation,
113
+ hidden_activation_kwargs=hidden_activation_kwargs,
114
+ complex_mask=complex_mask,
115
+ use_freq_weights=use_freq_weights,
116
+ )
117
+ for stem in stems
118
+ }
119
+ )
120
+ else:
121
+ self.mask_estim = nn.ModuleDict(
122
+ {
123
+ stem: MaskEstimationModule(
124
+ band_specs=band_specs,
125
+ emb_dim=emb_dim,
126
+ mlp_dim=mlp_dim,
127
+ cond_dim=cond_dim,
128
+ in_channel=in_channel,
129
+ hidden_activation=hidden_activation,
130
+ hidden_activation_kwargs=hidden_activation_kwargs,
131
+ complex_mask=complex_mask,
132
+ )
133
+ for stem in stems
134
+ }
135
+ )
136
+
137
+ def instantiate_bandsplit(
138
+ self,
139
+ in_channel: int,
140
+ band_specs: List[Tuple[float, float]],
141
+ require_no_overlap: bool = False,
142
+ require_no_gap: bool = True,
143
+ normalize_channel_independently: bool = False,
144
+ treat_channel_as_feature: bool = True,
145
+ emb_dim: int = 128,
146
+ ):
147
+ self.band_split = BandSplitModule(
148
+ in_channel=in_channel,
149
+ band_specs=band_specs,
150
+ require_no_overlap=require_no_overlap,
151
+ require_no_gap=require_no_gap,
152
+ normalize_channel_independently=normalize_channel_independently,
153
+ treat_channel_as_feature=treat_channel_as_feature,
154
+ emb_dim=emb_dim,
155
+ )
156
+
157
+
158
+ class SingleMaskBandsplitCoreBase(BandsplitCoreBase):
159
+ def __init__(self, **kwargs) -> None:
160
+ super().__init__()
161
+
162
+ def forward(self, x):
163
+ # x = complex spectrogram (batch, in_chan, n_freq, n_time)
164
+ z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
165
+ q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
166
+ m = self.mask_estim(q) # (batch, in_chan, n_freq, n_time)
167
+
168
+ s = self.mask(x, m)
169
+
170
+ return s
171
+
172
+
173
+ class SingleMaskBandsplitCoreRNN(
174
+ SingleMaskBandsplitCoreBase,
175
+ ):
176
+ def __init__(
177
+ self,
178
+ in_channel: int,
179
+ band_specs: List[Tuple[float, float]],
180
+ require_no_overlap: bool = False,
181
+ require_no_gap: bool = True,
182
+ normalize_channel_independently: bool = False,
183
+ treat_channel_as_feature: bool = True,
184
+ n_sqm_modules: int = 12,
185
+ emb_dim: int = 128,
186
+ rnn_dim: int = 256,
187
+ bidirectional: bool = True,
188
+ rnn_type: str = "LSTM",
189
+ mlp_dim: int = 512,
190
+ hidden_activation: str = "Tanh",
191
+ hidden_activation_kwargs: Optional[Dict] = None,
192
+ complex_mask: bool = True,
193
+ ) -> None:
194
+ super().__init__()
195
+ self.band_split = BandSplitModule(
196
+ in_channel=in_channel,
197
+ band_specs=band_specs,
198
+ require_no_overlap=require_no_overlap,
199
+ require_no_gap=require_no_gap,
200
+ normalize_channel_independently=normalize_channel_independently,
201
+ treat_channel_as_feature=treat_channel_as_feature,
202
+ emb_dim=emb_dim,
203
+ )
204
+ self.tf_model = SeqBandModellingModule(
205
+ n_modules=n_sqm_modules,
206
+ emb_dim=emb_dim,
207
+ rnn_dim=rnn_dim,
208
+ bidirectional=bidirectional,
209
+ rnn_type=rnn_type,
210
+ )
211
+ self.mask_estim = MaskEstimationModule(
212
+ in_channel=in_channel,
213
+ band_specs=band_specs,
214
+ emb_dim=emb_dim,
215
+ mlp_dim=mlp_dim,
216
+ hidden_activation=hidden_activation,
217
+ hidden_activation_kwargs=hidden_activation_kwargs,
218
+ complex_mask=complex_mask,
219
+ )
220
+
221
+
222
+ class SingleMaskBandsplitCoreTransformer(
223
+ SingleMaskBandsplitCoreBase,
224
+ ):
225
+ def __init__(
226
+ self,
227
+ in_channel: int,
228
+ band_specs: List[Tuple[float, float]],
229
+ require_no_overlap: bool = False,
230
+ require_no_gap: bool = True,
231
+ normalize_channel_independently: bool = False,
232
+ treat_channel_as_feature: bool = True,
233
+ n_sqm_modules: int = 12,
234
+ emb_dim: int = 128,
235
+ rnn_dim: int = 256,
236
+ bidirectional: bool = True,
237
+ tf_dropout: float = 0.0,
238
+ mlp_dim: int = 512,
239
+ hidden_activation: str = "Tanh",
240
+ hidden_activation_kwargs: Optional[Dict] = None,
241
+ complex_mask: bool = True,
242
+ ) -> None:
243
+ super().__init__()
244
+ self.band_split = BandSplitModule(
245
+ in_channel=in_channel,
246
+ band_specs=band_specs,
247
+ require_no_overlap=require_no_overlap,
248
+ require_no_gap=require_no_gap,
249
+ normalize_channel_independently=normalize_channel_independently,
250
+ treat_channel_as_feature=treat_channel_as_feature,
251
+ emb_dim=emb_dim,
252
+ )
253
+ self.tf_model = TransformerTimeFreqModule(
254
+ n_modules=n_sqm_modules,
255
+ emb_dim=emb_dim,
256
+ rnn_dim=rnn_dim,
257
+ bidirectional=bidirectional,
258
+ dropout=tf_dropout,
259
+ )
260
+ self.mask_estim = MaskEstimationModule(
261
+ in_channel=in_channel,
262
+ band_specs=band_specs,
263
+ emb_dim=emb_dim,
264
+ mlp_dim=mlp_dim,
265
+ hidden_activation=hidden_activation,
266
+ hidden_activation_kwargs=hidden_activation_kwargs,
267
+ complex_mask=complex_mask,
268
+ )
269
+
270
+
271
+ class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase):
272
+ def __init__(
273
+ self,
274
+ in_channel: int,
275
+ stems: List[str],
276
+ band_specs: List[Tuple[float, float]],
277
+ require_no_overlap: bool = False,
278
+ require_no_gap: bool = True,
279
+ normalize_channel_independently: bool = False,
280
+ treat_channel_as_feature: bool = True,
281
+ n_sqm_modules: int = 12,
282
+ emb_dim: int = 128,
283
+ rnn_dim: int = 256,
284
+ bidirectional: bool = True,
285
+ rnn_type: str = "LSTM",
286
+ mlp_dim: int = 512,
287
+ cond_dim: int = 0,
288
+ hidden_activation: str = "Tanh",
289
+ hidden_activation_kwargs: Optional[Dict] = None,
290
+ complex_mask: bool = True,
291
+ overlapping_band: bool = False,
292
+ freq_weights: Optional[List[torch.Tensor]] = None,
293
+ n_freq: Optional[int] = None,
294
+ use_freq_weights: bool = True,
295
+ mult_add_mask: bool = False,
296
+ ) -> None:
297
+
298
+ super().__init__()
299
+ self.instantiate_bandsplit(
300
+ in_channel=in_channel,
301
+ band_specs=band_specs,
302
+ require_no_overlap=require_no_overlap,
303
+ require_no_gap=require_no_gap,
304
+ normalize_channel_independently=normalize_channel_independently,
305
+ treat_channel_as_feature=treat_channel_as_feature,
306
+ emb_dim=emb_dim,
307
+ )
308
+
309
+ self.tf_model = SeqBandModellingModule(
310
+ n_modules=n_sqm_modules,
311
+ emb_dim=emb_dim,
312
+ rnn_dim=rnn_dim,
313
+ bidirectional=bidirectional,
314
+ rnn_type=rnn_type,
315
+ )
316
+
317
+ self.mult_add_mask = mult_add_mask
318
+
319
+ self.instantiate_mask_estim(
320
+ in_channel=in_channel,
321
+ stems=stems,
322
+ band_specs=band_specs,
323
+ emb_dim=emb_dim,
324
+ mlp_dim=mlp_dim,
325
+ cond_dim=cond_dim,
326
+ hidden_activation=hidden_activation,
327
+ hidden_activation_kwargs=hidden_activation_kwargs,
328
+ complex_mask=complex_mask,
329
+ overlapping_band=overlapping_band,
330
+ freq_weights=freq_weights,
331
+ n_freq=n_freq,
332
+ use_freq_weights=use_freq_weights,
333
+ mult_add_mask=mult_add_mask,
334
+ )
335
+
336
+ @staticmethod
337
+ def _mult_add_mask(x, m):
338
+
339
+ assert m.ndim == 5
340
+
341
+ mm = m[..., 0]
342
+ am = m[..., 1]
343
+
344
+ # print(mm.shape, am.shape, x.shape, m.shape)
345
+
346
+ return x * mm + am
347
+
348
+ def mask(self, x, m):
349
+ if self.mult_add_mask:
350
+
351
+ return self._mult_add_mask(x, m)
352
+ else:
353
+ return super().mask(x, m)
354
+
355
+
356
+ class MultiSourceMultiMaskBandSplitCoreTransformer(
357
+ MultiMaskBandSplitCoreBase,
358
+ ):
359
+ def __init__(
360
+ self,
361
+ in_channel: int,
362
+ stems: List[str],
363
+ band_specs: List[Tuple[float, float]],
364
+ require_no_overlap: bool = False,
365
+ require_no_gap: bool = True,
366
+ normalize_channel_independently: bool = False,
367
+ treat_channel_as_feature: bool = True,
368
+ n_sqm_modules: int = 12,
369
+ emb_dim: int = 128,
370
+ rnn_dim: int = 256,
371
+ bidirectional: bool = True,
372
+ tf_dropout: float = 0.0,
373
+ mlp_dim: int = 512,
374
+ hidden_activation: str = "Tanh",
375
+ hidden_activation_kwargs: Optional[Dict] = None,
376
+ complex_mask: bool = True,
377
+ overlapping_band: bool = False,
378
+ freq_weights: Optional[List[torch.Tensor]] = None,
379
+ n_freq: Optional[int] = None,
380
+ use_freq_weights: bool = True,
381
+ rnn_type: str = "LSTM",
382
+ cond_dim: int = 0,
383
+ mult_add_mask: bool = False,
384
+ ) -> None:
385
+ super().__init__()
386
+ self.instantiate_bandsplit(
387
+ in_channel=in_channel,
388
+ band_specs=band_specs,
389
+ require_no_overlap=require_no_overlap,
390
+ require_no_gap=require_no_gap,
391
+ normalize_channel_independently=normalize_channel_independently,
392
+ treat_channel_as_feature=treat_channel_as_feature,
393
+ emb_dim=emb_dim,
394
+ )
395
+ self.tf_model = TransformerTimeFreqModule(
396
+ n_modules=n_sqm_modules,
397
+ emb_dim=emb_dim,
398
+ rnn_dim=rnn_dim,
399
+ bidirectional=bidirectional,
400
+ dropout=tf_dropout,
401
+ )
402
+
403
+ self.instantiate_mask_estim(
404
+ in_channel=in_channel,
405
+ stems=stems,
406
+ band_specs=band_specs,
407
+ emb_dim=emb_dim,
408
+ mlp_dim=mlp_dim,
409
+ cond_dim=cond_dim,
410
+ hidden_activation=hidden_activation,
411
+ hidden_activation_kwargs=hidden_activation_kwargs,
412
+ complex_mask=complex_mask,
413
+ overlapping_band=overlapping_band,
414
+ freq_weights=freq_weights,
415
+ n_freq=n_freq,
416
+ use_freq_weights=use_freq_weights,
417
+ mult_add_mask=mult_add_mask,
418
+ )
419
+
420
+
421
+ class MultiSourceMultiMaskBandSplitCoreConv(
422
+ MultiMaskBandSplitCoreBase,
423
+ ):
424
+ def __init__(
425
+ self,
426
+ in_channel: int,
427
+ stems: List[str],
428
+ band_specs: List[Tuple[float, float]],
429
+ require_no_overlap: bool = False,
430
+ require_no_gap: bool = True,
431
+ normalize_channel_independently: bool = False,
432
+ treat_channel_as_feature: bool = True,
433
+ n_sqm_modules: int = 12,
434
+ emb_dim: int = 128,
435
+ rnn_dim: int = 256,
436
+ bidirectional: bool = True,
437
+ tf_dropout: float = 0.0,
438
+ mlp_dim: int = 512,
439
+ hidden_activation: str = "Tanh",
440
+ hidden_activation_kwargs: Optional[Dict] = None,
441
+ complex_mask: bool = True,
442
+ overlapping_band: bool = False,
443
+ freq_weights: Optional[List[torch.Tensor]] = None,
444
+ n_freq: Optional[int] = None,
445
+ use_freq_weights: bool = True,
446
+ rnn_type: str = "LSTM",
447
+ cond_dim: int = 0,
448
+ mult_add_mask: bool = False,
449
+ ) -> None:
450
+ super().__init__()
451
+ self.instantiate_bandsplit(
452
+ in_channel=in_channel,
453
+ band_specs=band_specs,
454
+ require_no_overlap=require_no_overlap,
455
+ require_no_gap=require_no_gap,
456
+ normalize_channel_independently=normalize_channel_independently,
457
+ treat_channel_as_feature=treat_channel_as_feature,
458
+ emb_dim=emb_dim,
459
+ )
460
+ self.tf_model = ConvolutionalTimeFreqModule(
461
+ n_modules=n_sqm_modules,
462
+ emb_dim=emb_dim,
463
+ rnn_dim=rnn_dim,
464
+ bidirectional=bidirectional,
465
+ dropout=tf_dropout,
466
+ )
467
+
468
+ self.instantiate_mask_estim(
469
+ in_channel=in_channel,
470
+ stems=stems,
471
+ band_specs=band_specs,
472
+ emb_dim=emb_dim,
473
+ mlp_dim=mlp_dim,
474
+ cond_dim=cond_dim,
475
+ hidden_activation=hidden_activation,
476
+ hidden_activation_kwargs=hidden_activation_kwargs,
477
+ complex_mask=complex_mask,
478
+ overlapping_band=overlapping_band,
479
+ freq_weights=freq_weights,
480
+ n_freq=n_freq,
481
+ use_freq_weights=use_freq_weights,
482
+ mult_add_mask=mult_add_mask,
483
+ )
484
+
485
+
486
+ class PatchingMaskBandsplitCoreBase(MultiMaskBandSplitCoreBase):
487
+ def __init__(self) -> None:
488
+ super().__init__()
489
+
490
+ def mask(self, x, m):
491
+ # x.shape = (batch, n_channel, n_freq, n_time)
492
+ # m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time)
493
+
494
+ _, n_channel, kernel_freq, kernel_time, n_freq, n_time = m.shape
495
+ padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2)
496
+
497
+ xf = F.unfold(
498
+ x,
499
+ kernel_size=(kernel_freq, kernel_time),
500
+ padding=padding,
501
+ stride=(1, 1),
502
+ )
503
+
504
+ xf = xf.view(
505
+ -1,
506
+ n_channel,
507
+ kernel_freq,
508
+ kernel_time,
509
+ n_freq,
510
+ n_time,
511
+ )
512
+
513
+ sf = xf * m
514
+
515
+ sf = sf.view(
516
+ -1,
517
+ n_channel * kernel_freq * kernel_time,
518
+ n_freq * n_time,
519
+ )
520
+
521
+ s = F.fold(
522
+ sf,
523
+ output_size=(n_freq, n_time),
524
+ kernel_size=(kernel_freq, kernel_time),
525
+ padding=padding,
526
+ stride=(1, 1),
527
+ ).view(
528
+ -1,
529
+ n_channel,
530
+ n_freq,
531
+ n_time,
532
+ )
533
+
534
+ return s
535
+
536
+ def old_mask(self, x, m):
537
+ # x.shape = (batch, n_channel, n_freq, n_time)
538
+ # m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time)
539
+
540
+ s = torch.zeros_like(x)
541
+
542
+ _, n_channel, n_freq, n_time = x.shape
543
+ kernel_freq, kernel_time, _, _, _, _ = m.shape
544
+
545
+ # print(x.shape, m.shape)
546
+
547
+ kernel_freq_half = (kernel_freq - 1) // 2
548
+ kernel_time_half = (kernel_time - 1) // 2
549
+
550
+ for ifreq in range(kernel_freq):
551
+ for itime in range(kernel_time):
552
+ df, dt = kernel_freq_half - ifreq, kernel_time_half - itime
553
+ x = x.roll(shifts=(df, dt), dims=(2, 3))
554
+
555
+ # if `df` > 0:
556
+ # x[:, :, :df, :] = 0
557
+ # elif `df` < 0:
558
+ # x[:, :, df:, :] = 0
559
+
560
+ # if `dt` > 0:
561
+ # x[:, :, :, :dt] = 0
562
+ # elif `dt` < 0:
563
+ # x[:, :, :, dt:] = 0
564
+
565
+ fslice = slice(max(0, df), min(n_freq, n_freq + df))
566
+ tslice = slice(max(0, dt), min(n_time, n_time + dt))
567
+
568
+ s[:, :, fslice, tslice] += (
569
+ x[:, :, fslice, tslice] * m[ifreq, itime, :, :, fslice, tslice]
570
+ )
571
+
572
+ return s
573
+
574
+
575
+ class MultiSourceMultiPatchingMaskBandSplitCoreRNN(PatchingMaskBandsplitCoreBase):
576
+ def __init__(
577
+ self,
578
+ in_channel: int,
579
+ stems: List[str],
580
+ band_specs: List[Tuple[float, float]],
581
+ mask_kernel_freq: int,
582
+ mask_kernel_time: int,
583
+ conv_kernel_freq: int,
584
+ conv_kernel_time: int,
585
+ kernel_norm_mlp_version: int,
586
+ require_no_overlap: bool = False,
587
+ require_no_gap: bool = True,
588
+ normalize_channel_independently: bool = False,
589
+ treat_channel_as_feature: bool = True,
590
+ n_sqm_modules: int = 12,
591
+ emb_dim: int = 128,
592
+ rnn_dim: int = 256,
593
+ bidirectional: bool = True,
594
+ rnn_type: str = "LSTM",
595
+ mlp_dim: int = 512,
596
+ hidden_activation: str = "Tanh",
597
+ hidden_activation_kwargs: Optional[Dict] = None,
598
+ complex_mask: bool = True,
599
+ overlapping_band: bool = False,
600
+ freq_weights: Optional[List[torch.Tensor]] = None,
601
+ n_freq: Optional[int] = None,
602
+ ) -> None:
603
+
604
+ super().__init__()
605
+ self.band_split = BandSplitModule(
606
+ in_channel=in_channel,
607
+ band_specs=band_specs,
608
+ require_no_overlap=require_no_overlap,
609
+ require_no_gap=require_no_gap,
610
+ normalize_channel_independently=normalize_channel_independently,
611
+ treat_channel_as_feature=treat_channel_as_feature,
612
+ emb_dim=emb_dim,
613
+ )
614
+
615
+ self.tf_model = SeqBandModellingModule(
616
+ n_modules=n_sqm_modules,
617
+ emb_dim=emb_dim,
618
+ rnn_dim=rnn_dim,
619
+ bidirectional=bidirectional,
620
+ rnn_type=rnn_type,
621
+ )
622
+
623
+ if hidden_activation_kwargs is None:
624
+ hidden_activation_kwargs = {}
625
+
626
+ if overlapping_band:
627
+ assert freq_weights is not None
628
+ assert n_freq is not None
629
+ self.mask_estim = nn.ModuleDict(
630
+ {
631
+ stem: PatchingMaskEstimationModule(
632
+ band_specs=band_specs,
633
+ freq_weights=freq_weights,
634
+ n_freq=n_freq,
635
+ emb_dim=emb_dim,
636
+ mlp_dim=mlp_dim,
637
+ in_channel=in_channel,
638
+ hidden_activation=hidden_activation,
639
+ hidden_activation_kwargs=hidden_activation_kwargs,
640
+ complex_mask=complex_mask,
641
+ mask_kernel_freq=mask_kernel_freq,
642
+ mask_kernel_time=mask_kernel_time,
643
+ conv_kernel_freq=conv_kernel_freq,
644
+ conv_kernel_time=conv_kernel_time,
645
+ kernel_norm_mlp_version=kernel_norm_mlp_version,
646
+ )
647
+ for stem in stems
648
+ }
649
+ )
650
+ else:
651
+ raise NotImplementedError
mvsepless/models/bandit/core/model/bsrnn/maskestim.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Dict, List, Optional, Tuple, Type
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn.modules import activation
7
+
8
+ from .utils import (
9
+ band_widths_from_specs,
10
+ check_no_gap,
11
+ check_no_overlap,
12
+ check_nonzero_bandwidth,
13
+ )
14
+
15
+
16
+ class BaseNormMLP(nn.Module):
17
+ def __init__(
18
+ self,
19
+ emb_dim: int,
20
+ mlp_dim: int,
21
+ bandwidth: int,
22
+ in_channel: Optional[int],
23
+ hidden_activation: str = "Tanh",
24
+ hidden_activation_kwargs=None,
25
+ complex_mask: bool = True,
26
+ ):
27
+
28
+ super().__init__()
29
+ if hidden_activation_kwargs is None:
30
+ hidden_activation_kwargs = {}
31
+ self.hidden_activation_kwargs = hidden_activation_kwargs
32
+ self.norm = nn.LayerNorm(emb_dim)
33
+ self.hidden = torch.jit.script(
34
+ nn.Sequential(
35
+ nn.Linear(in_features=emb_dim, out_features=mlp_dim),
36
+ activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
37
+ )
38
+ )
39
+
40
+ self.bandwidth = bandwidth
41
+ self.in_channel = in_channel
42
+
43
+ self.complex_mask = complex_mask
44
+ self.reim = 2 if complex_mask else 1
45
+ self.glu_mult = 2
46
+
47
+
48
+ class NormMLP(BaseNormMLP):
49
+ def __init__(
50
+ self,
51
+ emb_dim: int,
52
+ mlp_dim: int,
53
+ bandwidth: int,
54
+ in_channel: Optional[int],
55
+ hidden_activation: str = "Tanh",
56
+ hidden_activation_kwargs=None,
57
+ complex_mask: bool = True,
58
+ ) -> None:
59
+ super().__init__(
60
+ emb_dim=emb_dim,
61
+ mlp_dim=mlp_dim,
62
+ bandwidth=bandwidth,
63
+ in_channel=in_channel,
64
+ hidden_activation=hidden_activation,
65
+ hidden_activation_kwargs=hidden_activation_kwargs,
66
+ complex_mask=complex_mask,
67
+ )
68
+
69
+ self.output = torch.jit.script(
70
+ nn.Sequential(
71
+ nn.Linear(
72
+ in_features=mlp_dim,
73
+ out_features=bandwidth * in_channel * self.reim * 2,
74
+ ),
75
+ nn.GLU(dim=-1),
76
+ )
77
+ )
78
+
79
+ def reshape_output(self, mb):
80
+ # print(mb.shape)
81
+ batch, n_time, _ = mb.shape
82
+ if self.complex_mask:
83
+ mb = mb.reshape(
84
+ batch, n_time, self.in_channel, self.bandwidth, self.reim
85
+ ).contiguous()
86
+ # print(mb.shape)
87
+ mb = torch.view_as_complex(mb) # (batch, n_time, in_channel, bandwidth)
88
+ else:
89
+ mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth)
90
+
91
+ mb = torch.permute(mb, (0, 2, 3, 1)) # (batch, in_channel, bandwidth, n_time)
92
+
93
+ return mb
94
+
95
+ def forward(self, qb):
96
+ # qb = (batch, n_time, emb_dim)
97
+
98
+ # if torch.any(torch.isnan(qb)):
99
+ # raise ValueError("qb0")
100
+
101
+ qb = self.norm(qb) # (batch, n_time, emb_dim)
102
+
103
+ # if torch.any(torch.isnan(qb)):
104
+ # raise ValueError("qb1")
105
+
106
+ qb = self.hidden(qb) # (batch, n_time, mlp_dim)
107
+ # if torch.any(torch.isnan(qb)):
108
+ # raise ValueError("qb2")
109
+ mb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim)
110
+ # if torch.any(torch.isnan(qb)):
111
+ # raise ValueError("mb")
112
+ mb = self.reshape_output(mb) # (batch, in_channel, bandwidth, n_time)
113
+
114
+ return mb
115
+
116
+
117
+ class MultAddNormMLP(NormMLP):
118
+ def __init__(
119
+ self,
120
+ emb_dim: int,
121
+ mlp_dim: int,
122
+ bandwidth: int,
123
+ in_channel: "int | None",
124
+ hidden_activation: str = "Tanh",
125
+ hidden_activation_kwargs=None,
126
+ complex_mask: bool = True,
127
+ ) -> None:
128
+ super().__init__(
129
+ emb_dim,
130
+ mlp_dim,
131
+ bandwidth,
132
+ in_channel,
133
+ hidden_activation,
134
+ hidden_activation_kwargs,
135
+ complex_mask,
136
+ )
137
+
138
+ self.output2 = torch.jit.script(
139
+ nn.Sequential(
140
+ nn.Linear(
141
+ in_features=mlp_dim,
142
+ out_features=bandwidth * in_channel * self.reim * 2,
143
+ ),
144
+ nn.GLU(dim=-1),
145
+ )
146
+ )
147
+
148
+ def forward(self, qb):
149
+
150
+ qb = self.norm(qb) # (batch, n_time, emb_dim)
151
+ qb = self.hidden(qb) # (batch, n_time, mlp_dim)
152
+ mmb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim)
153
+ mmb = self.reshape_output(mmb) # (batch, in_channel, bandwidth, n_time)
154
+ amb = self.output2(qb) # (batch, n_time, bandwidth * in_channel * reim)
155
+ amb = self.reshape_output(amb) # (batch, in_channel, bandwidth, n_time)
156
+
157
+ return mmb, amb
158
+
159
+
160
+ class MaskEstimationModuleSuperBase(nn.Module):
161
+ pass
162
+
163
+
164
+ class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
165
+ def __init__(
166
+ self,
167
+ band_specs: List[Tuple[float, float]],
168
+ emb_dim: int,
169
+ mlp_dim: int,
170
+ in_channel: Optional[int],
171
+ hidden_activation: str = "Tanh",
172
+ hidden_activation_kwargs: Dict = None,
173
+ complex_mask: bool = True,
174
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
175
+ norm_mlp_kwargs: Dict = None,
176
+ ) -> None:
177
+ super().__init__()
178
+
179
+ self.band_widths = band_widths_from_specs(band_specs)
180
+ self.n_bands = len(band_specs)
181
+
182
+ if hidden_activation_kwargs is None:
183
+ hidden_activation_kwargs = {}
184
+
185
+ if norm_mlp_kwargs is None:
186
+ norm_mlp_kwargs = {}
187
+
188
+ self.norm_mlp = nn.ModuleList(
189
+ [
190
+ (
191
+ norm_mlp_cls(
192
+ bandwidth=self.band_widths[b],
193
+ emb_dim=emb_dim,
194
+ mlp_dim=mlp_dim,
195
+ in_channel=in_channel,
196
+ hidden_activation=hidden_activation,
197
+ hidden_activation_kwargs=hidden_activation_kwargs,
198
+ complex_mask=complex_mask,
199
+ **norm_mlp_kwargs,
200
+ )
201
+ )
202
+ for b in range(self.n_bands)
203
+ ]
204
+ )
205
+
206
+ def compute_masks(self, q):
207
+ batch, n_bands, n_time, emb_dim = q.shape
208
+
209
+ masks = []
210
+
211
+ for b, nmlp in enumerate(self.norm_mlp):
212
+ # print(f"maskestim/{b:02d}")
213
+ qb = q[:, b, :, :]
214
+ mb = nmlp(qb)
215
+ masks.append(mb)
216
+
217
+ return masks
218
+
219
+
220
+ class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
221
+ def __init__(
222
+ self,
223
+ in_channel: int,
224
+ band_specs: List[Tuple[float, float]],
225
+ freq_weights: List[torch.Tensor],
226
+ n_freq: int,
227
+ emb_dim: int,
228
+ mlp_dim: int,
229
+ cond_dim: int = 0,
230
+ hidden_activation: str = "Tanh",
231
+ hidden_activation_kwargs: Dict = None,
232
+ complex_mask: bool = True,
233
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
234
+ norm_mlp_kwargs: Dict = None,
235
+ use_freq_weights: bool = True,
236
+ ) -> None:
237
+ check_nonzero_bandwidth(band_specs)
238
+ check_no_gap(band_specs)
239
+
240
+ # if cond_dim > 0:
241
+ # raise NotImplementedError
242
+
243
+ super().__init__(
244
+ band_specs=band_specs,
245
+ emb_dim=emb_dim + cond_dim,
246
+ mlp_dim=mlp_dim,
247
+ in_channel=in_channel,
248
+ hidden_activation=hidden_activation,
249
+ hidden_activation_kwargs=hidden_activation_kwargs,
250
+ complex_mask=complex_mask,
251
+ norm_mlp_cls=norm_mlp_cls,
252
+ norm_mlp_kwargs=norm_mlp_kwargs,
253
+ )
254
+
255
+ self.n_freq = n_freq
256
+ self.band_specs = band_specs
257
+ self.in_channel = in_channel
258
+
259
+ if freq_weights is not None:
260
+ for i, fw in enumerate(freq_weights):
261
+ self.register_buffer(f"freq_weights/{i}", fw)
262
+
263
+ self.use_freq_weights = use_freq_weights
264
+ else:
265
+ self.use_freq_weights = False
266
+
267
+ self.cond_dim = cond_dim
268
+
269
+ def forward(self, q, cond=None):
270
+ # q = (batch, n_bands, n_time, emb_dim)
271
+
272
+ batch, n_bands, n_time, emb_dim = q.shape
273
+
274
+ if cond is not None:
275
+ print(cond)
276
+ if cond.ndim == 2:
277
+ cond = cond[:, None, None, :].expand(-1, n_bands, n_time, -1)
278
+ elif cond.ndim == 3:
279
+ assert cond.shape[1] == n_time
280
+ else:
281
+ raise ValueError(f"Invalid cond shape: {cond.shape}")
282
+
283
+ q = torch.cat([q, cond], dim=-1)
284
+ elif self.cond_dim > 0:
285
+ cond = torch.ones(
286
+ (batch, n_bands, n_time, self.cond_dim),
287
+ device=q.device,
288
+ dtype=q.dtype,
289
+ )
290
+ q = torch.cat([q, cond], dim=-1)
291
+ else:
292
+ pass
293
+
294
+ mask_list = self.compute_masks(
295
+ q
296
+ ) # [n_bands * (batch, in_channel, bandwidth, n_time)]
297
+
298
+ masks = torch.zeros(
299
+ (batch, self.in_channel, self.n_freq, n_time),
300
+ device=q.device,
301
+ dtype=mask_list[0].dtype,
302
+ )
303
+
304
+ for im, mask in enumerate(mask_list):
305
+ fstart, fend = self.band_specs[im]
306
+ if self.use_freq_weights:
307
+ fw = self.get_buffer(f"freq_weights/{im}")[:, None]
308
+ mask = mask * fw
309
+ masks[:, :, fstart:fend, :] += mask
310
+
311
+ return masks
312
+
313
+
314
+ class MaskEstimationModule(OverlappingMaskEstimationModule):
315
+ def __init__(
316
+ self,
317
+ band_specs: List[Tuple[float, float]],
318
+ emb_dim: int,
319
+ mlp_dim: int,
320
+ in_channel: Optional[int],
321
+ hidden_activation: str = "Tanh",
322
+ hidden_activation_kwargs: Dict = None,
323
+ complex_mask: bool = True,
324
+ **kwargs,
325
+ ) -> None:
326
+ check_nonzero_bandwidth(band_specs)
327
+ check_no_gap(band_specs)
328
+ check_no_overlap(band_specs)
329
+ super().__init__(
330
+ in_channel=in_channel,
331
+ band_specs=band_specs,
332
+ freq_weights=None,
333
+ n_freq=None,
334
+ emb_dim=emb_dim,
335
+ mlp_dim=mlp_dim,
336
+ hidden_activation=hidden_activation,
337
+ hidden_activation_kwargs=hidden_activation_kwargs,
338
+ complex_mask=complex_mask,
339
+ )
340
+
341
+ def forward(self, q, cond=None):
342
+ # q = (batch, n_bands, n_time, emb_dim)
343
+
344
+ masks = self.compute_masks(
345
+ q
346
+ ) # [n_bands * (batch, in_channel, bandwidth, n_time)]
347
+
348
+ # TODO: currently this requires band specs to have no gap and no overlap
349
+ masks = torch.concat(masks, dim=2) # (batch, in_channel, n_freq, n_time)
350
+
351
+ return masks
mvsepless/models/bandit/core/model/bsrnn/tfmodel.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.nn.modules import rnn
7
+
8
+ import torch.backends.cuda
9
+
10
+
11
+ class TimeFrequencyModellingModule(nn.Module):
12
+ def __init__(self) -> None:
13
+ super().__init__()
14
+
15
+
16
+ class ResidualRNN(nn.Module):
17
+ def __init__(
18
+ self,
19
+ emb_dim: int,
20
+ rnn_dim: int,
21
+ bidirectional: bool = True,
22
+ rnn_type: str = "LSTM",
23
+ use_batch_trick: bool = True,
24
+ use_layer_norm: bool = True,
25
+ ) -> None:
26
+ # n_group is the size of the 2nd dim
27
+ super().__init__()
28
+
29
+ self.use_layer_norm = use_layer_norm
30
+ if use_layer_norm:
31
+ self.norm = nn.LayerNorm(emb_dim)
32
+ else:
33
+ self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim)
34
+
35
+ self.rnn = rnn.__dict__[rnn_type](
36
+ input_size=emb_dim,
37
+ hidden_size=rnn_dim,
38
+ num_layers=1,
39
+ batch_first=True,
40
+ bidirectional=bidirectional,
41
+ )
42
+
43
+ self.fc = nn.Linear(
44
+ in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
45
+ )
46
+
47
+ self.use_batch_trick = use_batch_trick
48
+ if not self.use_batch_trick:
49
+ warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
50
+
51
+ def forward(self, z):
52
+ # z = (batch, n_uncrossed, n_across, emb_dim)
53
+
54
+ z0 = torch.clone(z)
55
+
56
+ # print(z.device)
57
+
58
+ if self.use_layer_norm:
59
+ z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim)
60
+ else:
61
+ z = torch.permute(
62
+ z, (0, 3, 1, 2)
63
+ ) # (batch, emb_dim, n_uncrossed, n_across)
64
+
65
+ z = self.norm(z) # (batch, emb_dim, n_uncrossed, n_across)
66
+
67
+ z = torch.permute(
68
+ z, (0, 2, 3, 1)
69
+ ) # (batch, n_uncrossed, n_across, emb_dim)
70
+
71
+ batch, n_uncrossed, n_across, emb_dim = z.shape
72
+
73
+ if self.use_batch_trick:
74
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
75
+
76
+ z = self.rnn(z.contiguous())[
77
+ 0
78
+ ] # (batch * n_uncrossed, n_across, dir_rnn_dim)
79
+
80
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
81
+ # (batch, n_uncrossed, n_across, dir_rnn_dim)
82
+ else:
83
+ # Note: this is EXTREMELY SLOW
84
+ zlist = []
85
+ for i in range(n_uncrossed):
86
+ zi = self.rnn(z[:, i, :, :])[0] # (batch, n_across, emb_dim)
87
+ zlist.append(zi)
88
+
89
+ z = torch.stack(zlist, dim=1) # (batch, n_uncrossed, n_across, dir_rnn_dim)
90
+
91
+ z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
92
+
93
+ z = z + z0
94
+
95
+ return z
96
+
97
+
98
+ class SeqBandModellingModule(TimeFrequencyModellingModule):
99
+ def __init__(
100
+ self,
101
+ n_modules: int = 12,
102
+ emb_dim: int = 128,
103
+ rnn_dim: int = 256,
104
+ bidirectional: bool = True,
105
+ rnn_type: str = "LSTM",
106
+ parallel_mode=False,
107
+ ) -> None:
108
+ super().__init__()
109
+ self.seqband = nn.ModuleList([])
110
+
111
+ if parallel_mode:
112
+ for _ in range(n_modules):
113
+ self.seqband.append(
114
+ nn.ModuleList(
115
+ [
116
+ ResidualRNN(
117
+ emb_dim=emb_dim,
118
+ rnn_dim=rnn_dim,
119
+ bidirectional=bidirectional,
120
+ rnn_type=rnn_type,
121
+ ),
122
+ ResidualRNN(
123
+ emb_dim=emb_dim,
124
+ rnn_dim=rnn_dim,
125
+ bidirectional=bidirectional,
126
+ rnn_type=rnn_type,
127
+ ),
128
+ ]
129
+ )
130
+ )
131
+ else:
132
+
133
+ for _ in range(2 * n_modules):
134
+ self.seqband.append(
135
+ ResidualRNN(
136
+ emb_dim=emb_dim,
137
+ rnn_dim=rnn_dim,
138
+ bidirectional=bidirectional,
139
+ rnn_type=rnn_type,
140
+ )
141
+ )
142
+
143
+ self.parallel_mode = parallel_mode
144
+
145
+ def forward(self, z):
146
+ # z = (batch, n_bands, n_time, emb_dim)
147
+
148
+ if self.parallel_mode:
149
+ for sbm_pair in self.seqband:
150
+ # z: (batch, n_bands, n_time, emb_dim)
151
+ sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
152
+ zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
153
+ zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
154
+ z = zt + zf.transpose(1, 2)
155
+ else:
156
+ for sbm in self.seqband:
157
+ z = sbm(z)
158
+ z = z.transpose(1, 2)
159
+
160
+ # (batch, n_bands, n_time, emb_dim)
161
+ # --> (batch, n_time, n_bands, emb_dim)
162
+ # OR
163
+ # (batch, n_time, n_bands, emb_dim)
164
+ # --> (batch, n_bands, n_time, emb_dim)
165
+
166
+ q = z
167
+ return q # (batch, n_bands, n_time, emb_dim)
168
+
169
+
170
+ class ResidualTransformer(nn.Module):
171
+ def __init__(
172
+ self,
173
+ emb_dim: int = 128,
174
+ rnn_dim: int = 256,
175
+ bidirectional: bool = True,
176
+ dropout: float = 0.0,
177
+ ) -> None:
178
+ # n_group is the size of the 2nd dim
179
+ super().__init__()
180
+
181
+ self.tf = nn.TransformerEncoderLayer(
182
+ d_model=emb_dim, nhead=4, dim_feedforward=rnn_dim, batch_first=True
183
+ )
184
+
185
+ self.is_causal = not bidirectional
186
+ self.dropout = dropout
187
+
188
+ def forward(self, z):
189
+ batch, n_uncrossed, n_across, emb_dim = z.shape
190
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
191
+ z = self.tf(
192
+ z, is_causal=self.is_causal
193
+ ) # (batch, n_uncrossed, n_across, emb_dim)
194
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim))
195
+
196
+ return z
197
+
198
+
199
+ class TransformerTimeFreqModule(TimeFrequencyModellingModule):
200
+ def __init__(
201
+ self,
202
+ n_modules: int = 12,
203
+ emb_dim: int = 128,
204
+ rnn_dim: int = 256,
205
+ bidirectional: bool = True,
206
+ dropout: float = 0.0,
207
+ ) -> None:
208
+ super().__init__()
209
+ self.norm = nn.LayerNorm(emb_dim)
210
+ self.seqband = nn.ModuleList([])
211
+
212
+ for _ in range(2 * n_modules):
213
+ self.seqband.append(
214
+ ResidualTransformer(
215
+ emb_dim=emb_dim,
216
+ rnn_dim=rnn_dim,
217
+ bidirectional=bidirectional,
218
+ dropout=dropout,
219
+ )
220
+ )
221
+
222
+ def forward(self, z):
223
+ # z = (batch, n_bands, n_time, emb_dim)
224
+ z = self.norm(z) # (batch, n_bands, n_time, emb_dim)
225
+
226
+ for sbm in self.seqband:
227
+ z = sbm(z)
228
+ z = z.transpose(1, 2)
229
+
230
+ # (batch, n_bands, n_time, emb_dim)
231
+ # --> (batch, n_time, n_bands, emb_dim)
232
+ # OR
233
+ # (batch, n_time, n_bands, emb_dim)
234
+ # --> (batch, n_bands, n_time, emb_dim)
235
+
236
+ q = z
237
+ return q # (batch, n_bands, n_time, emb_dim)
238
+
239
+
240
+ class ResidualConvolution(nn.Module):
241
+ def __init__(
242
+ self,
243
+ emb_dim: int = 128,
244
+ rnn_dim: int = 256,
245
+ bidirectional: bool = True,
246
+ dropout: float = 0.0,
247
+ ) -> None:
248
+ # n_group is the size of the 2nd dim
249
+ super().__init__()
250
+ self.norm = nn.InstanceNorm2d(emb_dim, affine=True)
251
+
252
+ self.conv = nn.Sequential(
253
+ nn.Conv2d(
254
+ in_channels=emb_dim,
255
+ out_channels=rnn_dim,
256
+ kernel_size=(3, 3),
257
+ padding="same",
258
+ stride=(1, 1),
259
+ ),
260
+ nn.Tanhshrink(),
261
+ )
262
+
263
+ self.is_causal = not bidirectional
264
+ self.dropout = dropout
265
+
266
+ self.fc = nn.Conv2d(
267
+ in_channels=rnn_dim,
268
+ out_channels=emb_dim,
269
+ kernel_size=(1, 1),
270
+ padding="same",
271
+ stride=(1, 1),
272
+ )
273
+
274
+ def forward(self, z):
275
+ # z = (batch, n_uncrossed, n_across, emb_dim)
276
+
277
+ z0 = torch.clone(z)
278
+
279
+ z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim)
280
+ z = self.conv(z) # (batch, n_uncrossed, n_across, emb_dim)
281
+ z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
282
+ z = z + z0
283
+
284
+ return z
285
+
286
+
287
+ class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule):
288
+ def __init__(
289
+ self,
290
+ n_modules: int = 12,
291
+ emb_dim: int = 128,
292
+ rnn_dim: int = 256,
293
+ bidirectional: bool = True,
294
+ dropout: float = 0.0,
295
+ ) -> None:
296
+ super().__init__()
297
+ self.seqband = torch.jit.script(
298
+ nn.Sequential(
299
+ *[
300
+ ResidualConvolution(
301
+ emb_dim=emb_dim,
302
+ rnn_dim=rnn_dim,
303
+ bidirectional=bidirectional,
304
+ dropout=dropout,
305
+ )
306
+ for _ in range(2 * n_modules)
307
+ ]
308
+ )
309
+ )
310
+
311
+ def forward(self, z):
312
+ # z = (batch, n_bands, n_time, emb_dim)
313
+
314
+ z = torch.permute(z, (0, 3, 1, 2)) # (batch, emb_dim, n_bands, n_time)
315
+
316
+ z = self.seqband(z) # (batch, emb_dim, n_bands, n_time)
317
+
318
+ z = torch.permute(z, (0, 2, 3, 1)) # (batch, n_bands, n_time, emb_dim)
319
+
320
+ return z
mvsepless/models/bandit/core/model/bsrnn/utils.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import abstractmethod
3
+ from typing import Any, Callable
4
+
5
+ import numpy as np
6
+ import torch
7
+ from librosa import hz_to_midi, midi_to_hz
8
+ from torch import Tensor
9
+ from torchaudio import functional as taF
10
+ from spafe.fbanks import bark_fbanks
11
+ from spafe.utils.converters import erb2hz, hz2bark, hz2erb
12
+ from torchaudio.functional.functional import _create_triangular_filterbank
13
+
14
+
15
+ def band_widths_from_specs(band_specs):
16
+ return [e - i for i, e in band_specs]
17
+
18
+
19
+ def check_nonzero_bandwidth(band_specs):
20
+ # pprint(band_specs)
21
+ for fstart, fend in band_specs:
22
+ if fend - fstart <= 0:
23
+ raise ValueError("Bands cannot be zero-width")
24
+
25
+
26
+ def check_no_overlap(band_specs):
27
+ fend_prev = -1
28
+ for fstart_curr, fend_curr in band_specs:
29
+ if fstart_curr <= fend_prev:
30
+ raise ValueError("Bands cannot overlap")
31
+
32
+
33
+ def check_no_gap(band_specs):
34
+ fstart, _ = band_specs[0]
35
+ assert fstart == 0
36
+
37
+ fend_prev = -1
38
+ for fstart_curr, fend_curr in band_specs:
39
+ if fstart_curr - fend_prev > 1:
40
+ raise ValueError("Bands cannot leave gap")
41
+ fend_prev = fend_curr
42
+
43
+
44
+ class BandsplitSpecification:
45
+ def __init__(self, nfft: int, fs: int) -> None:
46
+ self.fs = fs
47
+ self.nfft = nfft
48
+ self.nyquist = fs / 2
49
+ self.max_index = nfft // 2 + 1
50
+
51
+ self.split500 = self.hertz_to_index(500)
52
+ self.split1k = self.hertz_to_index(1000)
53
+ self.split2k = self.hertz_to_index(2000)
54
+ self.split4k = self.hertz_to_index(4000)
55
+ self.split8k = self.hertz_to_index(8000)
56
+ self.split16k = self.hertz_to_index(16000)
57
+ self.split20k = self.hertz_to_index(20000)
58
+
59
+ self.above20k = [(self.split20k, self.max_index)]
60
+ self.above16k = [(self.split16k, self.split20k)] + self.above20k
61
+
62
+ def index_to_hertz(self, index: int):
63
+ return index * self.fs / self.nfft
64
+
65
+ def hertz_to_index(self, hz: float, round: bool = True):
66
+ index = hz * self.nfft / self.fs
67
+
68
+ if round:
69
+ index = int(np.round(index))
70
+
71
+ return index
72
+
73
+ def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
74
+ band_specs = []
75
+ lower = start_index
76
+
77
+ while lower < end_index:
78
+ upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
79
+ upper = min(upper, end_index)
80
+
81
+ band_specs.append((lower, upper))
82
+ lower = upper
83
+
84
+ return band_specs
85
+
86
+ @abstractmethod
87
+ def get_band_specs(self):
88
+ raise NotImplementedError
89
+
90
+
91
+ class VocalBandsplitSpecification(BandsplitSpecification):
92
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
93
+ super().__init__(nfft=nfft, fs=fs)
94
+
95
+ self.version = version
96
+
97
+ def get_band_specs(self):
98
+ return getattr(self, f"version{self.version}")()
99
+
100
+ @property
101
+ def version1(self):
102
+ return self.get_band_specs_with_bandwidth(
103
+ start_index=0, end_index=self.max_index, bandwidth_hz=1000
104
+ )
105
+
106
+ def version2(self):
107
+ below16k = self.get_band_specs_with_bandwidth(
108
+ start_index=0, end_index=self.split16k, bandwidth_hz=1000
109
+ )
110
+ below20k = self.get_band_specs_with_bandwidth(
111
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
112
+ )
113
+
114
+ return below16k + below20k + self.above20k
115
+
116
+ def version3(self):
117
+ below8k = self.get_band_specs_with_bandwidth(
118
+ start_index=0, end_index=self.split8k, bandwidth_hz=1000
119
+ )
120
+ below16k = self.get_band_specs_with_bandwidth(
121
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
122
+ )
123
+
124
+ return below8k + below16k + self.above16k
125
+
126
+ def version4(self):
127
+ below1k = self.get_band_specs_with_bandwidth(
128
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
129
+ )
130
+ below8k = self.get_band_specs_with_bandwidth(
131
+ start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
132
+ )
133
+ below16k = self.get_band_specs_with_bandwidth(
134
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
135
+ )
136
+
137
+ return below1k + below8k + below16k + self.above16k
138
+
139
+ def version5(self):
140
+ below1k = self.get_band_specs_with_bandwidth(
141
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
142
+ )
143
+ below16k = self.get_band_specs_with_bandwidth(
144
+ start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
145
+ )
146
+ below20k = self.get_band_specs_with_bandwidth(
147
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
148
+ )
149
+ return below1k + below16k + below20k + self.above20k
150
+
151
+ def version6(self):
152
+ below1k = self.get_band_specs_with_bandwidth(
153
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
154
+ )
155
+ below4k = self.get_band_specs_with_bandwidth(
156
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
157
+ )
158
+ below8k = self.get_band_specs_with_bandwidth(
159
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
160
+ )
161
+ below16k = self.get_band_specs_with_bandwidth(
162
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
163
+ )
164
+ return below1k + below4k + below8k + below16k + self.above16k
165
+
166
+ def version7(self):
167
+ below1k = self.get_band_specs_with_bandwidth(
168
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
169
+ )
170
+ below4k = self.get_band_specs_with_bandwidth(
171
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
172
+ )
173
+ below8k = self.get_band_specs_with_bandwidth(
174
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
175
+ )
176
+ below16k = self.get_band_specs_with_bandwidth(
177
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
178
+ )
179
+ below20k = self.get_band_specs_with_bandwidth(
180
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
181
+ )
182
+ return below1k + below4k + below8k + below16k + below20k + self.above20k
183
+
184
+
185
+ class OtherBandsplitSpecification(VocalBandsplitSpecification):
186
+ def __init__(self, nfft: int, fs: int) -> None:
187
+ super().__init__(nfft=nfft, fs=fs, version="7")
188
+
189
+
190
+ class BassBandsplitSpecification(BandsplitSpecification):
191
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
192
+ super().__init__(nfft=nfft, fs=fs)
193
+
194
+ def get_band_specs(self):
195
+ below500 = self.get_band_specs_with_bandwidth(
196
+ start_index=0, end_index=self.split500, bandwidth_hz=50
197
+ )
198
+ below1k = self.get_band_specs_with_bandwidth(
199
+ start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
200
+ )
201
+ below4k = self.get_band_specs_with_bandwidth(
202
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
203
+ )
204
+ below8k = self.get_band_specs_with_bandwidth(
205
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
206
+ )
207
+ below16k = self.get_band_specs_with_bandwidth(
208
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
209
+ )
210
+ above16k = [(self.split16k, self.max_index)]
211
+
212
+ return below500 + below1k + below4k + below8k + below16k + above16k
213
+
214
+
215
+ class DrumBandsplitSpecification(BandsplitSpecification):
216
+ def __init__(self, nfft: int, fs: int) -> None:
217
+ super().__init__(nfft=nfft, fs=fs)
218
+
219
+ def get_band_specs(self):
220
+ below1k = self.get_band_specs_with_bandwidth(
221
+ start_index=0, end_index=self.split1k, bandwidth_hz=50
222
+ )
223
+ below2k = self.get_band_specs_with_bandwidth(
224
+ start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
225
+ )
226
+ below4k = self.get_band_specs_with_bandwidth(
227
+ start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
228
+ )
229
+ below8k = self.get_band_specs_with_bandwidth(
230
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
231
+ )
232
+ below16k = self.get_band_specs_with_bandwidth(
233
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
234
+ )
235
+ above16k = [(self.split16k, self.max_index)]
236
+
237
+ return below1k + below2k + below4k + below8k + below16k + above16k
238
+
239
+
240
+ class PerceptualBandsplitSpecification(BandsplitSpecification):
241
+ def __init__(
242
+ self,
243
+ nfft: int,
244
+ fs: int,
245
+ fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
246
+ n_bands: int,
247
+ f_min: float = 0.0,
248
+ f_max: float = None,
249
+ ) -> None:
250
+ super().__init__(nfft=nfft, fs=fs)
251
+ self.n_bands = n_bands
252
+ if f_max is None:
253
+ f_max = fs / 2
254
+
255
+ self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
256
+
257
+ weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True) # (1, n_freqs)
258
+ normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs)
259
+
260
+ freq_weights = []
261
+ band_specs = []
262
+ for i in range(self.n_bands):
263
+ active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
264
+ if isinstance(active_bins, int):
265
+ active_bins = (active_bins, active_bins)
266
+ if len(active_bins) == 0:
267
+ continue
268
+ start_index = active_bins[0]
269
+ end_index = active_bins[-1] + 1
270
+ band_specs.append((start_index, end_index))
271
+ freq_weights.append(normalized_mel_fb[i, start_index:end_index])
272
+
273
+ self.freq_weights = freq_weights
274
+ self.band_specs = band_specs
275
+
276
+ def get_band_specs(self):
277
+ return self.band_specs
278
+
279
+ def get_freq_weights(self):
280
+ return self.freq_weights
281
+
282
+ def save_to_file(self, dir_path: str) -> None:
283
+
284
+ os.makedirs(dir_path, exist_ok=True)
285
+
286
+ import pickle
287
+
288
+ with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
289
+ pickle.dump(
290
+ {
291
+ "band_specs": self.band_specs,
292
+ "freq_weights": self.freq_weights,
293
+ "filterbank": self.filterbank,
294
+ },
295
+ f,
296
+ )
297
+
298
+
299
+ def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
300
+ fb = taF.melscale_fbanks(
301
+ n_mels=n_bands,
302
+ sample_rate=fs,
303
+ f_min=f_min,
304
+ f_max=f_max,
305
+ n_freqs=n_freqs,
306
+ ).T
307
+
308
+ fb[0, 0] = 1.0
309
+
310
+ return fb
311
+
312
+
313
+ class MelBandsplitSpecification(PerceptualBandsplitSpecification):
314
+ def __init__(
315
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
316
+ ) -> None:
317
+ super().__init__(
318
+ fbank_fn=mel_filterbank,
319
+ nfft=nfft,
320
+ fs=fs,
321
+ n_bands=n_bands,
322
+ f_min=f_min,
323
+ f_max=f_max,
324
+ )
325
+
326
+
327
+ def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
328
+
329
+ nfft = 2 * (n_freqs - 1)
330
+ df = fs / nfft
331
+ # init freqs
332
+ f_max = f_max or fs / 2
333
+ f_min = f_min or 0
334
+ f_min = fs / nfft
335
+
336
+ n_octaves = np.log2(f_max / f_min)
337
+ n_octaves_per_band = n_octaves / n_bands
338
+ bandwidth_mult = np.power(2.0, n_octaves_per_band)
339
+
340
+ low_midi = max(0, hz_to_midi(f_min))
341
+ high_midi = hz_to_midi(f_max)
342
+ midi_points = np.linspace(low_midi, high_midi, n_bands)
343
+ hz_pts = midi_to_hz(midi_points)
344
+
345
+ low_pts = hz_pts / bandwidth_mult
346
+ high_pts = hz_pts * bandwidth_mult
347
+
348
+ low_bins = np.floor(low_pts / df).astype(int)
349
+ high_bins = np.ceil(high_pts / df).astype(int)
350
+
351
+ fb = np.zeros((n_bands, n_freqs))
352
+
353
+ for i in range(n_bands):
354
+ fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
355
+
356
+ fb[0, : low_bins[0]] = 1.0
357
+ fb[-1, high_bins[-1] + 1 :] = 1.0
358
+
359
+ return torch.as_tensor(fb)
360
+
361
+
362
+ class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
363
+ def __init__(
364
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
365
+ ) -> None:
366
+ super().__init__(
367
+ fbank_fn=musical_filterbank,
368
+ nfft=nfft,
369
+ fs=fs,
370
+ n_bands=n_bands,
371
+ f_min=f_min,
372
+ f_max=f_max,
373
+ )
374
+
375
+
376
+ def bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
377
+ nfft = 2 * (n_freqs - 1)
378
+ fb, _ = bark_fbanks.bark_filter_banks(
379
+ nfilts=n_bands,
380
+ nfft=nfft,
381
+ fs=fs,
382
+ low_freq=f_min,
383
+ high_freq=f_max,
384
+ scale="constant",
385
+ )
386
+
387
+ return torch.as_tensor(fb)
388
+
389
+
390
+ class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
391
+ def __init__(
392
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
393
+ ) -> None:
394
+ super().__init__(
395
+ fbank_fn=bark_filterbank,
396
+ nfft=nfft,
397
+ fs=fs,
398
+ n_bands=n_bands,
399
+ f_min=f_min,
400
+ f_max=f_max,
401
+ )
402
+
403
+
404
+ def triangular_bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
405
+
406
+ all_freqs = torch.linspace(0, fs // 2, n_freqs)
407
+
408
+ # calculate mel freq bins
409
+ m_min = hz2bark(f_min)
410
+ m_max = hz2bark(f_max)
411
+
412
+ m_pts = torch.linspace(m_min, m_max, n_bands + 2)
413
+ f_pts = 600 * torch.sinh(m_pts / 6)
414
+
415
+ # create filterbank
416
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
417
+
418
+ fb = fb.T
419
+
420
+ first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
421
+ first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
422
+
423
+ fb[first_active_band, :first_active_bin] = 1.0
424
+
425
+ return fb
426
+
427
+
428
+ class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
429
+ def __init__(
430
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
431
+ ) -> None:
432
+ super().__init__(
433
+ fbank_fn=triangular_bark_filterbank,
434
+ nfft=nfft,
435
+ fs=fs,
436
+ n_bands=n_bands,
437
+ f_min=f_min,
438
+ f_max=f_max,
439
+ )
440
+
441
+
442
+ def minibark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
443
+ fb = bark_filterbank(n_bands, fs, f_min, f_max, n_freqs)
444
+
445
+ fb[fb < np.sqrt(0.5)] = 0.0
446
+
447
+ return fb
448
+
449
+
450
+ class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
451
+ def __init__(
452
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
453
+ ) -> None:
454
+ super().__init__(
455
+ fbank_fn=minibark_filterbank,
456
+ nfft=nfft,
457
+ fs=fs,
458
+ n_bands=n_bands,
459
+ f_min=f_min,
460
+ f_max=f_max,
461
+ )
462
+
463
+
464
+ def erb_filterbank(
465
+ n_bands: int,
466
+ fs: int,
467
+ f_min: float,
468
+ f_max: float,
469
+ n_freqs: int,
470
+ ) -> Tensor:
471
+ # freq bins
472
+ A = (1000 * np.log(10)) / (24.7 * 4.37)
473
+ all_freqs = torch.linspace(0, fs // 2, n_freqs)
474
+
475
+ # calculate mel freq bins
476
+ m_min = hz2erb(f_min)
477
+ m_max = hz2erb(f_max)
478
+
479
+ m_pts = torch.linspace(m_min, m_max, n_bands + 2)
480
+ f_pts = (torch.pow(10, (m_pts / A)) - 1) / 0.00437
481
+
482
+ # create filterbank
483
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
484
+
485
+ fb = fb.T
486
+
487
+ first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
488
+ first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
489
+
490
+ fb[first_active_band, :first_active_bin] = 1.0
491
+
492
+ return fb
493
+
494
+
495
+ class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
496
+ def __init__(
497
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
498
+ ) -> None:
499
+ super().__init__(
500
+ fbank_fn=erb_filterbank,
501
+ nfft=nfft,
502
+ fs=fs,
503
+ n_bands=n_bands,
504
+ f_min=f_min,
505
+ f_max=f_max,
506
+ )
507
+
508
+
509
+ if __name__ == "__main__":
510
+ import pandas as pd
511
+
512
+ band_defs = []
513
+
514
+ for bands in [VocalBandsplitSpecification]:
515
+ band_name = bands.__name__.replace("BandsplitSpecification", "")
516
+
517
+ mbs = bands(nfft=2048, fs=44100).get_band_specs()
518
+
519
+ for i, (f_min, f_max) in enumerate(mbs):
520
+ band_defs.append(
521
+ {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
522
+ )
523
+
524
+ df = pd.DataFrame(band_defs)
525
+ df.to_csv("vox7bands.csv", index=False)
mvsepless/models/bandit/core/model/bsrnn/wrapper.py ADDED
@@ -0,0 +1,829 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pprint import pprint
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from .._spectral import _SpectralComponent
8
+ from .utils import (
9
+ BarkBandsplitSpecification,
10
+ BassBandsplitSpecification,
11
+ DrumBandsplitSpecification,
12
+ EquivalentRectangularBandsplitSpecification,
13
+ MelBandsplitSpecification,
14
+ MusicalBandsplitSpecification,
15
+ OtherBandsplitSpecification,
16
+ TriangularBarkBandsplitSpecification,
17
+ VocalBandsplitSpecification,
18
+ )
19
+ from .core import (
20
+ MultiSourceMultiMaskBandSplitCoreConv,
21
+ MultiSourceMultiMaskBandSplitCoreRNN,
22
+ MultiSourceMultiMaskBandSplitCoreTransformer,
23
+ MultiSourceMultiPatchingMaskBandSplitCoreRNN,
24
+ SingleMaskBandsplitCoreRNN,
25
+ SingleMaskBandsplitCoreTransformer,
26
+ )
27
+
28
+ import pytorch_lightning as pl
29
+
30
+
31
+ def get_band_specs(band_specs, n_fft, fs, n_bands=None):
32
+ if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]:
33
+ bsm = VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs()
34
+ freq_weights = None
35
+ overlapping_band = False
36
+ elif "tribark" in band_specs:
37
+ assert n_bands is not None
38
+ specs = TriangularBarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
39
+ bsm = specs.get_band_specs()
40
+ freq_weights = specs.get_freq_weights()
41
+ overlapping_band = True
42
+ elif "bark" in band_specs:
43
+ assert n_bands is not None
44
+ specs = BarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
45
+ bsm = specs.get_band_specs()
46
+ freq_weights = specs.get_freq_weights()
47
+ overlapping_band = True
48
+ elif "erb" in band_specs:
49
+ assert n_bands is not None
50
+ specs = EquivalentRectangularBandsplitSpecification(
51
+ nfft=n_fft, fs=fs, n_bands=n_bands
52
+ )
53
+ bsm = specs.get_band_specs()
54
+ freq_weights = specs.get_freq_weights()
55
+ overlapping_band = True
56
+ elif "musical" in band_specs:
57
+ assert n_bands is not None
58
+ specs = MusicalBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
59
+ bsm = specs.get_band_specs()
60
+ freq_weights = specs.get_freq_weights()
61
+ overlapping_band = True
62
+ elif band_specs == "dnr:mel" or "mel" in band_specs:
63
+ assert n_bands is not None
64
+ specs = MelBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
65
+ bsm = specs.get_band_specs()
66
+ freq_weights = specs.get_freq_weights()
67
+ overlapping_band = True
68
+ else:
69
+ raise NameError
70
+
71
+ return bsm, freq_weights, overlapping_band
72
+
73
+
74
+ def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None):
75
+ if band_specs_map == "musdb:all":
76
+ bsm = {
77
+ "vocals": VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
78
+ "drums": DrumBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
79
+ "bass": BassBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
80
+ "other": OtherBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
81
+ }
82
+ freq_weights = None
83
+ overlapping_band = False
84
+ elif band_specs_map == "dnr:vox7":
85
+ bsm_, freq_weights, overlapping_band = get_band_specs(
86
+ "dnr:speech", n_fft, fs, n_bands
87
+ )
88
+ bsm = {"speech": bsm_, "music": bsm_, "effects": bsm_}
89
+ elif "dnr:vox7:" in band_specs_map:
90
+ stem = band_specs_map.split(":")[-1]
91
+ bsm_, freq_weights, overlapping_band = get_band_specs(
92
+ "dnr:speech", n_fft, fs, n_bands
93
+ )
94
+ bsm = {stem: bsm_}
95
+ else:
96
+ raise NameError
97
+
98
+ return bsm, freq_weights, overlapping_band
99
+
100
+
101
+ class BandSplitWrapperBase(pl.LightningModule):
102
+ bsrnn: nn.Module
103
+
104
+ def __init__(self, **kwargs):
105
+ super().__init__()
106
+
107
+
108
+ class SingleMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
109
+ def __init__(
110
+ self,
111
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
112
+ fs: int = 44100,
113
+ n_fft: int = 2048,
114
+ win_length: Optional[int] = 2048,
115
+ hop_length: int = 512,
116
+ window_fn: str = "hann_window",
117
+ wkwargs: Optional[Dict] = None,
118
+ power: Optional[int] = None,
119
+ center: bool = True,
120
+ normalized: bool = True,
121
+ pad_mode: str = "constant",
122
+ onesided: bool = True,
123
+ n_bands: int = None,
124
+ ) -> None:
125
+ super().__init__(
126
+ n_fft=n_fft,
127
+ win_length=win_length,
128
+ hop_length=hop_length,
129
+ window_fn=window_fn,
130
+ wkwargs=wkwargs,
131
+ power=power,
132
+ center=center,
133
+ normalized=normalized,
134
+ pad_mode=pad_mode,
135
+ onesided=onesided,
136
+ )
137
+
138
+ if isinstance(band_specs_map, str):
139
+ self.band_specs_map, self.freq_weights, self.overlapping_band = (
140
+ get_band_specs_map(band_specs_map, n_fft, fs, n_bands=n_bands)
141
+ )
142
+
143
+ self.stems = list(self.band_specs_map.keys())
144
+
145
+ def forward(self, batch):
146
+ audio = batch["audio"]
147
+
148
+ with torch.no_grad():
149
+ batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
150
+
151
+ X = batch["spectrogram"]["mixture"]
152
+ length = batch["audio"]["mixture"].shape[-1]
153
+
154
+ output = {"spectrogram": {}, "audio": {}}
155
+
156
+ for stem, bsrnn in self.bsrnn.items():
157
+ S = bsrnn(X)
158
+ s = self.istft(S, length)
159
+ output["spectrogram"][stem] = S
160
+ output["audio"][stem] = s
161
+
162
+ return batch, output
163
+
164
+
165
+ class MultiMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
166
+ def __init__(
167
+ self,
168
+ stems: List[str],
169
+ band_specs: Union[str, List[Tuple[float, float]]],
170
+ fs: int = 44100,
171
+ n_fft: int = 2048,
172
+ win_length: Optional[int] = 2048,
173
+ hop_length: int = 512,
174
+ window_fn: str = "hann_window",
175
+ wkwargs: Optional[Dict] = None,
176
+ power: Optional[int] = None,
177
+ center: bool = True,
178
+ normalized: bool = True,
179
+ pad_mode: str = "constant",
180
+ onesided: bool = True,
181
+ n_bands: int = None,
182
+ ) -> None:
183
+ super().__init__(
184
+ n_fft=n_fft,
185
+ win_length=win_length,
186
+ hop_length=hop_length,
187
+ window_fn=window_fn,
188
+ wkwargs=wkwargs,
189
+ power=power,
190
+ center=center,
191
+ normalized=normalized,
192
+ pad_mode=pad_mode,
193
+ onesided=onesided,
194
+ )
195
+
196
+ if isinstance(band_specs, str):
197
+ self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
198
+ band_specs, n_fft, fs, n_bands
199
+ )
200
+
201
+ self.stems = stems
202
+
203
+ def forward(self, batch):
204
+ # with torch.no_grad():
205
+ audio = batch["audio"]
206
+ cond = batch.get("condition", None)
207
+ with torch.no_grad():
208
+ batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
209
+
210
+ X = batch["spectrogram"]["mixture"]
211
+ length = batch["audio"]["mixture"].shape[-1]
212
+
213
+ output = self.bsrnn(X, cond=cond)
214
+ output["audio"] = {}
215
+
216
+ for stem, S in output["spectrogram"].items():
217
+ s = self.istft(S, length)
218
+ output["audio"][stem] = s
219
+
220
+ return batch, output
221
+
222
+
223
+ class MultiMaskMultiSourceBandSplitBaseSimple(BandSplitWrapperBase, _SpectralComponent):
224
+ def __init__(
225
+ self,
226
+ stems: List[str],
227
+ band_specs: Union[str, List[Tuple[float, float]]],
228
+ fs: int = 44100,
229
+ n_fft: int = 2048,
230
+ win_length: Optional[int] = 2048,
231
+ hop_length: int = 512,
232
+ window_fn: str = "hann_window",
233
+ wkwargs: Optional[Dict] = None,
234
+ power: Optional[int] = None,
235
+ center: bool = True,
236
+ normalized: bool = True,
237
+ pad_mode: str = "constant",
238
+ onesided: bool = True,
239
+ n_bands: int = None,
240
+ ) -> None:
241
+ super().__init__(
242
+ n_fft=n_fft,
243
+ win_length=win_length,
244
+ hop_length=hop_length,
245
+ window_fn=window_fn,
246
+ wkwargs=wkwargs,
247
+ power=power,
248
+ center=center,
249
+ normalized=normalized,
250
+ pad_mode=pad_mode,
251
+ onesided=onesided,
252
+ )
253
+
254
+ if isinstance(band_specs, str):
255
+ self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
256
+ band_specs, n_fft, fs, n_bands
257
+ )
258
+
259
+ self.stems = stems
260
+
261
+ def forward(self, batch):
262
+ with torch.no_grad():
263
+ X = self.stft(batch)
264
+ length = batch.shape[-1]
265
+ output = self.bsrnn(X, cond=None)
266
+ res = []
267
+ for stem, S in output["spectrogram"].items():
268
+ s = self.istft(S, length)
269
+ res.append(s)
270
+ res = torch.stack(res, dim=1)
271
+ return res
272
+
273
+
274
+ class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase):
275
+ def __init__(
276
+ self,
277
+ in_channel: int,
278
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
279
+ fs: int = 44100,
280
+ require_no_overlap: bool = False,
281
+ require_no_gap: bool = True,
282
+ normalize_channel_independently: bool = False,
283
+ treat_channel_as_feature: bool = True,
284
+ n_sqm_modules: int = 12,
285
+ emb_dim: int = 128,
286
+ rnn_dim: int = 256,
287
+ bidirectional: bool = True,
288
+ rnn_type: str = "LSTM",
289
+ mlp_dim: int = 512,
290
+ hidden_activation: str = "Tanh",
291
+ hidden_activation_kwargs: Optional[Dict] = None,
292
+ complex_mask: bool = True,
293
+ n_fft: int = 2048,
294
+ win_length: Optional[int] = 2048,
295
+ hop_length: int = 512,
296
+ window_fn: str = "hann_window",
297
+ wkwargs: Optional[Dict] = None,
298
+ power: Optional[int] = None,
299
+ center: bool = True,
300
+ normalized: bool = True,
301
+ pad_mode: str = "constant",
302
+ onesided: bool = True,
303
+ ) -> None:
304
+ super().__init__(
305
+ band_specs_map=band_specs_map,
306
+ fs=fs,
307
+ n_fft=n_fft,
308
+ win_length=win_length,
309
+ hop_length=hop_length,
310
+ window_fn=window_fn,
311
+ wkwargs=wkwargs,
312
+ power=power,
313
+ center=center,
314
+ normalized=normalized,
315
+ pad_mode=pad_mode,
316
+ onesided=onesided,
317
+ )
318
+
319
+ self.bsrnn = nn.ModuleDict(
320
+ {
321
+ src: SingleMaskBandsplitCoreRNN(
322
+ band_specs=specs,
323
+ in_channel=in_channel,
324
+ require_no_overlap=require_no_overlap,
325
+ require_no_gap=require_no_gap,
326
+ normalize_channel_independently=normalize_channel_independently,
327
+ treat_channel_as_feature=treat_channel_as_feature,
328
+ n_sqm_modules=n_sqm_modules,
329
+ emb_dim=emb_dim,
330
+ rnn_dim=rnn_dim,
331
+ bidirectional=bidirectional,
332
+ rnn_type=rnn_type,
333
+ mlp_dim=mlp_dim,
334
+ hidden_activation=hidden_activation,
335
+ hidden_activation_kwargs=hidden_activation_kwargs,
336
+ complex_mask=complex_mask,
337
+ )
338
+ for src, specs in self.band_specs_map.items()
339
+ }
340
+ )
341
+
342
+
343
+ class SingleMaskMultiSourceBandSplitTransformer(SingleMaskMultiSourceBandSplitBase):
344
+ def __init__(
345
+ self,
346
+ in_channel: int,
347
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
348
+ fs: int = 44100,
349
+ require_no_overlap: bool = False,
350
+ require_no_gap: bool = True,
351
+ normalize_channel_independently: bool = False,
352
+ treat_channel_as_feature: bool = True,
353
+ n_sqm_modules: int = 12,
354
+ emb_dim: int = 128,
355
+ rnn_dim: int = 256,
356
+ bidirectional: bool = True,
357
+ tf_dropout: float = 0.0,
358
+ mlp_dim: int = 512,
359
+ hidden_activation: str = "Tanh",
360
+ hidden_activation_kwargs: Optional[Dict] = None,
361
+ complex_mask: bool = True,
362
+ n_fft: int = 2048,
363
+ win_length: Optional[int] = 2048,
364
+ hop_length: int = 512,
365
+ window_fn: str = "hann_window",
366
+ wkwargs: Optional[Dict] = None,
367
+ power: Optional[int] = None,
368
+ center: bool = True,
369
+ normalized: bool = True,
370
+ pad_mode: str = "constant",
371
+ onesided: bool = True,
372
+ ) -> None:
373
+ super().__init__(
374
+ band_specs_map=band_specs_map,
375
+ fs=fs,
376
+ n_fft=n_fft,
377
+ win_length=win_length,
378
+ hop_length=hop_length,
379
+ window_fn=window_fn,
380
+ wkwargs=wkwargs,
381
+ power=power,
382
+ center=center,
383
+ normalized=normalized,
384
+ pad_mode=pad_mode,
385
+ onesided=onesided,
386
+ )
387
+
388
+ self.bsrnn = nn.ModuleDict(
389
+ {
390
+ src: SingleMaskBandsplitCoreTransformer(
391
+ band_specs=specs,
392
+ in_channel=in_channel,
393
+ require_no_overlap=require_no_overlap,
394
+ require_no_gap=require_no_gap,
395
+ normalize_channel_independently=normalize_channel_independently,
396
+ treat_channel_as_feature=treat_channel_as_feature,
397
+ n_sqm_modules=n_sqm_modules,
398
+ emb_dim=emb_dim,
399
+ rnn_dim=rnn_dim,
400
+ bidirectional=bidirectional,
401
+ tf_dropout=tf_dropout,
402
+ mlp_dim=mlp_dim,
403
+ hidden_activation=hidden_activation,
404
+ hidden_activation_kwargs=hidden_activation_kwargs,
405
+ complex_mask=complex_mask,
406
+ )
407
+ for src, specs in self.band_specs_map.items()
408
+ }
409
+ )
410
+
411
+
412
+ class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
413
+ def __init__(
414
+ self,
415
+ in_channel: int,
416
+ stems: List[str],
417
+ band_specs: Union[str, List[Tuple[float, float]]],
418
+ fs: int = 44100,
419
+ require_no_overlap: bool = False,
420
+ require_no_gap: bool = True,
421
+ normalize_channel_independently: bool = False,
422
+ treat_channel_as_feature: bool = True,
423
+ n_sqm_modules: int = 12,
424
+ emb_dim: int = 128,
425
+ rnn_dim: int = 256,
426
+ cond_dim: int = 0,
427
+ bidirectional: bool = True,
428
+ rnn_type: str = "LSTM",
429
+ mlp_dim: int = 512,
430
+ hidden_activation: str = "Tanh",
431
+ hidden_activation_kwargs: Optional[Dict] = None,
432
+ complex_mask: bool = True,
433
+ n_fft: int = 2048,
434
+ win_length: Optional[int] = 2048,
435
+ hop_length: int = 512,
436
+ window_fn: str = "hann_window",
437
+ wkwargs: Optional[Dict] = None,
438
+ power: Optional[int] = None,
439
+ center: bool = True,
440
+ normalized: bool = True,
441
+ pad_mode: str = "constant",
442
+ onesided: bool = True,
443
+ n_bands: int = None,
444
+ use_freq_weights: bool = True,
445
+ normalize_input: bool = False,
446
+ mult_add_mask: bool = False,
447
+ freeze_encoder: bool = False,
448
+ ) -> None:
449
+ super().__init__(
450
+ stems=stems,
451
+ band_specs=band_specs,
452
+ fs=fs,
453
+ n_fft=n_fft,
454
+ win_length=win_length,
455
+ hop_length=hop_length,
456
+ window_fn=window_fn,
457
+ wkwargs=wkwargs,
458
+ power=power,
459
+ center=center,
460
+ normalized=normalized,
461
+ pad_mode=pad_mode,
462
+ onesided=onesided,
463
+ n_bands=n_bands,
464
+ )
465
+
466
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
467
+ stems=stems,
468
+ band_specs=self.band_specs,
469
+ in_channel=in_channel,
470
+ require_no_overlap=require_no_overlap,
471
+ require_no_gap=require_no_gap,
472
+ normalize_channel_independently=normalize_channel_independently,
473
+ treat_channel_as_feature=treat_channel_as_feature,
474
+ n_sqm_modules=n_sqm_modules,
475
+ emb_dim=emb_dim,
476
+ rnn_dim=rnn_dim,
477
+ bidirectional=bidirectional,
478
+ rnn_type=rnn_type,
479
+ mlp_dim=mlp_dim,
480
+ cond_dim=cond_dim,
481
+ hidden_activation=hidden_activation,
482
+ hidden_activation_kwargs=hidden_activation_kwargs,
483
+ complex_mask=complex_mask,
484
+ overlapping_band=self.overlapping_band,
485
+ freq_weights=self.freq_weights,
486
+ n_freq=n_fft // 2 + 1,
487
+ use_freq_weights=use_freq_weights,
488
+ mult_add_mask=mult_add_mask,
489
+ )
490
+
491
+ self.normalize_input = normalize_input
492
+ self.cond_dim = cond_dim
493
+
494
+ if freeze_encoder:
495
+ for param in self.bsrnn.band_split.parameters():
496
+ param.requires_grad = False
497
+
498
+ for param in self.bsrnn.tf_model.parameters():
499
+ param.requires_grad = False
500
+
501
+
502
+ class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple):
503
+ def __init__(
504
+ self,
505
+ in_channel: int,
506
+ stems: List[str],
507
+ band_specs: Union[str, List[Tuple[float, float]]],
508
+ fs: int = 44100,
509
+ require_no_overlap: bool = False,
510
+ require_no_gap: bool = True,
511
+ normalize_channel_independently: bool = False,
512
+ treat_channel_as_feature: bool = True,
513
+ n_sqm_modules: int = 12,
514
+ emb_dim: int = 128,
515
+ rnn_dim: int = 256,
516
+ cond_dim: int = 0,
517
+ bidirectional: bool = True,
518
+ rnn_type: str = "LSTM",
519
+ mlp_dim: int = 512,
520
+ hidden_activation: str = "Tanh",
521
+ hidden_activation_kwargs: Optional[Dict] = None,
522
+ complex_mask: bool = True,
523
+ n_fft: int = 2048,
524
+ win_length: Optional[int] = 2048,
525
+ hop_length: int = 512,
526
+ window_fn: str = "hann_window",
527
+ wkwargs: Optional[Dict] = None,
528
+ power: Optional[int] = None,
529
+ center: bool = True,
530
+ normalized: bool = True,
531
+ pad_mode: str = "constant",
532
+ onesided: bool = True,
533
+ n_bands: int = None,
534
+ use_freq_weights: bool = True,
535
+ normalize_input: bool = False,
536
+ mult_add_mask: bool = False,
537
+ freeze_encoder: bool = False,
538
+ ) -> None:
539
+ super().__init__(
540
+ stems=stems,
541
+ band_specs=band_specs,
542
+ fs=fs,
543
+ n_fft=n_fft,
544
+ win_length=win_length,
545
+ hop_length=hop_length,
546
+ window_fn=window_fn,
547
+ wkwargs=wkwargs,
548
+ power=power,
549
+ center=center,
550
+ normalized=normalized,
551
+ pad_mode=pad_mode,
552
+ onesided=onesided,
553
+ n_bands=n_bands,
554
+ )
555
+
556
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
557
+ stems=stems,
558
+ band_specs=self.band_specs,
559
+ in_channel=in_channel,
560
+ require_no_overlap=require_no_overlap,
561
+ require_no_gap=require_no_gap,
562
+ normalize_channel_independently=normalize_channel_independently,
563
+ treat_channel_as_feature=treat_channel_as_feature,
564
+ n_sqm_modules=n_sqm_modules,
565
+ emb_dim=emb_dim,
566
+ rnn_dim=rnn_dim,
567
+ bidirectional=bidirectional,
568
+ rnn_type=rnn_type,
569
+ mlp_dim=mlp_dim,
570
+ cond_dim=cond_dim,
571
+ hidden_activation=hidden_activation,
572
+ hidden_activation_kwargs=hidden_activation_kwargs,
573
+ complex_mask=complex_mask,
574
+ overlapping_band=self.overlapping_band,
575
+ freq_weights=self.freq_weights,
576
+ n_freq=n_fft // 2 + 1,
577
+ use_freq_weights=use_freq_weights,
578
+ mult_add_mask=mult_add_mask,
579
+ )
580
+
581
+ self.normalize_input = normalize_input
582
+ self.cond_dim = cond_dim
583
+
584
+ if freeze_encoder:
585
+ for param in self.bsrnn.band_split.parameters():
586
+ param.requires_grad = False
587
+
588
+ for param in self.bsrnn.tf_model.parameters():
589
+ param.requires_grad = False
590
+
591
+
592
+ class MultiMaskMultiSourceBandSplitTransformer(MultiMaskMultiSourceBandSplitBase):
593
+ def __init__(
594
+ self,
595
+ in_channel: int,
596
+ stems: List[str],
597
+ band_specs: Union[str, List[Tuple[float, float]]],
598
+ fs: int = 44100,
599
+ require_no_overlap: bool = False,
600
+ require_no_gap: bool = True,
601
+ normalize_channel_independently: bool = False,
602
+ treat_channel_as_feature: bool = True,
603
+ n_sqm_modules: int = 12,
604
+ emb_dim: int = 128,
605
+ rnn_dim: int = 256,
606
+ cond_dim: int = 0,
607
+ bidirectional: bool = True,
608
+ rnn_type: str = "LSTM",
609
+ mlp_dim: int = 512,
610
+ hidden_activation: str = "Tanh",
611
+ hidden_activation_kwargs: Optional[Dict] = None,
612
+ complex_mask: bool = True,
613
+ n_fft: int = 2048,
614
+ win_length: Optional[int] = 2048,
615
+ hop_length: int = 512,
616
+ window_fn: str = "hann_window",
617
+ wkwargs: Optional[Dict] = None,
618
+ power: Optional[int] = None,
619
+ center: bool = True,
620
+ normalized: bool = True,
621
+ pad_mode: str = "constant",
622
+ onesided: bool = True,
623
+ n_bands: int = None,
624
+ use_freq_weights: bool = True,
625
+ normalize_input: bool = False,
626
+ mult_add_mask: bool = False,
627
+ ) -> None:
628
+ super().__init__(
629
+ stems=stems,
630
+ band_specs=band_specs,
631
+ fs=fs,
632
+ n_fft=n_fft,
633
+ win_length=win_length,
634
+ hop_length=hop_length,
635
+ window_fn=window_fn,
636
+ wkwargs=wkwargs,
637
+ power=power,
638
+ center=center,
639
+ normalized=normalized,
640
+ pad_mode=pad_mode,
641
+ onesided=onesided,
642
+ n_bands=n_bands,
643
+ )
644
+
645
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer(
646
+ stems=stems,
647
+ band_specs=self.band_specs,
648
+ in_channel=in_channel,
649
+ require_no_overlap=require_no_overlap,
650
+ require_no_gap=require_no_gap,
651
+ normalize_channel_independently=normalize_channel_independently,
652
+ treat_channel_as_feature=treat_channel_as_feature,
653
+ n_sqm_modules=n_sqm_modules,
654
+ emb_dim=emb_dim,
655
+ rnn_dim=rnn_dim,
656
+ bidirectional=bidirectional,
657
+ rnn_type=rnn_type,
658
+ mlp_dim=mlp_dim,
659
+ cond_dim=cond_dim,
660
+ hidden_activation=hidden_activation,
661
+ hidden_activation_kwargs=hidden_activation_kwargs,
662
+ complex_mask=complex_mask,
663
+ overlapping_band=self.overlapping_band,
664
+ freq_weights=self.freq_weights,
665
+ n_freq=n_fft // 2 + 1,
666
+ use_freq_weights=use_freq_weights,
667
+ mult_add_mask=mult_add_mask,
668
+ )
669
+
670
+
671
+ class MultiMaskMultiSourceBandSplitConv(MultiMaskMultiSourceBandSplitBase):
672
+ def __init__(
673
+ self,
674
+ in_channel: int,
675
+ stems: List[str],
676
+ band_specs: Union[str, List[Tuple[float, float]]],
677
+ fs: int = 44100,
678
+ require_no_overlap: bool = False,
679
+ require_no_gap: bool = True,
680
+ normalize_channel_independently: bool = False,
681
+ treat_channel_as_feature: bool = True,
682
+ n_sqm_modules: int = 12,
683
+ emb_dim: int = 128,
684
+ rnn_dim: int = 256,
685
+ cond_dim: int = 0,
686
+ bidirectional: bool = True,
687
+ rnn_type: str = "LSTM",
688
+ mlp_dim: int = 512,
689
+ hidden_activation: str = "Tanh",
690
+ hidden_activation_kwargs: Optional[Dict] = None,
691
+ complex_mask: bool = True,
692
+ n_fft: int = 2048,
693
+ win_length: Optional[int] = 2048,
694
+ hop_length: int = 512,
695
+ window_fn: str = "hann_window",
696
+ wkwargs: Optional[Dict] = None,
697
+ power: Optional[int] = None,
698
+ center: bool = True,
699
+ normalized: bool = True,
700
+ pad_mode: str = "constant",
701
+ onesided: bool = True,
702
+ n_bands: int = None,
703
+ use_freq_weights: bool = True,
704
+ normalize_input: bool = False,
705
+ mult_add_mask: bool = False,
706
+ ) -> None:
707
+ super().__init__(
708
+ stems=stems,
709
+ band_specs=band_specs,
710
+ fs=fs,
711
+ n_fft=n_fft,
712
+ win_length=win_length,
713
+ hop_length=hop_length,
714
+ window_fn=window_fn,
715
+ wkwargs=wkwargs,
716
+ power=power,
717
+ center=center,
718
+ normalized=normalized,
719
+ pad_mode=pad_mode,
720
+ onesided=onesided,
721
+ n_bands=n_bands,
722
+ )
723
+
724
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv(
725
+ stems=stems,
726
+ band_specs=self.band_specs,
727
+ in_channel=in_channel,
728
+ require_no_overlap=require_no_overlap,
729
+ require_no_gap=require_no_gap,
730
+ normalize_channel_independently=normalize_channel_independently,
731
+ treat_channel_as_feature=treat_channel_as_feature,
732
+ n_sqm_modules=n_sqm_modules,
733
+ emb_dim=emb_dim,
734
+ rnn_dim=rnn_dim,
735
+ bidirectional=bidirectional,
736
+ rnn_type=rnn_type,
737
+ mlp_dim=mlp_dim,
738
+ cond_dim=cond_dim,
739
+ hidden_activation=hidden_activation,
740
+ hidden_activation_kwargs=hidden_activation_kwargs,
741
+ complex_mask=complex_mask,
742
+ overlapping_band=self.overlapping_band,
743
+ freq_weights=self.freq_weights,
744
+ n_freq=n_fft // 2 + 1,
745
+ use_freq_weights=use_freq_weights,
746
+ mult_add_mask=mult_add_mask,
747
+ )
748
+
749
+
750
+ class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
751
+ def __init__(
752
+ self,
753
+ in_channel: int,
754
+ stems: List[str],
755
+ band_specs: Union[str, List[Tuple[float, float]]],
756
+ kernel_norm_mlp_version: int = 1,
757
+ mask_kernel_freq: int = 3,
758
+ mask_kernel_time: int = 3,
759
+ conv_kernel_freq: int = 1,
760
+ conv_kernel_time: int = 1,
761
+ fs: int = 44100,
762
+ require_no_overlap: bool = False,
763
+ require_no_gap: bool = True,
764
+ normalize_channel_independently: bool = False,
765
+ treat_channel_as_feature: bool = True,
766
+ n_sqm_modules: int = 12,
767
+ emb_dim: int = 128,
768
+ rnn_dim: int = 256,
769
+ bidirectional: bool = True,
770
+ rnn_type: str = "LSTM",
771
+ mlp_dim: int = 512,
772
+ hidden_activation: str = "Tanh",
773
+ hidden_activation_kwargs: Optional[Dict] = None,
774
+ complex_mask: bool = True,
775
+ n_fft: int = 2048,
776
+ win_length: Optional[int] = 2048,
777
+ hop_length: int = 512,
778
+ window_fn: str = "hann_window",
779
+ wkwargs: Optional[Dict] = None,
780
+ power: Optional[int] = None,
781
+ center: bool = True,
782
+ normalized: bool = True,
783
+ pad_mode: str = "constant",
784
+ onesided: bool = True,
785
+ n_bands: int = None,
786
+ ) -> None:
787
+ super().__init__(
788
+ stems=stems,
789
+ band_specs=band_specs,
790
+ fs=fs,
791
+ n_fft=n_fft,
792
+ win_length=win_length,
793
+ hop_length=hop_length,
794
+ window_fn=window_fn,
795
+ wkwargs=wkwargs,
796
+ power=power,
797
+ center=center,
798
+ normalized=normalized,
799
+ pad_mode=pad_mode,
800
+ onesided=onesided,
801
+ n_bands=n_bands,
802
+ )
803
+
804
+ self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN(
805
+ stems=stems,
806
+ band_specs=self.band_specs,
807
+ in_channel=in_channel,
808
+ require_no_overlap=require_no_overlap,
809
+ require_no_gap=require_no_gap,
810
+ normalize_channel_independently=normalize_channel_independently,
811
+ treat_channel_as_feature=treat_channel_as_feature,
812
+ n_sqm_modules=n_sqm_modules,
813
+ emb_dim=emb_dim,
814
+ rnn_dim=rnn_dim,
815
+ bidirectional=bidirectional,
816
+ rnn_type=rnn_type,
817
+ mlp_dim=mlp_dim,
818
+ hidden_activation=hidden_activation,
819
+ hidden_activation_kwargs=hidden_activation_kwargs,
820
+ complex_mask=complex_mask,
821
+ overlapping_band=self.overlapping_band,
822
+ freq_weights=self.freq_weights,
823
+ n_freq=n_fft // 2 + 1,
824
+ mask_kernel_freq=mask_kernel_freq,
825
+ mask_kernel_time=mask_kernel_time,
826
+ conv_kernel_freq=conv_kernel_freq,
827
+ conv_kernel_time=conv_kernel_time,
828
+ kernel_norm_mlp_version=kernel_norm_mlp_version,
829
+ )
mvsepless/models/bandit/core/utils/__init__.py ADDED
File without changes
mvsepless/models/bandit/core/utils/audio.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+
3
+ from tqdm.auto import tqdm
4
+ from typing import Callable, Dict, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+
12
+ @torch.jit.script
13
+ def merge(
14
+ combined: torch.Tensor,
15
+ original_batch_size: int,
16
+ n_channel: int,
17
+ n_chunks: int,
18
+ chunk_size: int,
19
+ ):
20
+ combined = torch.reshape(
21
+ combined, (original_batch_size, n_chunks, n_channel, chunk_size)
22
+ )
23
+ combined = torch.permute(combined, (0, 2, 3, 1)).reshape(
24
+ original_batch_size * n_channel, chunk_size, n_chunks
25
+ )
26
+
27
+ return combined
28
+
29
+
30
+ @torch.jit.script
31
+ def unfold(
32
+ padded_audio: torch.Tensor,
33
+ original_batch_size: int,
34
+ n_channel: int,
35
+ chunk_size: int,
36
+ hop_size: int,
37
+ ) -> torch.Tensor:
38
+
39
+ unfolded_input = F.unfold(
40
+ padded_audio[:, :, None, :], kernel_size=(1, chunk_size), stride=(1, hop_size)
41
+ )
42
+
43
+ _, _, n_chunks = unfolded_input.shape
44
+ unfolded_input = unfolded_input.view(
45
+ original_batch_size, n_channel, chunk_size, n_chunks
46
+ )
47
+ unfolded_input = torch.permute(unfolded_input, (0, 3, 1, 2)).reshape(
48
+ original_batch_size * n_chunks, n_channel, chunk_size
49
+ )
50
+
51
+ return unfolded_input
52
+
53
+
54
+ @torch.jit.script
55
+ # @torch.compile
56
+ def merge_chunks_all(
57
+ combined: torch.Tensor,
58
+ original_batch_size: int,
59
+ n_channel: int,
60
+ n_samples: int,
61
+ n_padded_samples: int,
62
+ n_chunks: int,
63
+ chunk_size: int,
64
+ hop_size: int,
65
+ edge_frame_pad_sizes: Tuple[int, int],
66
+ standard_window: torch.Tensor,
67
+ first_window: torch.Tensor,
68
+ last_window: torch.Tensor,
69
+ ):
70
+ combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
71
+
72
+ combined = combined * standard_window[:, None].to(combined.device)
73
+
74
+ combined = F.fold(
75
+ combined.to(torch.float32),
76
+ output_size=(1, n_padded_samples),
77
+ kernel_size=(1, chunk_size),
78
+ stride=(1, hop_size),
79
+ )
80
+
81
+ combined = combined.view(original_batch_size, n_channel, n_padded_samples)
82
+
83
+ pad_front, pad_back = edge_frame_pad_sizes
84
+ combined = combined[..., pad_front:-pad_back]
85
+
86
+ combined = combined[..., :n_samples]
87
+
88
+ return combined
89
+
90
+ # @torch.jit.script
91
+
92
+
93
+ def merge_chunks_edge(
94
+ combined: torch.Tensor,
95
+ original_batch_size: int,
96
+ n_channel: int,
97
+ n_samples: int,
98
+ n_padded_samples: int,
99
+ n_chunks: int,
100
+ chunk_size: int,
101
+ hop_size: int,
102
+ edge_frame_pad_sizes: Tuple[int, int],
103
+ standard_window: torch.Tensor,
104
+ first_window: torch.Tensor,
105
+ last_window: torch.Tensor,
106
+ ):
107
+ combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
108
+
109
+ combined[..., 0] = combined[..., 0] * first_window
110
+ combined[..., -1] = combined[..., -1] * last_window
111
+ combined[..., 1:-1] = combined[..., 1:-1] * standard_window[:, None]
112
+
113
+ combined = F.fold(
114
+ combined,
115
+ output_size=(1, n_padded_samples),
116
+ kernel_size=(1, chunk_size),
117
+ stride=(1, hop_size),
118
+ )
119
+
120
+ combined = combined.view(original_batch_size, n_channel, n_padded_samples)
121
+
122
+ combined = combined[..., :n_samples]
123
+
124
+ return combined
125
+
126
+
127
+ class BaseFader(nn.Module):
128
+ def __init__(
129
+ self,
130
+ chunk_size_second: float,
131
+ hop_size_second: float,
132
+ fs: int,
133
+ fade_edge_frames: bool,
134
+ batch_size: int,
135
+ ) -> None:
136
+ super().__init__()
137
+
138
+ self.chunk_size = int(chunk_size_second * fs)
139
+ self.hop_size = int(hop_size_second * fs)
140
+ self.overlap_size = self.chunk_size - self.hop_size
141
+ self.fade_edge_frames = fade_edge_frames
142
+ self.batch_size = batch_size
143
+
144
+ # @torch.jit.script
145
+ def prepare(self, audio):
146
+
147
+ if self.fade_edge_frames:
148
+ audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect")
149
+
150
+ n_samples = audio.shape[-1]
151
+ n_chunks = int(np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1)
152
+
153
+ padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size
154
+ pad_size = padded_size - n_samples
155
+
156
+ padded_audio = F.pad(audio, (0, pad_size))
157
+
158
+ return padded_audio, n_chunks
159
+
160
+ def forward(
161
+ self,
162
+ audio: torch.Tensor,
163
+ model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
164
+ ):
165
+
166
+ original_dtype = audio.dtype
167
+ original_device = audio.device
168
+
169
+ audio = audio.to("cpu")
170
+
171
+ original_batch_size, n_channel, n_samples = audio.shape
172
+ padded_audio, n_chunks = self.prepare(audio)
173
+ del audio
174
+ n_padded_samples = padded_audio.shape[-1]
175
+
176
+ if n_channel > 1:
177
+ padded_audio = padded_audio.view(
178
+ original_batch_size * n_channel, 1, n_padded_samples
179
+ )
180
+
181
+ unfolded_input = unfold(
182
+ padded_audio, original_batch_size, n_channel, self.chunk_size, self.hop_size
183
+ )
184
+
185
+ n_total_chunks, n_channel, chunk_size = unfolded_input.shape
186
+
187
+ n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int)
188
+
189
+ chunks_in = [
190
+ unfolded_input[b * self.batch_size : (b + 1) * self.batch_size, ...].clone()
191
+ for b in range(n_batch)
192
+ ]
193
+
194
+ all_chunks_out = defaultdict(
195
+ lambda: torch.zeros_like(unfolded_input, device="cpu")
196
+ )
197
+
198
+ # for b, cin in enumerate(tqdm(chunks_in)):
199
+ for b, cin in enumerate(chunks_in):
200
+ if torch.allclose(cin, torch.tensor(0.0)):
201
+ del cin
202
+ continue
203
+
204
+ chunks_out = model_fn(cin.to(original_device))
205
+ del cin
206
+ for s, c in chunks_out.items():
207
+ all_chunks_out[s][
208
+ b * self.batch_size : (b + 1) * self.batch_size, ...
209
+ ] = c.cpu()
210
+ del chunks_out
211
+
212
+ del unfolded_input
213
+ del padded_audio
214
+
215
+ if self.fade_edge_frames:
216
+ fn = merge_chunks_all
217
+ else:
218
+ fn = merge_chunks_edge
219
+ outputs = {}
220
+
221
+ torch.cuda.empty_cache()
222
+
223
+ for s, c in all_chunks_out.items():
224
+ combined: torch.Tensor = fn(
225
+ c,
226
+ original_batch_size,
227
+ n_channel,
228
+ n_samples,
229
+ n_padded_samples,
230
+ n_chunks,
231
+ self.chunk_size,
232
+ self.hop_size,
233
+ self.edge_frame_pad_sizes,
234
+ self.standard_window,
235
+ self.__dict__.get("first_window", self.standard_window),
236
+ self.__dict__.get("last_window", self.standard_window),
237
+ )
238
+
239
+ outputs[s] = combined.to(dtype=original_dtype, device=original_device)
240
+
241
+ return {"audio": outputs}
242
+
243
+ #
244
+ # def old_forward(
245
+ # self,
246
+ # audio: torch.Tensor,
247
+ # model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
248
+ # ):
249
+ #
250
+ # n_samples = audio.shape[-1]
251
+ # original_batch_size = audio.shape[0]
252
+ #
253
+ # padded_audio, n_chunks = self.prepare(audio)
254
+ #
255
+ # ndim = padded_audio.ndim
256
+ # broadcaster = [1 for _ in range(ndim - 1)] + [self.chunk_size]
257
+ #
258
+ # outputs = defaultdict(
259
+ # lambda: torch.zeros_like(
260
+ # padded_audio, device=audio.device, dtype=torch.float64
261
+ # )
262
+ # )
263
+ #
264
+ # all_chunks_out = []
265
+ # len_chunks_in = []
266
+ #
267
+ # batch_size_ = int(self.batch_size // original_batch_size)
268
+ # for b in range(int(np.ceil(n_chunks / batch_size_))):
269
+ # chunks_in = []
270
+ # for j in range(batch_size_):
271
+ # i = b * batch_size_ + j
272
+ # if i == n_chunks:
273
+ # break
274
+ #
275
+ # start = i * hop_size
276
+ # end = start + self.chunk_size
277
+ # chunk_in = padded_audio[..., start:end]
278
+ # chunks_in.append(chunk_in)
279
+ #
280
+ # chunks_in = torch.concat(chunks_in, dim=0)
281
+ # chunks_out = model_fn(chunks_in)
282
+ # all_chunks_out.append(chunks_out)
283
+ # len_chunks_in.append(len(chunks_in))
284
+ #
285
+ # for b, (chunks_out, lci) in enumerate(
286
+ # zip(all_chunks_out, len_chunks_in)
287
+ # ):
288
+ # for stem in chunks_out:
289
+ # for j in range(lci // original_batch_size):
290
+ # i = b * batch_size_ + j
291
+ #
292
+ # if self.fade_edge_frames:
293
+ # window = self.standard_window
294
+ # else:
295
+ # if i == 0:
296
+ # window = self.first_window
297
+ # elif i == n_chunks - 1:
298
+ # window = self.last_window
299
+ # else:
300
+ # window = self.standard_window
301
+ #
302
+ # start = i * hop_size
303
+ # end = start + self.chunk_size
304
+ #
305
+ # chunk_out = chunks_out[stem][j * original_batch_size: (j + 1) * original_batch_size,
306
+ # ...]
307
+ # contrib = window.view(*broadcaster) * chunk_out
308
+ # outputs[stem][..., start:end] = (
309
+ # outputs[stem][..., start:end] + contrib
310
+ # )
311
+ #
312
+ # if self.fade_edge_frames:
313
+ # pad_front, pad_back = self.edge_frame_pad_sizes
314
+ # outputs = {k: v[..., pad_front:-pad_back] for k, v in
315
+ # outputs.items()}
316
+ #
317
+ # outputs = {k: v[..., :n_samples].to(audio.dtype) for k, v in
318
+ # outputs.items()}
319
+ #
320
+ # return {
321
+ # "audio": outputs
322
+ # }
323
+
324
+
325
+ class LinearFader(BaseFader):
326
+ def __init__(
327
+ self,
328
+ chunk_size_second: float,
329
+ hop_size_second: float,
330
+ fs: int,
331
+ fade_edge_frames: bool = False,
332
+ batch_size: int = 1,
333
+ ) -> None:
334
+
335
+ assert hop_size_second >= chunk_size_second / 2
336
+
337
+ super().__init__(
338
+ chunk_size_second=chunk_size_second,
339
+ hop_size_second=hop_size_second,
340
+ fs=fs,
341
+ fade_edge_frames=fade_edge_frames,
342
+ batch_size=batch_size,
343
+ )
344
+
345
+ in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1]
346
+ out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:]
347
+ center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size)
348
+ inout_ones = torch.ones(self.overlap_size)
349
+
350
+ # using nn.Parameters allows lightning to take care of devices for us
351
+ self.register_buffer(
352
+ "standard_window", torch.concat([in_fade, center_ones, out_fade])
353
+ )
354
+
355
+ self.fade_edge_frames = fade_edge_frames
356
+ self.edge_frame_pad_size = (self.overlap_size, self.overlap_size)
357
+
358
+ if not self.fade_edge_frames:
359
+ self.first_window = nn.Parameter(
360
+ torch.concat([inout_ones, center_ones, out_fade]), requires_grad=False
361
+ )
362
+ self.last_window = nn.Parameter(
363
+ torch.concat([in_fade, center_ones, inout_ones]), requires_grad=False
364
+ )
365
+
366
+
367
+ class OverlapAddFader(BaseFader):
368
+ def __init__(
369
+ self,
370
+ window_type: str,
371
+ chunk_size_second: float,
372
+ hop_size_second: float,
373
+ fs: int,
374
+ batch_size: int = 1,
375
+ ) -> None:
376
+ assert (chunk_size_second / hop_size_second) % 2 == 0
377
+ assert int(chunk_size_second * fs) % 2 == 0
378
+
379
+ super().__init__(
380
+ chunk_size_second=chunk_size_second,
381
+ hop_size_second=hop_size_second,
382
+ fs=fs,
383
+ fade_edge_frames=True,
384
+ batch_size=batch_size,
385
+ )
386
+
387
+ self.hop_multiplier = self.chunk_size / (2 * self.hop_size)
388
+ # print(f"hop multiplier: {self.hop_multiplier}")
389
+
390
+ self.edge_frame_pad_sizes = (2 * self.overlap_size, 2 * self.overlap_size)
391
+
392
+ self.register_buffer(
393
+ "standard_window",
394
+ torch.windows.__dict__[window_type](
395
+ self.chunk_size,
396
+ sym=False, # dtype=torch.float64
397
+ )
398
+ / self.hop_multiplier,
399
+ )
400
+
401
+
402
+ if __name__ == "__main__":
403
+ import torchaudio as ta
404
+
405
+ fs = 44100
406
+ ola = OverlapAddFader("hann", 6.0, 1.0, fs, batch_size=16)
407
+ audio_, _ = ta.load(
408
+ "$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too " "Much/vocals.wav"
409
+ )
410
+ audio_ = audio_[None, ...]
411
+ out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"]
412
+ print(torch.allclose(out, audio_))
mvsepless/models/bandit/model_from_config.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os.path
3
+ import torch
4
+
5
+ import yaml
6
+ from ml_collections import ConfigDict
7
+
8
+ torch.set_float32_matmul_precision("medium")
9
+
10
+
11
+ def get_model(
12
+ config_path,
13
+ weights_path,
14
+ device,
15
+ ):
16
+ from .core.model import MultiMaskMultiSourceBandSplitRNNSimple
17
+
18
+ f = open(config_path)
19
+ config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
20
+ f.close()
21
+
22
+ model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model)
23
+ d = torch.load(code_path + "model_bandit_plus_dnr_sdr_11.47.chpt")
24
+ model.load_state_dict(d)
25
+ model.to(device)
26
+ return model, config
mvsepless/models/bandit_v2/bandit.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+
3
+ import torch
4
+ import torchaudio as ta
5
+ from torch import nn
6
+ import pytorch_lightning as pl
7
+
8
+ from .bandsplit import BandSplitModule
9
+ from .maskestim import OverlappingMaskEstimationModule
10
+ from .tfmodel import SeqBandModellingModule
11
+ from .utils import MusicalBandsplitSpecification
12
+
13
+
14
+ class BaseEndToEndModule(pl.LightningModule):
15
+ def __init__(
16
+ self,
17
+ ) -> None:
18
+ super().__init__()
19
+
20
+
21
+ class BaseBandit(BaseEndToEndModule):
22
+ def __init__(
23
+ self,
24
+ in_channels: int,
25
+ fs: int,
26
+ band_type: str = "musical",
27
+ n_bands: int = 64,
28
+ require_no_overlap: bool = False,
29
+ require_no_gap: bool = True,
30
+ normalize_channel_independently: bool = False,
31
+ treat_channel_as_feature: bool = True,
32
+ n_sqm_modules: int = 12,
33
+ emb_dim: int = 128,
34
+ rnn_dim: int = 256,
35
+ bidirectional: bool = True,
36
+ rnn_type: str = "LSTM",
37
+ n_fft: int = 2048,
38
+ win_length: Optional[int] = 2048,
39
+ hop_length: int = 512,
40
+ window_fn: str = "hann_window",
41
+ wkwargs: Optional[Dict] = None,
42
+ power: Optional[int] = None,
43
+ center: bool = True,
44
+ normalized: bool = True,
45
+ pad_mode: str = "constant",
46
+ onesided: bool = True,
47
+ ):
48
+ super().__init__()
49
+
50
+ self.in_channels = in_channels
51
+
52
+ self.instantitate_spectral(
53
+ n_fft=n_fft,
54
+ win_length=win_length,
55
+ hop_length=hop_length,
56
+ window_fn=window_fn,
57
+ wkwargs=wkwargs,
58
+ power=power,
59
+ normalized=normalized,
60
+ center=center,
61
+ pad_mode=pad_mode,
62
+ onesided=onesided,
63
+ )
64
+
65
+ self.instantiate_bandsplit(
66
+ in_channels=in_channels,
67
+ band_type=band_type,
68
+ n_bands=n_bands,
69
+ require_no_overlap=require_no_overlap,
70
+ require_no_gap=require_no_gap,
71
+ normalize_channel_independently=normalize_channel_independently,
72
+ treat_channel_as_feature=treat_channel_as_feature,
73
+ emb_dim=emb_dim,
74
+ n_fft=n_fft,
75
+ fs=fs,
76
+ )
77
+
78
+ self.instantiate_tf_modelling(
79
+ n_sqm_modules=n_sqm_modules,
80
+ emb_dim=emb_dim,
81
+ rnn_dim=rnn_dim,
82
+ bidirectional=bidirectional,
83
+ rnn_type=rnn_type,
84
+ )
85
+
86
+ def instantitate_spectral(
87
+ self,
88
+ n_fft: int = 2048,
89
+ win_length: Optional[int] = 2048,
90
+ hop_length: int = 512,
91
+ window_fn: str = "hann_window",
92
+ wkwargs: Optional[Dict] = None,
93
+ power: Optional[int] = None,
94
+ normalized: bool = True,
95
+ center: bool = True,
96
+ pad_mode: str = "constant",
97
+ onesided: bool = True,
98
+ ):
99
+ assert power is None
100
+
101
+ window_fn = torch.__dict__[window_fn]
102
+
103
+ self.stft = ta.transforms.Spectrogram(
104
+ n_fft=n_fft,
105
+ win_length=win_length,
106
+ hop_length=hop_length,
107
+ pad_mode=pad_mode,
108
+ pad=0,
109
+ window_fn=window_fn,
110
+ wkwargs=wkwargs,
111
+ power=power,
112
+ normalized=normalized,
113
+ center=center,
114
+ onesided=onesided,
115
+ )
116
+
117
+ self.istft = ta.transforms.InverseSpectrogram(
118
+ n_fft=n_fft,
119
+ win_length=win_length,
120
+ hop_length=hop_length,
121
+ pad_mode=pad_mode,
122
+ pad=0,
123
+ window_fn=window_fn,
124
+ wkwargs=wkwargs,
125
+ normalized=normalized,
126
+ center=center,
127
+ onesided=onesided,
128
+ )
129
+
130
+ def instantiate_bandsplit(
131
+ self,
132
+ in_channels: int,
133
+ band_type: str = "musical",
134
+ n_bands: int = 64,
135
+ require_no_overlap: bool = False,
136
+ require_no_gap: bool = True,
137
+ normalize_channel_independently: bool = False,
138
+ treat_channel_as_feature: bool = True,
139
+ emb_dim: int = 128,
140
+ n_fft: int = 2048,
141
+ fs: int = 44100,
142
+ ):
143
+ assert band_type == "musical"
144
+
145
+ self.band_specs = MusicalBandsplitSpecification(
146
+ nfft=n_fft, fs=fs, n_bands=n_bands
147
+ )
148
+
149
+ self.band_split = BandSplitModule(
150
+ in_channels=in_channels,
151
+ band_specs=self.band_specs.get_band_specs(),
152
+ require_no_overlap=require_no_overlap,
153
+ require_no_gap=require_no_gap,
154
+ normalize_channel_independently=normalize_channel_independently,
155
+ treat_channel_as_feature=treat_channel_as_feature,
156
+ emb_dim=emb_dim,
157
+ )
158
+
159
+ def instantiate_tf_modelling(
160
+ self,
161
+ n_sqm_modules: int = 12,
162
+ emb_dim: int = 128,
163
+ rnn_dim: int = 256,
164
+ bidirectional: bool = True,
165
+ rnn_type: str = "LSTM",
166
+ ):
167
+ try:
168
+ self.tf_model = torch.compile(
169
+ SeqBandModellingModule(
170
+ n_modules=n_sqm_modules,
171
+ emb_dim=emb_dim,
172
+ rnn_dim=rnn_dim,
173
+ bidirectional=bidirectional,
174
+ rnn_type=rnn_type,
175
+ ),
176
+ disable=True,
177
+ )
178
+ except Exception as e:
179
+ self.tf_model = SeqBandModellingModule(
180
+ n_modules=n_sqm_modules,
181
+ emb_dim=emb_dim,
182
+ rnn_dim=rnn_dim,
183
+ bidirectional=bidirectional,
184
+ rnn_type=rnn_type,
185
+ )
186
+
187
+ def mask(self, x, m):
188
+ return x * m
189
+
190
+ def forward(self, batch, mode="train"):
191
+ # Model takes mono as input we give stereo, so we do process of each channel independently
192
+ init_shape = batch.shape
193
+ if not isinstance(batch, dict):
194
+ mono = batch.view(-1, 1, batch.shape[-1])
195
+ batch = {"mixture": {"audio": mono}}
196
+
197
+ with torch.no_grad():
198
+ mixture = batch["mixture"]["audio"]
199
+
200
+ x = self.stft(mixture)
201
+ batch["mixture"]["spectrogram"] = x
202
+
203
+ if "sources" in batch.keys():
204
+ for stem in batch["sources"].keys():
205
+ s = batch["sources"][stem]["audio"]
206
+ s = self.stft(s)
207
+ batch["sources"][stem]["spectrogram"] = s
208
+
209
+ batch = self.separate(batch)
210
+
211
+ if 1:
212
+ b = []
213
+ for s in self.stems:
214
+ # We need to obtain stereo again
215
+ r = batch["estimates"][s]["audio"].view(
216
+ -1, init_shape[1], init_shape[2]
217
+ )
218
+ b.append(r)
219
+ # And we need to return back tensor and not independent stems
220
+ batch = torch.stack(b, dim=1)
221
+ return batch
222
+
223
+ def encode(self, batch):
224
+ x = batch["mixture"]["spectrogram"]
225
+ length = batch["mixture"]["audio"].shape[-1]
226
+
227
+ z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
228
+ q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
229
+
230
+ return x, q, length
231
+
232
+ def separate(self, batch):
233
+ raise NotImplementedError
234
+
235
+
236
+ class Bandit(BaseBandit):
237
+ def __init__(
238
+ self,
239
+ in_channels: int,
240
+ stems: List[str],
241
+ band_type: str = "musical",
242
+ n_bands: int = 64,
243
+ require_no_overlap: bool = False,
244
+ require_no_gap: bool = True,
245
+ normalize_channel_independently: bool = False,
246
+ treat_channel_as_feature: bool = True,
247
+ n_sqm_modules: int = 12,
248
+ emb_dim: int = 128,
249
+ rnn_dim: int = 256,
250
+ bidirectional: bool = True,
251
+ rnn_type: str = "LSTM",
252
+ mlp_dim: int = 512,
253
+ hidden_activation: str = "Tanh",
254
+ hidden_activation_kwargs: Dict | None = None,
255
+ complex_mask: bool = True,
256
+ use_freq_weights: bool = True,
257
+ n_fft: int = 2048,
258
+ win_length: int | None = 2048,
259
+ hop_length: int = 512,
260
+ window_fn: str = "hann_window",
261
+ wkwargs: Dict | None = None,
262
+ power: int | None = None,
263
+ center: bool = True,
264
+ normalized: bool = True,
265
+ pad_mode: str = "constant",
266
+ onesided: bool = True,
267
+ fs: int = 44100,
268
+ stft_precisions="32",
269
+ bandsplit_precisions="bf16",
270
+ tf_model_precisions="bf16",
271
+ mask_estim_precisions="bf16",
272
+ ):
273
+ super().__init__(
274
+ in_channels=in_channels,
275
+ band_type=band_type,
276
+ n_bands=n_bands,
277
+ require_no_overlap=require_no_overlap,
278
+ require_no_gap=require_no_gap,
279
+ normalize_channel_independently=normalize_channel_independently,
280
+ treat_channel_as_feature=treat_channel_as_feature,
281
+ n_sqm_modules=n_sqm_modules,
282
+ emb_dim=emb_dim,
283
+ rnn_dim=rnn_dim,
284
+ bidirectional=bidirectional,
285
+ rnn_type=rnn_type,
286
+ n_fft=n_fft,
287
+ win_length=win_length,
288
+ hop_length=hop_length,
289
+ window_fn=window_fn,
290
+ wkwargs=wkwargs,
291
+ power=power,
292
+ center=center,
293
+ normalized=normalized,
294
+ pad_mode=pad_mode,
295
+ onesided=onesided,
296
+ fs=fs,
297
+ )
298
+
299
+ self.stems = stems
300
+
301
+ self.instantiate_mask_estim(
302
+ in_channels=in_channels,
303
+ stems=stems,
304
+ emb_dim=emb_dim,
305
+ mlp_dim=mlp_dim,
306
+ hidden_activation=hidden_activation,
307
+ hidden_activation_kwargs=hidden_activation_kwargs,
308
+ complex_mask=complex_mask,
309
+ n_freq=n_fft // 2 + 1,
310
+ use_freq_weights=use_freq_weights,
311
+ )
312
+
313
+ def instantiate_mask_estim(
314
+ self,
315
+ in_channels: int,
316
+ stems: List[str],
317
+ emb_dim: int,
318
+ mlp_dim: int,
319
+ hidden_activation: str,
320
+ hidden_activation_kwargs: Optional[Dict] = None,
321
+ complex_mask: bool = True,
322
+ n_freq: Optional[int] = None,
323
+ use_freq_weights: bool = False,
324
+ ):
325
+ if hidden_activation_kwargs is None:
326
+ hidden_activation_kwargs = {}
327
+
328
+ assert n_freq is not None
329
+
330
+ self.mask_estim = nn.ModuleDict(
331
+ {
332
+ stem: OverlappingMaskEstimationModule(
333
+ band_specs=self.band_specs.get_band_specs(),
334
+ freq_weights=self.band_specs.get_freq_weights(),
335
+ n_freq=n_freq,
336
+ emb_dim=emb_dim,
337
+ mlp_dim=mlp_dim,
338
+ in_channels=in_channels,
339
+ hidden_activation=hidden_activation,
340
+ hidden_activation_kwargs=hidden_activation_kwargs,
341
+ complex_mask=complex_mask,
342
+ use_freq_weights=use_freq_weights,
343
+ )
344
+ for stem in stems
345
+ }
346
+ )
347
+
348
+ def separate(self, batch):
349
+ batch["estimates"] = {}
350
+
351
+ x, q, length = self.encode(batch)
352
+
353
+ for stem, mem in self.mask_estim.items():
354
+ m = mem(q)
355
+
356
+ s = self.mask(x, m.to(x.dtype))
357
+ s = torch.reshape(s, x.shape)
358
+ batch["estimates"][stem] = {
359
+ "audio": self.istft(s, length),
360
+ "spectrogram": s,
361
+ }
362
+
363
+ return batch
mvsepless/models/bandit_v2/bandsplit.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.utils.checkpoint import checkpoint_sequential
6
+
7
+ from .utils import (
8
+ band_widths_from_specs,
9
+ check_no_gap,
10
+ check_no_overlap,
11
+ check_nonzero_bandwidth,
12
+ )
13
+
14
+
15
+ class NormFC(nn.Module):
16
+ def __init__(
17
+ self,
18
+ emb_dim: int,
19
+ bandwidth: int,
20
+ in_channels: int,
21
+ normalize_channel_independently: bool = False,
22
+ treat_channel_as_feature: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+
26
+ if not treat_channel_as_feature:
27
+ raise NotImplementedError
28
+
29
+ self.treat_channel_as_feature = treat_channel_as_feature
30
+
31
+ if normalize_channel_independently:
32
+ raise NotImplementedError
33
+
34
+ reim = 2
35
+
36
+ norm = nn.LayerNorm(in_channels * bandwidth * reim)
37
+
38
+ fc_in = bandwidth * reim
39
+
40
+ if treat_channel_as_feature:
41
+ fc_in *= in_channels
42
+ else:
43
+ assert emb_dim % in_channels == 0
44
+ emb_dim = emb_dim // in_channels
45
+
46
+ fc = nn.Linear(fc_in, emb_dim)
47
+
48
+ self.combined = nn.Sequential(norm, fc)
49
+
50
+ def forward(self, xb):
51
+ return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False)
52
+
53
+
54
+ class BandSplitModule(nn.Module):
55
+ def __init__(
56
+ self,
57
+ band_specs: List[Tuple[float, float]],
58
+ emb_dim: int,
59
+ in_channels: int,
60
+ require_no_overlap: bool = False,
61
+ require_no_gap: bool = True,
62
+ normalize_channel_independently: bool = False,
63
+ treat_channel_as_feature: bool = True,
64
+ ) -> None:
65
+ super().__init__()
66
+
67
+ check_nonzero_bandwidth(band_specs)
68
+
69
+ if require_no_gap:
70
+ check_no_gap(band_specs)
71
+
72
+ if require_no_overlap:
73
+ check_no_overlap(band_specs)
74
+
75
+ self.band_specs = band_specs
76
+ # list of [fstart, fend) in index.
77
+ # Note that fend is exclusive.
78
+ self.band_widths = band_widths_from_specs(band_specs)
79
+ self.n_bands = len(band_specs)
80
+ self.emb_dim = emb_dim
81
+
82
+ try:
83
+ self.norm_fc_modules = nn.ModuleList(
84
+ [ # type: ignore
85
+ torch.compile(
86
+ NormFC(
87
+ emb_dim=emb_dim,
88
+ bandwidth=bw,
89
+ in_channels=in_channels,
90
+ normalize_channel_independently=normalize_channel_independently,
91
+ treat_channel_as_feature=treat_channel_as_feature,
92
+ ),
93
+ disable=True,
94
+ )
95
+ for bw in self.band_widths
96
+ ]
97
+ )
98
+ except Exception as e:
99
+ self.norm_fc_modules = nn.ModuleList(
100
+ [ # type: ignore
101
+ NormFC(
102
+ emb_dim=emb_dim,
103
+ bandwidth=bw,
104
+ in_channels=in_channels,
105
+ normalize_channel_independently=normalize_channel_independently,
106
+ treat_channel_as_feature=treat_channel_as_feature,
107
+ )
108
+ for bw in self.band_widths
109
+ ]
110
+ )
111
+
112
+ def forward(self, x: torch.Tensor):
113
+ # x = complex spectrogram (batch, in_chan, n_freq, n_time)
114
+
115
+ batch, in_chan, band_width, n_time = x.shape
116
+
117
+ z = torch.zeros(
118
+ size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
119
+ )
120
+
121
+ x = torch.permute(x, (0, 3, 1, 2)).contiguous()
122
+
123
+ for i, nfm in enumerate(self.norm_fc_modules):
124
+ fstart, fend = self.band_specs[i]
125
+ xb = x[:, :, :, fstart:fend]
126
+ xb = torch.view_as_real(xb)
127
+ xb = torch.reshape(xb, (batch, n_time, -1))
128
+ z[:, i, :, :] = nfm(xb)
129
+
130
+ return z
mvsepless/models/bandit_v2/film.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+
4
+
5
+ class FiLM(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x, gamma, beta):
10
+ return gamma * x + beta
11
+
12
+
13
+ class BTFBroadcastedFiLM(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+ self.film = FiLM()
17
+
18
+ def forward(self, x, gamma, beta):
19
+
20
+ gamma = gamma[None, None, None, :]
21
+ beta = beta[None, None, None, :]
22
+
23
+ return self.film(x, gamma, beta)
mvsepless/models/bandit_v2/maskestim.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Type
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.modules import activation
6
+ from torch.utils.checkpoint import checkpoint_sequential
7
+
8
+ from .utils import (
9
+ band_widths_from_specs,
10
+ check_no_gap,
11
+ check_no_overlap,
12
+ check_nonzero_bandwidth,
13
+ )
14
+
15
+
16
+ class BaseNormMLP(nn.Module):
17
+ def __init__(
18
+ self,
19
+ emb_dim: int,
20
+ mlp_dim: int,
21
+ bandwidth: int,
22
+ in_channels: Optional[int],
23
+ hidden_activation: str = "Tanh",
24
+ hidden_activation_kwargs=None,
25
+ complex_mask: bool = True,
26
+ ):
27
+ super().__init__()
28
+ if hidden_activation_kwargs is None:
29
+ hidden_activation_kwargs = {}
30
+ self.hidden_activation_kwargs = hidden_activation_kwargs
31
+ self.norm = nn.LayerNorm(emb_dim)
32
+ self.hidden = nn.Sequential(
33
+ nn.Linear(in_features=emb_dim, out_features=mlp_dim),
34
+ activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
35
+ )
36
+
37
+ self.bandwidth = bandwidth
38
+ self.in_channels = in_channels
39
+
40
+ self.complex_mask = complex_mask
41
+ self.reim = 2 if complex_mask else 1
42
+ self.glu_mult = 2
43
+
44
+
45
+ class NormMLP(BaseNormMLP):
46
+ def __init__(
47
+ self,
48
+ emb_dim: int,
49
+ mlp_dim: int,
50
+ bandwidth: int,
51
+ in_channels: Optional[int],
52
+ hidden_activation: str = "Tanh",
53
+ hidden_activation_kwargs=None,
54
+ complex_mask: bool = True,
55
+ ) -> None:
56
+ super().__init__(
57
+ emb_dim=emb_dim,
58
+ mlp_dim=mlp_dim,
59
+ bandwidth=bandwidth,
60
+ in_channels=in_channels,
61
+ hidden_activation=hidden_activation,
62
+ hidden_activation_kwargs=hidden_activation_kwargs,
63
+ complex_mask=complex_mask,
64
+ )
65
+
66
+ self.output = nn.Sequential(
67
+ nn.Linear(
68
+ in_features=mlp_dim,
69
+ out_features=bandwidth * in_channels * self.reim * 2,
70
+ ),
71
+ nn.GLU(dim=-1),
72
+ )
73
+
74
+ try:
75
+ self.combined = torch.compile(
76
+ nn.Sequential(self.norm, self.hidden, self.output), disable=True
77
+ )
78
+ except Exception as e:
79
+ self.combined = nn.Sequential(self.norm, self.hidden, self.output)
80
+
81
+ def reshape_output(self, mb):
82
+ # print(mb.shape)
83
+ batch, n_time, _ = mb.shape
84
+ if self.complex_mask:
85
+ mb = mb.reshape(
86
+ batch, n_time, self.in_channels, self.bandwidth, self.reim
87
+ ).contiguous()
88
+ # print(mb.shape)
89
+ mb = torch.view_as_complex(mb) # (batch, n_time, in_channels, bandwidth)
90
+ else:
91
+ mb = mb.reshape(batch, n_time, self.in_channels, self.bandwidth)
92
+
93
+ mb = torch.permute(mb, (0, 2, 3, 1)) # (batch, in_channels, bandwidth, n_time)
94
+
95
+ return mb
96
+
97
+ def forward(self, qb):
98
+ # qb = (batch, n_time, emb_dim)
99
+ # qb = self.norm(qb) # (batch, n_time, emb_dim)
100
+ # qb = self.hidden(qb) # (batch, n_time, mlp_dim)
101
+ # mb = self.output(qb) # (batch, n_time, bandwidth * in_channels * reim)
102
+
103
+ mb = checkpoint_sequential(self.combined, 2, qb, use_reentrant=False)
104
+ mb = self.reshape_output(mb) # (batch, in_channels, bandwidth, n_time)
105
+
106
+ return mb
107
+
108
+
109
+ class MaskEstimationModuleSuperBase(nn.Module):
110
+ pass
111
+
112
+
113
+ class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
114
+ def __init__(
115
+ self,
116
+ band_specs: List[Tuple[float, float]],
117
+ emb_dim: int,
118
+ mlp_dim: int,
119
+ in_channels: Optional[int],
120
+ hidden_activation: str = "Tanh",
121
+ hidden_activation_kwargs: Dict = None,
122
+ complex_mask: bool = True,
123
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
124
+ norm_mlp_kwargs: Dict = None,
125
+ ) -> None:
126
+ super().__init__()
127
+
128
+ self.band_widths = band_widths_from_specs(band_specs)
129
+ self.n_bands = len(band_specs)
130
+
131
+ if hidden_activation_kwargs is None:
132
+ hidden_activation_kwargs = {}
133
+
134
+ if norm_mlp_kwargs is None:
135
+ norm_mlp_kwargs = {}
136
+
137
+ self.norm_mlp = nn.ModuleList(
138
+ [
139
+ norm_mlp_cls(
140
+ bandwidth=self.band_widths[b],
141
+ emb_dim=emb_dim,
142
+ mlp_dim=mlp_dim,
143
+ in_channels=in_channels,
144
+ hidden_activation=hidden_activation,
145
+ hidden_activation_kwargs=hidden_activation_kwargs,
146
+ complex_mask=complex_mask,
147
+ **norm_mlp_kwargs,
148
+ )
149
+ for b in range(self.n_bands)
150
+ ]
151
+ )
152
+
153
+ def compute_masks(self, q):
154
+ batch, n_bands, n_time, emb_dim = q.shape
155
+
156
+ masks = []
157
+
158
+ for b, nmlp in enumerate(self.norm_mlp):
159
+ # print(f"maskestim/{b:02d}")
160
+ qb = q[:, b, :, :]
161
+ mb = nmlp(qb)
162
+ masks.append(mb)
163
+
164
+ return masks
165
+
166
+ def compute_mask(self, q, b):
167
+ batch, n_bands, n_time, emb_dim = q.shape
168
+ qb = q[:, b, :, :]
169
+ mb = self.norm_mlp[b](qb)
170
+ return mb
171
+
172
+
173
+ class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
174
+ def __init__(
175
+ self,
176
+ in_channels: int,
177
+ band_specs: List[Tuple[float, float]],
178
+ freq_weights: List[torch.Tensor],
179
+ n_freq: int,
180
+ emb_dim: int,
181
+ mlp_dim: int,
182
+ cond_dim: int = 0,
183
+ hidden_activation: str = "Tanh",
184
+ hidden_activation_kwargs: Dict = None,
185
+ complex_mask: bool = True,
186
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
187
+ norm_mlp_kwargs: Dict = None,
188
+ use_freq_weights: bool = False,
189
+ ) -> None:
190
+ check_nonzero_bandwidth(band_specs)
191
+ check_no_gap(band_specs)
192
+
193
+ if cond_dim > 0:
194
+ raise NotImplementedError
195
+
196
+ super().__init__(
197
+ band_specs=band_specs,
198
+ emb_dim=emb_dim + cond_dim,
199
+ mlp_dim=mlp_dim,
200
+ in_channels=in_channels,
201
+ hidden_activation=hidden_activation,
202
+ hidden_activation_kwargs=hidden_activation_kwargs,
203
+ complex_mask=complex_mask,
204
+ norm_mlp_cls=norm_mlp_cls,
205
+ norm_mlp_kwargs=norm_mlp_kwargs,
206
+ )
207
+
208
+ self.n_freq = n_freq
209
+ self.band_specs = band_specs
210
+ self.in_channels = in_channels
211
+
212
+ if freq_weights is not None and use_freq_weights:
213
+ for i, fw in enumerate(freq_weights):
214
+ self.register_buffer(f"freq_weights/{i}", fw)
215
+
216
+ self.use_freq_weights = use_freq_weights
217
+ else:
218
+ self.use_freq_weights = False
219
+
220
+ def forward(self, q):
221
+ # q = (batch, n_bands, n_time, emb_dim)
222
+
223
+ batch, n_bands, n_time, emb_dim = q.shape
224
+
225
+ masks = torch.zeros(
226
+ (batch, self.in_channels, self.n_freq, n_time),
227
+ device=q.device,
228
+ dtype=torch.complex64,
229
+ )
230
+
231
+ for im in range(n_bands):
232
+ fstart, fend = self.band_specs[im]
233
+
234
+ mask = self.compute_mask(q, im)
235
+
236
+ if self.use_freq_weights:
237
+ fw = self.get_buffer(f"freq_weights/{im}")[:, None]
238
+ mask = mask * fw
239
+ masks[:, :, fstart:fend, :] += mask
240
+
241
+ return masks
242
+
243
+
244
+ class MaskEstimationModule(OverlappingMaskEstimationModule):
245
+ def __init__(
246
+ self,
247
+ band_specs: List[Tuple[float, float]],
248
+ emb_dim: int,
249
+ mlp_dim: int,
250
+ in_channels: Optional[int],
251
+ hidden_activation: str = "Tanh",
252
+ hidden_activation_kwargs: Dict = None,
253
+ complex_mask: bool = True,
254
+ **kwargs,
255
+ ) -> None:
256
+ check_nonzero_bandwidth(band_specs)
257
+ check_no_gap(band_specs)
258
+ check_no_overlap(band_specs)
259
+ super().__init__(
260
+ in_channels=in_channels,
261
+ band_specs=band_specs,
262
+ freq_weights=None,
263
+ n_freq=None,
264
+ emb_dim=emb_dim,
265
+ mlp_dim=mlp_dim,
266
+ hidden_activation=hidden_activation,
267
+ hidden_activation_kwargs=hidden_activation_kwargs,
268
+ complex_mask=complex_mask,
269
+ )
270
+
271
+ def forward(self, q, cond=None):
272
+ # q = (batch, n_bands, n_time, emb_dim)
273
+
274
+ masks = self.compute_masks(
275
+ q
276
+ ) # [n_bands * (batch, in_channels, bandwidth, n_time)]
277
+
278
+ # TODO: currently this requires band specs to have no gap and no overlap
279
+ masks = torch.concat(masks, dim=2) # (batch, in_channels, n_freq, n_time)
280
+
281
+ return masks
mvsepless/models/bandit_v2/tfmodel.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+ import torch.backends.cuda
5
+ from torch import nn
6
+ from torch.nn.modules import rnn
7
+ from torch.utils.checkpoint import checkpoint_sequential
8
+
9
+
10
+ class TimeFrequencyModellingModule(nn.Module):
11
+ def __init__(self) -> None:
12
+ super().__init__()
13
+
14
+
15
+ class ResidualRNN(nn.Module):
16
+ def __init__(
17
+ self,
18
+ emb_dim: int,
19
+ rnn_dim: int,
20
+ bidirectional: bool = True,
21
+ rnn_type: str = "LSTM",
22
+ use_batch_trick: bool = True,
23
+ use_layer_norm: bool = True,
24
+ ) -> None:
25
+ # n_group is the size of the 2nd dim
26
+ super().__init__()
27
+
28
+ assert use_layer_norm
29
+ assert use_batch_trick
30
+
31
+ self.use_layer_norm = use_layer_norm
32
+ self.norm = nn.LayerNorm(emb_dim)
33
+ self.rnn = rnn.__dict__[rnn_type](
34
+ input_size=emb_dim,
35
+ hidden_size=rnn_dim,
36
+ num_layers=1,
37
+ batch_first=True,
38
+ bidirectional=bidirectional,
39
+ )
40
+
41
+ self.fc = nn.Linear(
42
+ in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
43
+ )
44
+
45
+ self.use_batch_trick = use_batch_trick
46
+ if not self.use_batch_trick:
47
+ warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
48
+
49
+ def forward(self, z):
50
+ # z = (batch, n_uncrossed, n_across, emb_dim)
51
+
52
+ z0 = torch.clone(z)
53
+ z = self.norm(z)
54
+
55
+ batch, n_uncrossed, n_across, emb_dim = z.shape
56
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
57
+ z = self.rnn(z)[0]
58
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
59
+
60
+ z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
61
+
62
+ z = z + z0
63
+
64
+ return z
65
+
66
+
67
+ class Transpose(nn.Module):
68
+ def __init__(self, dim0: int, dim1: int) -> None:
69
+ super().__init__()
70
+ self.dim0 = dim0
71
+ self.dim1 = dim1
72
+
73
+ def forward(self, z):
74
+ return z.transpose(self.dim0, self.dim1)
75
+
76
+
77
+ class SeqBandModellingModule(TimeFrequencyModellingModule):
78
+ def __init__(
79
+ self,
80
+ n_modules: int = 12,
81
+ emb_dim: int = 128,
82
+ rnn_dim: int = 256,
83
+ bidirectional: bool = True,
84
+ rnn_type: str = "LSTM",
85
+ parallel_mode=False,
86
+ ) -> None:
87
+ super().__init__()
88
+
89
+ self.n_modules = n_modules
90
+
91
+ if parallel_mode:
92
+ self.seqband = nn.ModuleList([])
93
+ for _ in range(n_modules):
94
+ self.seqband.append(
95
+ nn.ModuleList(
96
+ [
97
+ ResidualRNN(
98
+ emb_dim=emb_dim,
99
+ rnn_dim=rnn_dim,
100
+ bidirectional=bidirectional,
101
+ rnn_type=rnn_type,
102
+ ),
103
+ ResidualRNN(
104
+ emb_dim=emb_dim,
105
+ rnn_dim=rnn_dim,
106
+ bidirectional=bidirectional,
107
+ rnn_type=rnn_type,
108
+ ),
109
+ ]
110
+ )
111
+ )
112
+ else:
113
+ seqband = []
114
+ for _ in range(2 * n_modules):
115
+ seqband += [
116
+ ResidualRNN(
117
+ emb_dim=emb_dim,
118
+ rnn_dim=rnn_dim,
119
+ bidirectional=bidirectional,
120
+ rnn_type=rnn_type,
121
+ ),
122
+ Transpose(1, 2),
123
+ ]
124
+
125
+ self.seqband = nn.Sequential(*seqband)
126
+
127
+ self.parallel_mode = parallel_mode
128
+
129
+ def forward(self, z):
130
+ # z = (batch, n_bands, n_time, emb_dim)
131
+
132
+ if self.parallel_mode:
133
+ for sbm_pair in self.seqband:
134
+ # z: (batch, n_bands, n_time, emb_dim)
135
+ sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
136
+ zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
137
+ zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
138
+ z = zt + zf.transpose(1, 2)
139
+ else:
140
+ z = checkpoint_sequential(
141
+ self.seqband, self.n_modules, z, use_reentrant=False
142
+ )
143
+
144
+ q = z
145
+ return q # (batch, n_bands, n_time, emb_dim)
mvsepless/models/bandit_v2/utils.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import abstractmethod
3
+ from typing import Callable
4
+
5
+ import numpy as np
6
+ import torch
7
+ from librosa import hz_to_midi, midi_to_hz
8
+ from torchaudio import functional as taF
9
+
10
+ # from spafe.fbanks import bark_fbanks
11
+ # from spafe.utils.converters import erb2hz, hz2bark, hz2erb
12
+
13
+
14
+ def band_widths_from_specs(band_specs):
15
+ return [e - i for i, e in band_specs]
16
+
17
+
18
+ def check_nonzero_bandwidth(band_specs):
19
+ # pprint(band_specs)
20
+ for fstart, fend in band_specs:
21
+ if fend - fstart <= 0:
22
+ raise ValueError("Bands cannot be zero-width")
23
+
24
+
25
+ def check_no_overlap(band_specs):
26
+ fend_prev = -1
27
+ for fstart_curr, fend_curr in band_specs:
28
+ if fstart_curr <= fend_prev:
29
+ raise ValueError("Bands cannot overlap")
30
+
31
+
32
+ def check_no_gap(band_specs):
33
+ fstart, _ = band_specs[0]
34
+ assert fstart == 0
35
+
36
+ fend_prev = -1
37
+ for fstart_curr, fend_curr in band_specs:
38
+ if fstart_curr - fend_prev > 1:
39
+ raise ValueError("Bands cannot leave gap")
40
+ fend_prev = fend_curr
41
+
42
+
43
+ class BandsplitSpecification:
44
+ def __init__(self, nfft: int, fs: int) -> None:
45
+ self.fs = fs
46
+ self.nfft = nfft
47
+ self.nyquist = fs / 2
48
+ self.max_index = nfft // 2 + 1
49
+
50
+ self.split500 = self.hertz_to_index(500)
51
+ self.split1k = self.hertz_to_index(1000)
52
+ self.split2k = self.hertz_to_index(2000)
53
+ self.split4k = self.hertz_to_index(4000)
54
+ self.split8k = self.hertz_to_index(8000)
55
+ self.split16k = self.hertz_to_index(16000)
56
+ self.split20k = self.hertz_to_index(20000)
57
+
58
+ self.above20k = [(self.split20k, self.max_index)]
59
+ self.above16k = [(self.split16k, self.split20k)] + self.above20k
60
+
61
+ def index_to_hertz(self, index: int):
62
+ return index * self.fs / self.nfft
63
+
64
+ def hertz_to_index(self, hz: float, round: bool = True):
65
+ index = hz * self.nfft / self.fs
66
+
67
+ if round:
68
+ index = int(np.round(index))
69
+
70
+ return index
71
+
72
+ def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
73
+ band_specs = []
74
+ lower = start_index
75
+
76
+ while lower < end_index:
77
+ upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
78
+ upper = min(upper, end_index)
79
+
80
+ band_specs.append((lower, upper))
81
+ lower = upper
82
+
83
+ return band_specs
84
+
85
+ @abstractmethod
86
+ def get_band_specs(self):
87
+ raise NotImplementedError
88
+
89
+
90
+ class VocalBandsplitSpecification(BandsplitSpecification):
91
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
92
+ super().__init__(nfft=nfft, fs=fs)
93
+
94
+ self.version = version
95
+
96
+ def get_band_specs(self):
97
+ return getattr(self, f"version{self.version}")()
98
+
99
+ @property
100
+ def version1(self):
101
+ return self.get_band_specs_with_bandwidth(
102
+ start_index=0, end_index=self.max_index, bandwidth_hz=1000
103
+ )
104
+
105
+ def version2(self):
106
+ below16k = self.get_band_specs_with_bandwidth(
107
+ start_index=0, end_index=self.split16k, bandwidth_hz=1000
108
+ )
109
+ below20k = self.get_band_specs_with_bandwidth(
110
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
111
+ )
112
+
113
+ return below16k + below20k + self.above20k
114
+
115
+ def version3(self):
116
+ below8k = self.get_band_specs_with_bandwidth(
117
+ start_index=0, end_index=self.split8k, bandwidth_hz=1000
118
+ )
119
+ below16k = self.get_band_specs_with_bandwidth(
120
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
121
+ )
122
+
123
+ return below8k + below16k + self.above16k
124
+
125
+ def version4(self):
126
+ below1k = self.get_band_specs_with_bandwidth(
127
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
128
+ )
129
+ below8k = self.get_band_specs_with_bandwidth(
130
+ start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
131
+ )
132
+ below16k = self.get_band_specs_with_bandwidth(
133
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
134
+ )
135
+
136
+ return below1k + below8k + below16k + self.above16k
137
+
138
+ def version5(self):
139
+ below1k = self.get_band_specs_with_bandwidth(
140
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
141
+ )
142
+ below16k = self.get_band_specs_with_bandwidth(
143
+ start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
144
+ )
145
+ below20k = self.get_band_specs_with_bandwidth(
146
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
147
+ )
148
+ return below1k + below16k + below20k + self.above20k
149
+
150
+ def version6(self):
151
+ below1k = self.get_band_specs_with_bandwidth(
152
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
153
+ )
154
+ below4k = self.get_band_specs_with_bandwidth(
155
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
156
+ )
157
+ below8k = self.get_band_specs_with_bandwidth(
158
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
159
+ )
160
+ below16k = self.get_band_specs_with_bandwidth(
161
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
162
+ )
163
+ return below1k + below4k + below8k + below16k + self.above16k
164
+
165
+ def version7(self):
166
+ below1k = self.get_band_specs_with_bandwidth(
167
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
168
+ )
169
+ below4k = self.get_band_specs_with_bandwidth(
170
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
171
+ )
172
+ below8k = self.get_band_specs_with_bandwidth(
173
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
174
+ )
175
+ below16k = self.get_band_specs_with_bandwidth(
176
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
177
+ )
178
+ below20k = self.get_band_specs_with_bandwidth(
179
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
180
+ )
181
+ return below1k + below4k + below8k + below16k + below20k + self.above20k
182
+
183
+
184
+ class OtherBandsplitSpecification(VocalBandsplitSpecification):
185
+ def __init__(self, nfft: int, fs: int) -> None:
186
+ super().__init__(nfft=nfft, fs=fs, version="7")
187
+
188
+
189
+ class BassBandsplitSpecification(BandsplitSpecification):
190
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
191
+ super().__init__(nfft=nfft, fs=fs)
192
+
193
+ def get_band_specs(self):
194
+ below500 = self.get_band_specs_with_bandwidth(
195
+ start_index=0, end_index=self.split500, bandwidth_hz=50
196
+ )
197
+ below1k = self.get_band_specs_with_bandwidth(
198
+ start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
199
+ )
200
+ below4k = self.get_band_specs_with_bandwidth(
201
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
202
+ )
203
+ below8k = self.get_band_specs_with_bandwidth(
204
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
205
+ )
206
+ below16k = self.get_band_specs_with_bandwidth(
207
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
208
+ )
209
+ above16k = [(self.split16k, self.max_index)]
210
+
211
+ return below500 + below1k + below4k + below8k + below16k + above16k
212
+
213
+
214
+ class DrumBandsplitSpecification(BandsplitSpecification):
215
+ def __init__(self, nfft: int, fs: int) -> None:
216
+ super().__init__(nfft=nfft, fs=fs)
217
+
218
+ def get_band_specs(self):
219
+ below1k = self.get_band_specs_with_bandwidth(
220
+ start_index=0, end_index=self.split1k, bandwidth_hz=50
221
+ )
222
+ below2k = self.get_band_specs_with_bandwidth(
223
+ start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
224
+ )
225
+ below4k = self.get_band_specs_with_bandwidth(
226
+ start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
227
+ )
228
+ below8k = self.get_band_specs_with_bandwidth(
229
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
230
+ )
231
+ below16k = self.get_band_specs_with_bandwidth(
232
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
233
+ )
234
+ above16k = [(self.split16k, self.max_index)]
235
+
236
+ return below1k + below2k + below4k + below8k + below16k + above16k
237
+
238
+
239
+ class PerceptualBandsplitSpecification(BandsplitSpecification):
240
+ def __init__(
241
+ self,
242
+ nfft: int,
243
+ fs: int,
244
+ fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
245
+ n_bands: int,
246
+ f_min: float = 0.0,
247
+ f_max: float = None,
248
+ ) -> None:
249
+ super().__init__(nfft=nfft, fs=fs)
250
+ self.n_bands = n_bands
251
+ if f_max is None:
252
+ f_max = fs / 2
253
+
254
+ self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
255
+
256
+ weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True) # (1, n_freqs)
257
+ normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs)
258
+
259
+ freq_weights = []
260
+ band_specs = []
261
+ for i in range(self.n_bands):
262
+ active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
263
+ if isinstance(active_bins, int):
264
+ active_bins = (active_bins, active_bins)
265
+ if len(active_bins) == 0:
266
+ continue
267
+ start_index = active_bins[0]
268
+ end_index = active_bins[-1] + 1
269
+ band_specs.append((start_index, end_index))
270
+ freq_weights.append(normalized_mel_fb[i, start_index:end_index])
271
+
272
+ self.freq_weights = freq_weights
273
+ self.band_specs = band_specs
274
+
275
+ def get_band_specs(self):
276
+ return self.band_specs
277
+
278
+ def get_freq_weights(self):
279
+ return self.freq_weights
280
+
281
+ def save_to_file(self, dir_path: str) -> None:
282
+ os.makedirs(dir_path, exist_ok=True)
283
+
284
+ import pickle
285
+
286
+ with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
287
+ pickle.dump(
288
+ {
289
+ "band_specs": self.band_specs,
290
+ "freq_weights": self.freq_weights,
291
+ "filterbank": self.filterbank,
292
+ },
293
+ f,
294
+ )
295
+
296
+
297
+ def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
298
+ fb = taF.melscale_fbanks(
299
+ n_mels=n_bands,
300
+ sample_rate=fs,
301
+ f_min=f_min,
302
+ f_max=f_max,
303
+ n_freqs=n_freqs,
304
+ ).T
305
+
306
+ fb[0, 0] = 1.0
307
+
308
+ return fb
309
+
310
+
311
+ class MelBandsplitSpecification(PerceptualBandsplitSpecification):
312
+ def __init__(
313
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
314
+ ) -> None:
315
+ super().__init__(
316
+ fbank_fn=mel_filterbank,
317
+ nfft=nfft,
318
+ fs=fs,
319
+ n_bands=n_bands,
320
+ f_min=f_min,
321
+ f_max=f_max,
322
+ )
323
+
324
+
325
+ def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
326
+ nfft = 2 * (n_freqs - 1)
327
+ df = fs / nfft
328
+ # init freqs
329
+ f_max = f_max or fs / 2
330
+ f_min = f_min or 0
331
+ f_min = fs / nfft
332
+
333
+ n_octaves = np.log2(f_max / f_min)
334
+ n_octaves_per_band = n_octaves / n_bands
335
+ bandwidth_mult = np.power(2.0, n_octaves_per_band)
336
+
337
+ low_midi = max(0, hz_to_midi(f_min))
338
+ high_midi = hz_to_midi(f_max)
339
+ midi_points = np.linspace(low_midi, high_midi, n_bands)
340
+ hz_pts = midi_to_hz(midi_points)
341
+
342
+ low_pts = hz_pts / bandwidth_mult
343
+ high_pts = hz_pts * bandwidth_mult
344
+
345
+ low_bins = np.floor(low_pts / df).astype(int)
346
+ high_bins = np.ceil(high_pts / df).astype(int)
347
+
348
+ fb = np.zeros((n_bands, n_freqs))
349
+
350
+ for i in range(n_bands):
351
+ fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
352
+
353
+ fb[0, : low_bins[0]] = 1.0
354
+ fb[-1, high_bins[-1] + 1 :] = 1.0
355
+
356
+ return torch.as_tensor(fb)
357
+
358
+
359
+ class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
360
+ def __init__(
361
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
362
+ ) -> None:
363
+ super().__init__(
364
+ fbank_fn=musical_filterbank,
365
+ nfft=nfft,
366
+ fs=fs,
367
+ n_bands=n_bands,
368
+ f_min=f_min,
369
+ f_max=f_max,
370
+ )
371
+
372
+
373
+ # def bark_filterbank(
374
+ # n_bands, fs, f_min, f_max, n_freqs
375
+ # ):
376
+ # nfft = 2 * (n_freqs -1)
377
+ # fb, _ = bark_fbanks.bark_filter_banks(
378
+ # nfilts=n_bands,
379
+ # nfft=nfft,
380
+ # fs=fs,
381
+ # low_freq=f_min,
382
+ # high_freq=f_max,
383
+ # scale="constant"
384
+ # )
385
+
386
+ # return torch.as_tensor(fb)
387
+
388
+ # class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
389
+ # def __init__(
390
+ # self,
391
+ # nfft: int,
392
+ # fs: int,
393
+ # n_bands: int,
394
+ # f_min: float = 0.0,
395
+ # f_max: float = None
396
+ # ) -> None:
397
+ # super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
398
+
399
+
400
+ # def triangular_bark_filterbank(
401
+ # n_bands, fs, f_min, f_max, n_freqs
402
+ # ):
403
+
404
+ # all_freqs = torch.linspace(0, fs // 2, n_freqs)
405
+
406
+ # # calculate mel freq bins
407
+ # m_min = hz2bark(f_min)
408
+ # m_max = hz2bark(f_max)
409
+
410
+ # m_pts = torch.linspace(m_min, m_max, n_bands + 2)
411
+ # f_pts = 600 * torch.sinh(m_pts / 6)
412
+
413
+ # # create filterbank
414
+ # fb = _create_triangular_filterbank(all_freqs, f_pts)
415
+
416
+ # fb = fb.T
417
+
418
+ # first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
419
+ # first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
420
+
421
+ # fb[first_active_band, :first_active_bin] = 1.0
422
+
423
+ # return fb
424
+
425
+ # class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
426
+ # def __init__(
427
+ # self,
428
+ # nfft: int,
429
+ # fs: int,
430
+ # n_bands: int,
431
+ # f_min: float = 0.0,
432
+ # f_max: float = None
433
+ # ) -> None:
434
+ # super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
435
+
436
+
437
+ # def minibark_filterbank(
438
+ # n_bands, fs, f_min, f_max, n_freqs
439
+ # ):
440
+ # fb = bark_filterbank(
441
+ # n_bands,
442
+ # fs,
443
+ # f_min,
444
+ # f_max,
445
+ # n_freqs
446
+ # )
447
+
448
+ # fb[fb < np.sqrt(0.5)] = 0.0
449
+
450
+ # return fb
451
+
452
+ # class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
453
+ # def __init__(
454
+ # self,
455
+ # nfft: int,
456
+ # fs: int,
457
+ # n_bands: int,
458
+ # f_min: float = 0.0,
459
+ # f_max: float = None
460
+ # ) -> None:
461
+ # super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
462
+
463
+
464
+ # def erb_filterbank(
465
+ # n_bands: int,
466
+ # fs: int,
467
+ # f_min: float,
468
+ # f_max: float,
469
+ # n_freqs: int,
470
+ # ) -> Tensor:
471
+ # # freq bins
472
+ # A = (1000 * np.log(10)) / (24.7 * 4.37)
473
+ # all_freqs = torch.linspace(0, fs // 2, n_freqs)
474
+
475
+ # # calculate mel freq bins
476
+ # m_min = hz2erb(f_min)
477
+ # m_max = hz2erb(f_max)
478
+
479
+ # m_pts = torch.linspace(m_min, m_max, n_bands + 2)
480
+ # f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437
481
+
482
+ # # create filterbank
483
+ # fb = _create_triangular_filterbank(all_freqs, f_pts)
484
+
485
+ # fb = fb.T
486
+
487
+
488
+ # first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
489
+ # first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
490
+
491
+ # fb[first_active_band, :first_active_bin] = 1.0
492
+
493
+ # return fb
494
+
495
+
496
+ # class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
497
+ # def __init__(
498
+ # self,
499
+ # nfft: int,
500
+ # fs: int,
501
+ # n_bands: int,
502
+ # f_min: float = 0.0,
503
+ # f_max: float = None
504
+ # ) -> None:
505
+ # super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
506
+
507
+ if __name__ == "__main__":
508
+ import pandas as pd
509
+
510
+ band_defs = []
511
+
512
+ for bands in [VocalBandsplitSpecification]:
513
+ band_name = bands.__name__.replace("BandsplitSpecification", "")
514
+
515
+ mbs = bands(nfft=2048, fs=44100).get_band_specs()
516
+
517
+ for i, (f_min, f_max) in enumerate(mbs):
518
+ band_defs.append(
519
+ {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
520
+ )
521
+
522
+ df = pd.DataFrame(band_defs)
523
+ df.to_csv("vox7bands.csv", index=False)