diff --git a/mvsepless/__init__.py b/mvsepless/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb0f890adfeb87d73fd58f17cb259619b24617ce
--- /dev/null
+++ b/mvsepless/__init__.py
@@ -0,0 +1,1579 @@
+import os
+import sys
+import shutil
+import logging
+import zipfile
+import importlib.util
+from pathlib import Path
+logging.basicConfig(level=logging.WARNING)
+
+script_dir = os.path.dirname(os.path.abspath(__file__))
+os.chdir(script_dir)
+
+if not __package__:
+ from model_manager import MvseplessModelManager
+ from audio import Audio, Inverter
+ from namer import Namer
+ from vbach_infer import vbach_inference, model_manager as VbachModel
+ from downloader import dw_file, dw_yt_dlp
+ from ensemble import ensemble_audio_files
+else:
+ from .model_manager import MvseplessModelManager
+ from .audio import Audio, Inverter
+ from .namer import Namer
+ from .vbach_infer import vbach_inference, model_manager as VbachModel
+ from .downloader import dw_file, dw_yt_dlp
+ from .ensemble import ensemble_audio_files
+from typing import Literal
+import gradio as gr
+import pandas as pd
+import subprocess
+import json
+import threading
+import queue
+import time
+import argparse
+from datetime import datetime
+import tempfile
+import ast
+
+class MVSEPLESS:
+ audio = Audio()
+ inverter = Inverter()
+ namer = Namer()
+ model_manager = MvseplessModelManager()
+ vbach_model_manager = VbachModel
+
+class Separator(MVSEPLESS):
+
+ class OutputReader:
+ def __init__(self, debug=False):
+ self.debug = debug
+
+ def parse_json_line(self, line):
+ try:
+ return json.loads(line)
+ except json.JSONDecodeError:
+ return None
+
+ def reaction_line(self, line, progress, add_text):
+ _add_text = ""
+ if add_text != "" or add_text is not None:
+ _add_text = f"| {add_text}"
+
+ data = self.parse_json_line(line)
+ if data is None:
+ return None
+ elif "reading" in data:
+ progress(0.05, desc=f"Чтение файла {_add_text}")
+ print("Чтение файла")
+ return None
+ elif "processing" in data:
+ progress_a = data["processing"]
+ processed = progress_a.get("processed", 0)
+ total = progress_a.get("total", 1)
+ # Исправлено: убрано деление на ноль
+ if total > 0:
+ progress_ratio = min(0.89, 0.05 + (processed / total * 0.85)) # Оставляем место для этапа записи
+ percent = int((processed / total) * 100)
+ progress(progress_ratio, desc=f"Обработано: {percent}% {_add_text}")
+ print(f"\rОбработано: {percent}%", end="")
+ return None
+ elif "writing" in data:
+ progress(0.9, desc="Запись результатов")
+ print(f"\rЗапись в файл {data['writing']}", end="")
+ return None
+ elif "done" in data:
+ progress(1.0, desc=f"Завершено {_add_text}")
+ print("\rЗавершено", end="\n")
+ return data["done"]
+ elif "error" in data:
+ raise Exception(data["error"])
+
+ def read_stream_to_queue(self, stream, queue_obj, stream_name):
+ """Чтение потока вывода подпроцесса и запись в очередь"""
+ try:
+ for line in iter(stream.readline, ''):
+ line = line.strip()
+ if line:
+ if self.debug:
+ print(f"[{stream_name}] {line}") # Отладочный вывод
+ queue_obj.put(line)
+ stream.close()
+ except Exception as e:
+ print(f"Error reading {stream_name}: {e}")
+
+ output_reader = OutputReader()
+
+ def separator_model_loader(self, model_type: str, model_name: str, mdx_denoise: bool, vr_aggr: bool, progress) -> tuple[int, str, str]:
+
+ if model_type in [
+ "mel_band_roformer",
+ "bs_roformer",
+ "mdx23c",
+ "mdxnet",
+ "vr",
+ "scnet",
+ "htdemucs",
+ "bandit",
+ "bandit_v2",
+ ]:
+ info = self.model_manager.models_info[model_type].get(model_name, None)
+ if not info:
+ raise ValueError(f"Модель {model_name} не найдена для типа {model_type}")
+
+ id = self.model_manager.get_id(model_type, model_name)
+ conf, ckpt = self.model_manager.download_model(
+ self.model_manager.models_cache_dir,
+ model_name,
+ model_type,
+ info["checkpoint_url"],
+ info["config_url"],
+ )
+ if model_type != "htdemucs":
+ self.model_manager.conf_editor(conf, mdx_denoise, vr_aggr, model_type)
+
+ return id, conf, ckpt
+
+ else:
+ raise ValueError("Неподдерживаемый тип модели")
+
+ def separator_base(
+ self,
+ input_file: str,
+ output_dir: str,
+ model_type: Literal[
+ "mel_band_roformer",
+ "bs_roformer",
+ "mdx23c",
+ "mdxnet",
+ "scnet",
+ "htdemucs",
+ "bandit",
+ "bandit_v2",
+ ] = "mel_band_roformer",
+ model_name: str = "Mel-Band-Roformer_Vocals_kimberley_jensen",
+ ext_inst: bool = True,
+ output_format: Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] = "mp3",
+ output_bitrate: str = "320k",
+ template: str = "NAME_(STEM)_MODEL",
+ selected_stems: list = None,
+ ckpt: str = None,
+ conf: str = None,
+ id: int = None,
+ progress: any = gr.Progress(track_tqdm=True),
+ add_text_progress: str = ""
+ ) -> list[tuple[str, str]]:
+
+ if model_type in [
+ "mel_band_roformer",
+ "bs_roformer",
+ "mdx23c",
+ "mdxnet",
+ "vr",
+ "scnet",
+ "htdemucs",
+ "bandit",
+ "bandit_v2",
+ ]:
+
+ cmd = [
+ os.sys.executable,
+ "-m",
+ "infer",
+ "--input", input_file,
+ "--store_dir", output_dir,
+ "--model_type", model_type,
+ "--model_name", model_name,
+ "--model_id", str(id),
+ "--config_path", conf,
+ "--start_check_point", ckpt,
+ "--output_format", output_format,
+ "--output_bitrate", str(output_bitrate),
+ "--template", template
+ ]
+ if ext_inst:
+ cmd.append("--extract_instrumental")
+ if selected_stems:
+ cmd.append("--selected_instruments")
+ cmd.extend(selected_stems)
+
+ try:
+ process = subprocess.Popen(
+ cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True,
+ bufsize=1,
+ universal_newlines=True,
+ encoding='utf-8',
+ errors='replace'
+ )
+
+ # Создаем очереди для stdout и stderr
+ stdout_queue = queue.Queue()
+ stderr_queue = queue.Queue()
+
+ # Запускаем потоки для чтения stdout и stderr
+ stdout_thread = threading.Thread(
+ target=self.output_reader.read_stream_to_queue,
+ args=(process.stdout, stdout_queue, "stdout")
+ )
+ stderr_thread = threading.Thread(
+ target=self.output_reader.read_stream_to_queue,
+ args=(process.stderr, stderr_queue, "stderr")
+ )
+
+ stdout_thread.daemon = True
+ stderr_thread.daemon = True
+
+ stdout_thread.start()
+ stderr_thread.start()
+
+ results = {'output': None, 'error': None}
+ process_completed = False
+
+ # Основной цикл обработки сообщений
+ while not process_completed:
+ # Проверяем завершение процесса
+ if process.poll() is not None:
+ process_completed = True
+
+ # Обрабатываем сообщения из stdout
+ try:
+ stdout_line = stdout_queue.get_nowait()
+ result = self.output_reader.reaction_line(stdout_line, progress, add_text_progress)
+ if result is not None:
+ results['output'] = result
+ break
+ except queue.Empty:
+ pass
+
+ # Обрабатываем сообщения из stderr
+ try:
+ stderr_line = stderr_queue.get_nowait()
+ result = self.output_reader.reaction_line(stderr_line, progress, add_text_progress)
+ if result is not None:
+ results['output'] = result
+ break
+ except queue.Empty:
+ pass
+
+ # Если процесс еще работает, ждем немного перед следующей проверкой
+ if not process_completed:
+ time.sleep(0.1)
+
+ # Дополнительная обработка оставшихся сообщений после завершения процесса
+ for _ in range(10): # Проверяем несколько раз на случай задержек
+ try:
+ stdout_line = stdout_queue.get_nowait()
+ result = self.output_reader.reaction_line(stdout_line, progress, add_text_progress)
+ if result is not None:
+ results['output'] = result
+ break
+ except queue.Empty:
+ pass
+
+ try:
+ stderr_line = stderr_queue.get_nowait()
+ result = self.output_reader.reaction_line(stderr_line, progress, add_text_progress)
+ if result is not None:
+ results['output'] = result
+ break
+ except queue.Empty:
+ pass
+
+ time.sleep(0.1)
+
+ # Проверяем результаты
+ if results.get('error'):
+ raise Exception(results['error'])
+
+ if results.get('output'):
+ return results['output']
+
+ # Если процесс завершился с ошибкой
+ if process.returncode != 0:
+ # Пытаемся получить последние сообщения об ошибках
+ error_messages = []
+ try:
+ while True:
+ error_msg = stderr_queue.get_nowait()
+ error_messages.append(error_msg)
+ except queue.Empty:
+ pass
+
+ error_text = "\n".join(error_messages[-5:]) # Последние 5 сообщений
+ raise Exception(f"Процесс завершился с ошибкой. Код возврата: {process.returncode}. Сообщения об ошибках:\n{error_text}")
+
+ except Exception as e:
+ raise e
+ finally:
+ # Гарантируем завершение процесса
+ try:
+ if process.poll() is None:
+ process.terminate()
+ process.wait(timeout=5)
+ except:
+ try:
+ process.kill()
+ except:
+ pass
+ else:
+ raise ValueError("Неподдерживаемый тип модели")
+
+ def separate(
+ self,
+ input: str | list = None,
+ output_dir: str = None,
+ model_type: Literal[
+ "mel_band_roformer",
+ "bs_roformer",
+ "mdx23c",
+ "mdxnet",
+ "vr",
+ "scnet",
+ "htdemucs",
+ "bandit",
+ "bandit_v2",
+ ] = "mel_band_roformer",
+ model_name: str = "Mel-Band-Roformer_Vocals_kimberley_jensen",
+ ext_inst: bool = True,
+ output_format: Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] = "mp3",
+ output_bitrate: str = "320k",
+ template: str = "NAME_(STEM)_MODEL",
+ selected_stems: list = None,
+ add_settings: dict = {"mdx_denoise": False, "vr_aggr": 5, "add_single_sep_text_progress": None},
+ progress: any = gr.Progress(track_tqdm=True)
+ ) -> list[tuple[str, str]] | list[str, list[tuple[str, str]]]:
+
+ progress(0, desc="Начало обработки")
+
+ # Валидация параметров
+ if output_format not in self.audio.output_formats:
+ output_format = "flac"
+
+ if output_dir is None:
+ output_dir = os.getcwd()
+
+ if output_dir:
+ output_dir = os.path.abspath(output_dir)
+
+ if selected_stems is None:
+ selected_stems = []
+
+ if not input:
+ raise ValueError("Входной файл не указан")
+
+ if "STEM" not in template and template is not None:
+ template = template + "_STEM_"
+ if not template:
+ template = "mvsepless_NAME_(STEM)"
+
+ os.makedirs(output_dir, exist_ok=True)
+
+ mdx_denoise = add_settings.get("mdx_denoise", False)
+
+ vr_aggr = add_settings.get("vr_aggr", 5)
+
+ add_progress_text_custom = add_settings.get("add_single_sep_text_progress", "")
+
+ id, conf, ckpt = self.separator_model_loader(model_type, model_name, mdx_denoise, vr_aggr, progress)
+
+ if isinstance(input, str):
+ if not os.path.exists(input):
+ raise ValueError(f"Входной файл не найден: {input}")
+
+ if not self.audio.check(input):
+ raise ValueError("Входной файл не содержит аудио")
+
+ basename = os.path.splitext(os.path.basename(input))[0]
+ seped = self.separator_base(input_file=input,
+ output_dir=output_dir,
+ model_type=model_type,
+ model_name=model_name,
+ ext_inst=ext_inst,
+ output_format=output_format,
+ output_bitrate=output_bitrate,
+ template=template,
+ selected_stems=selected_stems,
+ ckpt=ckpt,
+ conf=conf,
+ id=id,
+ progress=progress,
+ add_text_progress=add_progress_text_custom)
+ return seped
+
+ elif isinstance(input, list):
+ results = []
+ for i, f in enumerate(input, 1):
+ print(f"Файл {i} из {len(input)}: {f}")
+ if os.path.exists(f):
+ if self.audio.check(f):
+ basename = os.path.splitext(os.path.basename(f))[0]
+ seped = self.separator_base(input_file=f,
+ output_dir=output_dir,
+ model_type=model_type,
+ model_name=model_name,
+ ext_inst=ext_inst,
+ output_format=output_format,
+ output_bitrate=output_bitrate,
+ template=template,
+ selected_stems=selected_stems,
+ ckpt=ckpt,
+ conf=conf,
+ id=id,
+ progress=progress,
+ add_text_progress=f"({i} из {len(input)}) ")
+ results.append([basename, seped])
+ return results
+
+
+ def UI(self, output_base_dir, output_temp_dir_check):
+ default_mts = self.model_manager.get_mt()
+ default_mt = self.model_manager.get_mt()[0]
+ default_mns = self.model_manager.get_mn(default_mt)
+ default_mn = default_mns[0]
+ default_stems = self.model_manager.get_stems(default_mt, default_mn)
+ default_tgt_inst = self.model_manager.get_tgt_inst(default_mt, default_mn)
+ with gr.Blocks():
+ with gr.Row():
+ with gr.Column():
+ with gr.Group(visible=False) as add_inputs:
+ input_path = gr.Textbox(label="Путь к входному файлу", interactive=True)
+ add_inputs_btn = gr.Button("Добавить файл", variant="primary")
+ with gr.Group(visible=False) as add_inputs_from_url:
+ input_url = gr.Textbox(label="URL входного файла", interactive=True)
+ with gr.Row():
+ inputs_url_format = gr.Dropdown(label="Формат входного файла", interactive=True,
+ choices=self.audio.output_formats,
+ value="mp3", filterable=False)
+ inputs_url_bitrate = gr.Slider(label="Битрейт входного файла", minimum=64, maximum=512, step=32, value=320, interactive=True)
+ with gr.Row():
+ inputs_url_cookie = gr.UploadButton(label="Файл cookie (необязательно)", interactive=True, type="filepath", file_count="single", file_types=[".txt", ".cookies"], variant="secondary")
+ add_inputs_url_btn = gr.Button("Добавить файл", variant="primary")
+ with gr.Row(visible=True) as add_buttons_row:
+ add_path_btn = gr.Button("Добавить файл по пути", variant="secondary")
+ add_url_btn = gr.Button("Добавить файл по URL", variant="secondary")
+ with gr.Group():
+ input_audio = gr.File(label="Входные аудио", interactive=True, type="filepath", file_count="multiple", file_types=[f".{of}" for of in self.audio.input_formats])
+ sep_state = gr.Textbox(label="Состояние разделения", interactive=False, value="", visible=False)
+ status = gr.Textbox(container=False, lines=3, interactive=False, max_lines=3)
+ input_preview_check = gr.Checkbox(label="Показать плееры для входных аудио", interactive=True, value=False)
+ @gr.render(inputs=[input_preview_check, input_audio])
+ def show_input_players(preview, audios):
+ if preview:
+ if audios:
+ with gr.Group():
+ for file in audios:
+ gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath")
+ with gr.Column():
+ with gr.Group():
+ with gr.Row():
+ model_type = gr.Dropdown(label="Тип модели", interactive=True, filterable=False,
+ choices=default_mts,
+ value=default_mt)
+ model_name = gr.Dropdown(label="Имя модели", interactive=True, filterable=False,
+ choices=default_mns, value=default_mn)
+ with gr.Group():
+ extract_instrumental = gr.Checkbox(label="Извлечь инструментал", interactive=True, value=True)
+ selected_stems = gr.CheckboxGroup(label="Выбранные стемы для разделения", interactive=False,
+ choices=default_stems, value=[])
+
+ with gr.Accordion(label="Дополнительные настройки", open=False):
+ vr_aggr_slider = gr.Slider(label="Сила подавления для VR моделей", minimum=-100, maximum=100, value=5, step=1)
+ mdx_denoise_check = gr.Checkbox(label="Включить шумоподавление для MDX-NET моделей (это повышает потребление памяти в два раза)", value=False)
+ with gr.Row():
+ output_format = gr.Dropdown(label="Формат выходного файла", interactive=True,
+ choices=self.audio.output_formats,
+ value="mp3", filterable=False)
+ output_bitrate = gr.Slider(label="Битрейт выходного файла", minimum=64, maximum=512, step=32, value=320, interactive=True)
+
+ template = gr.Textbox(label="Шаблон именования выходных файлов", interactive=True, value="NAME (STEM) MODEL", info="Используйте ключи: \nNAME - имя входного файла без расширения, \nSTEM - имя стема, \nMODEL - имя модели разделения")
+ separate_btn = gr.Button("Разделить", variant="primary")
+
+ @gr.render(inputs=[sep_state], triggers=[sep_state.change])
+ def players(state):
+ def create_archive_advanced(file_list, archive_name="archive.zip"):
+ """
+ Создает архив с расширенной обработкой ошибок
+ """
+ try:
+ print("Генерация ZIP-архива с результатами разделения...")
+ with zipfile.ZipFile(archive_name, 'w', zipfile.ZIP_DEFLATED) as zipf:
+ successful_files = 0
+
+ for basename, stems in file_list:
+ for stem_name, stem_path in stems:
+ try:
+ if os.path.exists(stem_path) and os.path.isfile(stem_path):
+ basename_ = os.path.basename(stem_path)
+ zipf.write(stem_path, basename_)
+ successful_files += 1
+ print(f"✓ Добавлен: {stem_path} -> {basename}")
+ else:
+ print(f"✗ Файл не найден или не является файлом: {stem_path}")
+
+ except Exception as e:
+ print(f"✗ Ошибка при добавлении {stem_path}: {e}")
+
+ print(f"\nАрхив создан: {archive_name}")
+ print(f"Успешно добавлено файлов: {successful_files}")
+ return os.path.abspath(archive_name)
+
+ except Exception as e:
+ print(f"Ошибка при создании архива: {e}")
+ if state != "":
+ state_loaded = ast.literal_eval(state)
+ archive_stems = create_archive_advanced(state_loaded, os.path.join(tempfile.tempdir, f"mvsepless_output_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip"))
+ for basename, stems in state_loaded:
+ with gr.Group():
+ gr.Markdown(f"
{basename}
")
+ for stem_name, stem_path in stems:
+ with gr.Row(equal_height=True):
+ output_stem = gr.Audio(value=stem_path, label=stem_name, type="filepath", interactive=False, show_download_button=True, scale=15)
+ reuse_btn = gr.Button("Использовать снова", variant="secondary")
+ @reuse_btn.click(
+ inputs=[output_stem, input_audio],
+ outputs=input_audio
+ )
+ def reuse_fn(stem_audio, input_a):
+ if input_a is None:
+ input_a = []
+ if isinstance(input_a, str):
+ input_a = [input_a]
+ if os.path.exists(stem_audio):
+ if self.audio.check(stem_audio):
+ input_a.append(stem_audio)
+ return input_a
+
+ gr.DownloadButton(label="Скачать как ZIP", value=archive_stems, interactive=True)
+
+ @add_inputs_btn.click(
+ inputs=[input_path, input_audio],
+ outputs=[add_inputs, input_audio, add_buttons_row])
+ def add_inputs_fn(input_p, input_a):
+ if input_p and os.path.exists(input_p):
+ if input_a is None:
+ input_a = []
+ if isinstance(input_a, str):
+ input_a = [input_a]
+ if self.audio.check(input_p):
+ input_a.append(input_p)
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
+
+ @add_inputs_url_btn.click(
+ inputs=[input_url, input_audio, inputs_url_format, inputs_url_bitrate, inputs_url_cookie],
+ outputs=[add_inputs_from_url, input_audio, add_buttons_row])
+ def add_inputs_from_url_fn(input_u, input_a, fmt, br, cookie):
+ if input_u:
+ if input_a is None:
+ input_a = []
+ if isinstance(input_a, str):
+ input_a = [input_a]
+ downloaded_file = dw_yt_dlp(
+ url=input_u,
+ output_format=fmt,
+ output_bitrate=str(int(br)),
+ cookie=cookie
+ )
+ if downloaded_file and os.path.exists(downloaded_file):
+ if self.audio.check(downloaded_file):
+ input_a.append(downloaded_file)
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
+
+ add_path_btn.click(
+ lambda: (gr.update(visible=True), gr.update(visible=False)),
+ outputs=[add_inputs, add_buttons_row])
+
+ add_url_btn.click(
+ lambda: (gr.update(visible=True), gr.update(visible=False)),
+ outputs=[add_inputs_from_url, add_buttons_row])
+
+ model_type.change(lambda x: gr.update(choices=self.model_manager.get_mn(x), value=self.model_manager.get_mn(x)[0]),
+ inputs=model_type, outputs=model_name)
+ model_name.change(lambda mt, mn: (gr.update(choices=self.model_manager.get_stems(mt, mn), value=[], interactive=False if self.model_manager.get_tgt_inst(mt, mn) else True), gr.update(value=True if self.model_manager.get_tgt_inst(mt, mn) else False)), inputs=[model_type, model_name], outputs=[selected_stems, extract_instrumental])
+ output_format.change(lambda x: gr.update(visible=False if x in ["wav", "flac", "aiff"] else True), inputs=output_format, outputs=output_bitrate)
+ inputs_url_format.change(lambda x: gr.update(visible=False if x in ["wav", "flac", "aiff"] else True), inputs=inputs_url_format, outputs=inputs_url_bitrate)
+ @separate_btn.click(
+ inputs=[
+ input_audio, model_type, model_name,
+ extract_instrumental, output_format, output_bitrate,
+ template, selected_stems, output_base_dir, output_temp_dir_check, mdx_denoise_check, vr_aggr_slider
+ ],
+ outputs=[sep_state, status],
+ show_progress="full"
+ )
+ def wrap(i, mt, mn, ei, of, ob, t, stems, o_dir, temp_save, mdx_denoise, vr_aggr, progress=gr.Progress(track_tqdm=True)):
+ if o_dir.strip() != "" and not temp_save:
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ o = os.path.join(o_dir, f"mvsepless_outputs_{timestamp}")
+ os.makedirs(o, exist_ok=True)
+ else:
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ o = tempfile.mkdtemp(prefix=f"mvsepless_outputs_{timestamp}_")
+ os.makedirs(o, exist_ok=True)
+ results = self.separate(i, o, mt, mn, ei, of, ob, t, stems, add_settings={"mdx_denoise": mdx_denoise, "vr_aggr": int(vr_aggr)}, progress=progress)
+ return str(results), ""
+
+class AutoEnsembless(Separator):
+
+ class ModelManager(MVSEPLESS):
+ def __init__(self):
+ self.data = []
+ self.ensemble_methods = ("min_fft", "max_fft", "avg_fft", "median_fft")
+ self.ensemble_invert_methods_map = {"min_fft": "max_fft", "max_fft": "min_fft", "avg_fft": "avg_fft", "median_fft": "median_fft"}
+ self.dir_presets = os.path.join(tempfile.tempdir, "presets")
+ os.makedirs(self.dir_presets, exist_ok=True)
+
+ def save(self, name):
+ if not name:
+ name = "ensembless_preset"
+ filepath = os.path.join(self.dir_presets, f"{self.namer.short(self.namer.sanitize(name), length=50)}.json")
+ with open(filepath, "w") as f:
+ json.dump(self.data, f, indent=4, ensure_ascii=False)
+ return filepath
+
+ def load(self, filepath):
+ with open(filepath, "r") as f:
+ ensemble_data_temp = json.load(f)
+ self.data = []
+ for (mt, mn, s_stem, i_stem, weight) in ensemble_data_temp:
+ if {mt, mn} not in [{model[0], model[1]} for model in self.data]:
+ self.data.append((mt, mn, s_stem, i_stem, weight))
+
+ def add(self, mt, mn, s_stem, i_stem, weight):
+ if {mt, mn} not in [{model[0], model[1]} for model in self.data]:
+ if s_stem and i_stem:
+ self.data.append((mt, mn, s_stem, i_stem, weight))
+
+ def replace(self, mt, mn, s_stem, i_stem, weight, index=1):
+ if self.data:
+ len_data = len(self.data)
+ if index >= 1:
+ if index <= len_data:
+ self.data[index - 1] = (mt, mn, s_stem, i_stem, weight)
+ elif index == 0:
+ self.data[0] = (mt, mn, s_stem, i_stem, weight)
+
+ def remove(self, index=1):
+ if self.data:
+ len_data = len(self.data)
+ if index >= 1:
+ if index <= len_data:
+ del self.data[index - 1]
+ elif index == 0:
+ del self.data[0]
+
+ def clear(self):
+ self.data = []
+
+ def get_df(self):
+ if not self.data:
+ columns = ["#", "Имя модели", "Основной стем", "Инверсия", "Вес"]
+ return pd.DataFrame(columns=columns)
+
+ data = []
+ for i, model in enumerate(self.data):
+ data.append(
+ [
+ f"{i+1}",
+ model[1],
+ model[2],
+ model[3],
+ model[4],
+ ]
+ )
+ columns = ["#", "Имя модели", "Основной стем", "Инверсия", "Вес"]
+ return pd.DataFrame(data, columns=columns)
+
+ def UI(self, output_base_dir, output_temp_dir_check):
+ ensemble_model_manager = self.ModelManager()
+ def get_stems(mt, mn):
+ stems = []
+ for stem in self.model_manager.get_stems(mt, mn):
+ stems.append(stem)
+
+ if not self.model_manager.get_tgt_inst(mt, mn):
+ if set(stems) == {"bass", "drums", "other", "vocals"} or set(stems) == {"bass", "drums", "other", "vocals", "piano", "guitar"}:
+ stems.append("instrumental +")
+ stems.append("instrumental -")
+
+ return stems
+
+ def get_invert_stems(mt, mn, s_stem):
+ orig_stems = []
+ stems = []
+ for stem in self.model_manager.get_stems(mt, mn):
+ orig_stems.append(stem)
+
+ for stem in orig_stems:
+ if stem != s_stem:
+ stems.append(stem)
+
+ if not self.model_manager.get_tgt_inst(mt, mn):
+ if len(orig_stems) > 2:
+ if s_stem not in ["instrumental +", "instrumental -"]:
+ stems.append("inverted +")
+ stems.append("inverted -")
+
+ return stems
+
+ default_model = {
+ "mt": self.model_manager.get_mt(),
+ "mn": self.model_manager.get_mn(self.model_manager.get_mt()[0]),
+ "stem": get_stems(
+ self.model_manager.get_mt()[0],
+ self.model_manager.get_mn(self.model_manager.get_mt()[0])[0],
+ ),
+ "invert_stem": get_invert_stems(
+ self.model_manager.get_mt()[0],
+ self.model_manager.get_mn(self.model_manager.get_mt()[0])[0],
+ "vocals",
+ ),
+ "weight": 1,
+ }
+
+ gr.Markdown("Пресет
")
+ with gr.Group():
+ with gr.Row(equal_height=True):
+ export_preset_name = gr.Textbox(
+ label="Имя пресета",
+ interactive=True,
+ value="ensembless_preset", scale=9
+ )
+ export_btn = gr.DownloadButton("Экспорт", variant="secondary", scale=3, interactive=True)
+ import_btn = gr.UploadButton(
+ "Импорт", file_types=[".json"], file_count="single", scale=3, interactive=True
+ )
+ gr.Markdown("Ансамбль
")
+ with gr.Row():
+ with gr.Column(scale=3): # логика добавлеия моделей
+ model_type = gr.Dropdown(label="Тип модели", choices=default_model["mt"], value=default_model["mt"][0], interactive=True, filterable=False)
+ model_name = gr.Dropdown(label="Имя модели", choices=default_model["mn"], value=default_model["mn"][0], interactive=True, filterable=False)
+ primary_stem = gr.Dropdown(label="Основной стем", choices=default_model["stem"], value=default_model["stem"][0], interactive=True, filterable=False)
+ secondary_stem = gr.Dropdown(label="Инверсия", choices=default_model["invert_stem"], value=default_model["invert_stem"][0], interactive=True, filterable=False)
+ weight = gr.Slider(label="Вес", minimum=0, maximum=10, step=0.01, value=1, interactive=True)
+ @model_type.change(
+ inputs=[model_type],
+ outputs=[model_name]
+ )
+ def update_model_names(mt):
+ model_names = self.model_manager.get_mn(mt)
+ new_mn = model_names[0] if model_names else ""
+
+ return gr.update(choices=model_names, value=new_mn)
+ @model_name.change(
+ inputs=[model_type, model_name],
+ outputs=[primary_stem, secondary_stem]
+ )
+ def update_stems_after_model_change(mt, mn):
+ stems = get_stems(mt, mn)
+ invert_stems = get_invert_stems(mt, mn, stems[0]) if stems else []
+
+ new_s_stem = stems[0] if stems else ""
+ new_i_stem = invert_stems[0] if invert_stems else ""
+
+ return (
+ gr.update(choices=stems, value=new_s_stem),
+ gr.update(choices=invert_stems, value=new_i_stem)
+ )
+ @primary_stem.change(
+ inputs=[model_type, model_name, primary_stem],
+ outputs=[secondary_stem]
+ )
+ def update_invert_stems(mt, mn, s_stem):
+ stems = get_invert_stems(mt, mn, s_stem)
+ new_i_stem = stems[0] if stems else ""
+ return gr.update(choices=stems, value=new_i_stem)
+
+ model_add_button = gr.Button("Добавить", interactive=True)
+ with gr.Column(scale=10):
+ df = gr.DataFrame(
+ value=ensemble_model_manager.get_df(),
+ headers=["#", "Имя модели", "Основной стем", "Инверсия", "Вес"],
+ datatype=["number", "str", "str", "str", "number"],
+ interactive=False
+ )
+
+ with gr.Group():
+ with gr.Row(equal_height=True):
+ with gr.Column():
+ model_index = gr.Number(label="Индекс модели", value=1, interactive=True)
+ model_clear_btn = gr.Button("Очистить", variant="stop", interactive=True)
+ with gr.Column():
+ model_replace_btn = gr.Button("Заменить", variant="primary", interactive=True)
+ model_delete_btn = gr.Button("Удалить", variant="stop", interactive=True)
+
+ @model_add_button.click(
+ inputs=[model_type, model_name, primary_stem, secondary_stem, weight],
+ outputs=df
+ )
+ def add_model_to_auto_ensemble(mt, mn, s_stem, i_stem, weight):
+ ensemble_model_manager.add(mt, mn, s_stem, i_stem, weight)
+ return ensemble_model_manager.get_df()
+
+ @model_replace_btn.click(
+ inputs=[model_type, model_name, primary_stem, secondary_stem, weight, model_index],
+ outputs=df
+ )
+ def replace_model_to_auto_ensemble(mt, mn, s_stem, i_stem, weight, index):
+ ensemble_model_manager.replace(mt, mn, s_stem, i_stem, weight, index)
+ return ensemble_model_manager.get_df()
+
+ @model_delete_btn.click(
+ inputs=[model_index],
+ outputs=df
+ )
+ def delete_model_to_auto_ensemble(index):
+ ensemble_model_manager.remove(index)
+ return ensemble_model_manager.get_df()
+
+ @model_clear_btn.click(
+ outputs=df
+ )
+ def clear_model_to_auto_ensemble():
+ ensemble_model_manager.clear()
+ return ensemble_model_manager.get_df()
+
+ gr.on(fn=ensemble_model_manager.get_df, outputs=df)
+
+ df.change(
+ fn=ensemble_model_manager.save,
+ inputs=export_preset_name,
+ outputs=export_btn
+ )
+
+ export_preset_name.change(
+ fn=ensemble_model_manager.save,
+ inputs=export_preset_name,
+ outputs=export_btn
+ )
+
+ @import_btn.upload(
+ inputs=import_btn,
+ outputs=df
+ )
+ def load_ensemble_preset(filepath):
+ ensemble_model_manager.load(filepath)
+ return ensemble_model_manager.get_df()
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown("Входное аудио
")
+ with gr.Group():
+ with gr.Group(visible=False) as add_inputs:
+ input_path = gr.Textbox(label="Путь к входному файлу", interactive=True)
+ add_inputs_btn = gr.Button("Загрузить файл", variant="primary")
+ with gr.Group(visible=False) as add_inputs_from_url:
+ input_url = gr.Textbox(label="URL входного файла", interactive=True)
+ with gr.Row():
+ inputs_url_format = gr.Dropdown(label="Формат входного файла", interactive=True,
+ choices=self.audio.output_formats,
+ value="mp3", filterable=False)
+ inputs_url_bitrate = gr.Slider(label="Битрейт входного файла", minimum=64, maximum=512, step=32, value=320, interactive=True)
+ with gr.Row():
+ inputs_url_cookie = gr.UploadButton(label="Файл cookie (необязательно)", interactive=True, type="filepath", file_count="single", file_types=[".txt", ".cookies"], variant="secondary")
+ add_inputs_url_btn = gr.Button("Загрузить файл", variant="primary")
+ with gr.Row(visible=True) as add_buttons_row:
+ add_path_btn = gr.Button("Загрузить файл по пути", variant="secondary")
+ add_url_btn = gr.Button("Загрузить файл по URL", variant="secondary")
+ with gr.Group():
+ input_audio = gr.File(label="Входное аудио", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.audio.input_formats])
+ with gr.Column():
+ gr.Markdown("Настройки
")
+ with gr.Group():
+ method = gr.Dropdown(label="Алгоритм склеивания", choices=["min_fft", "max_fft", "avg_fft", "median_fft"], value="avg_fft", filterable=False)
+ invert_ensemble = gr.Checkbox(label="Инверсия ансамбля", interactive=True, value=False)
+ output_format = gr.Dropdown(label="Формат выходного файла", interactive=True,
+ choices=self.audio.output_formats,
+ value="mp3", filterable=False)
+ run_btn = gr.Button("Создать ансамбль", variant="primary", interactive=True)
+
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown("Результаты
")
+ output_audio = gr.Audio(label="Результат", type="filepath", interactive=False, show_download_button=True)
+ output_audio_wav = gr.Textbox(label="Результат в WAV", interactive=False, visible=False)
+ with gr.Group():
+ invert_method = gr.Radio(
+ choices=["waveform", "spectrogram"],
+ label="Метод создания инверсии",
+ value="waveform",
+ )
+ invert_btn = gr.Button("Инвертировать")
+ output_inverted_audio = gr.Audio(label="Инверсия", type="filepath", interactive=False, show_download_button=True)
+ @invert_btn.click(inputs=[input_audio, output_audio_wav, invert_method, output_format], outputs=[output_inverted_audio])
+ def invert_result_ensemble(input_file, output_file, method, out_format):
+ if input_file and output_file:
+ o_dir = os.path.dirname(output_file)
+ basename = os.path.splitext(os.path.basename(input_file))[0]
+ output_path = os.path.join(o_dir, f"ensembless_{self.namer.short(basename, length=50)}_{method}_invert.{out_format}")
+ inverted = self.inverter.process_audio(audio1_path=input_file, audio2_path=output_file, out_format=out_format, method=method, output_path=output_path)
+ return inverted
+ else:
+ return None
+
+ with gr.Column():
+ gr.Markdown("Исходники ансамбля (WAV)
")
+ output_source_files = gr.Files(type="filepath", interactive=False, show_label=False)
+ output_source_preview_check = gr.Checkbox(label="Показать плееры для исходников ансамбля", interactive=True, value=False)
+ @gr.render(inputs=[output_source_preview_check, output_source_files])
+ def show_output_auto_ensemble_players(preview, audios):
+ if preview:
+ if audios:
+ with gr.Group():
+ for file in audios:
+ gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath")
+
+ @run_btn.click(
+ inputs=[input_audio, method, output_format, invert_ensemble, output_base_dir, output_temp_dir_check],
+ outputs=[output_audio, output_audio_wav, output_inverted_audio, output_source_files]
+ )
+ def auto_ensemble_run(input_file, method, out_format, invert_ensemble, o_dir, temp_save, progress=gr.Progress(track_tqdm=True)):
+ ensemble_state = ensemble_model_manager.data
+ invert_methods_map = ensemble_model_manager.ensemble_invert_methods_map
+ if not input_file:
+ return None, None, None, None, []
+ if not os.path.exists(input_file):
+ return None, None, None, None, []
+ if not self.audio.check(input_file):
+ return None, None, None, None, []
+ if o_dir.strip() != "" and not temp_save:
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ o = os.path.join(o_dir, f"ensembless_outputs_{timestamp}")
+ os.makedirs(o, exist_ok=True)
+ else:
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ o = tempfile.mkdtemp(prefix=f"ensembless_outputs_{timestamp}_")
+ os.makedirs(o, exist_ok=True)
+
+ basename = os.path.splitext(os.path.basename(input_file))[0]
+ def invert_weights(weights):
+ total_weight = sum(weights)
+ return [total_weight - w for w in weights]
+ # print(json.dumps(ensemble_state, indent=4, ensure_ascii=False))
+ success_separations = [] # list[tuple[str, str, float]] [(s_stem, i_stem, weight)]
+ ensemble_sources_list = [] # list[str, str, str, ...]
+ if ensemble_state:
+ total_ensemble_models = len(ensemble_state)
+ for i, model in enumerate(ensemble_state, start=1):
+
+ ens_mt = model[0]
+ ens_mn = model[1]
+ ens_s_stem = model[2]
+ ens_i_stem = model[3]
+ weight = model[4]
+
+ s_stem = None #path to primary stem
+ i_stem = None #path to invert stem
+
+ try:
+ result_seped_auto_ensemble = self.separate(input=input_file, output_dir=os.path.join(o, ens_mn), model_type=ens_mt, model_name=ens_mn, ext_inst=True, template="NAME - MODEL - STEM", output_format="wav", add_settings={"add_single_sep_text_progress": f"{i} из {total_ensemble_models}"}, progress=progress)
+ if result_seped_auto_ensemble:
+ for stem, path in result_seped_auto_ensemble:
+ ensemble_sources_list.append(path)
+ if stem == ens_s_stem:
+ s_stem = path
+ elif stem == ens_i_stem:
+ i_stem = path
+
+ if invert_ensemble:
+ if not i_stem:
+ result_seped_auto_ensemble_invert = self.separate(input=input_file, output_dir=os.path.join(o, f"{ens_mn}_invert"), model_type=ens_mt, model_name=ens_mn, ext_inst=True, template="NAME - MODEL - STEM", output_format="wav", selected_stems=[ens_s_stem], add_settings={"add_single_sep_text_progress": f"{i} из {total_ensemble_models} (инверт.)"}, progress=progress)
+ if result_seped_auto_ensemble_invert:
+ for stem, path in result_seped_auto_ensemble:
+ if stem == ens_i_stem:
+ i_stem = path
+ ensemble_sources_list.append(path)
+
+ except Exception as e:
+ print(f"\nПроизошла ошибка при разделении: {e}")
+ progress(0, desc="Произошла ошибка при разделении, модель пропускается...")
+ continue
+ finally:
+ if s_stem:
+ success_separations.append((ens_mn, s_stem, i_stem, weight))
+
+ ensemble_sources_stems = []
+ ensemble_sources_invert_stems = []
+ weights = []
+
+ for out_mn, out_s_stem, out_i_stem, out_weight in success_separations:
+ ensemble_sources_stems.append(out_s_stem)
+ ensemble_sources_invert_stems.append(out_i_stem)
+ weights.append(out_weight)
+
+
+ auto_ensemble_invout_file = None
+ auto_ensemble_invout_file_wav = None
+
+ auto_ensemble_output_name = f"ensembless_{self.namer.short(basename, length=50)}_{len(ensemble_sources_stems)}_{method}"
+ auto_ensemble_inverted_output_name = f"ensembless_{self.namer.short(basename, length=50)}_{len(ensemble_sources_stems)}_{invert_methods_map[method]}_invert"
+ auto_ensemble_out_file, auto_ensemble_out_file_wav = ensemble_audio_files(files=ensemble_sources_stems, weights=weights, output=os.path.join(o, auto_ensemble_output_name), ensemble_type=method, out_format=out_format, add_wav=True)
+
+ if invert_ensemble:
+ auto_ensemble_invout_file, auto_ensemble_invout_file_wav = ensemble_audio_files(files=ensemble_sources_invert_stems, weights=invert_weights(weights), output=os.path.join(o, auto_ensemble_inverted_output_name), ensemble_type=invert_methods_map[method], out_format=out_format, add_wav=True)
+
+ return auto_ensemble_out_file, auto_ensemble_out_file_wav, auto_ensemble_invout_file, ensemble_sources_list
+
+ @add_inputs_btn.click(
+ inputs=[input_path, input_audio],
+ outputs=[add_inputs, input_audio, add_buttons_row])
+ def add_inputs_fn(input_p, input_a):
+ if input_p and os.path.exists(input_p):
+ if input_a is None:
+ input_a = None
+ if self.audio.check(input_p):
+ input_a = input_p
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
+
+ @add_inputs_url_btn.click(
+ inputs=[input_url, input_audio, inputs_url_format, inputs_url_bitrate, inputs_url_cookie],
+ outputs=[add_inputs_from_url, input_audio, add_buttons_row])
+ def add_inputs_from_url_fn(input_u, input_a, fmt, br, cookie):
+ if input_u:
+ if input_a is None:
+ input_a = None
+ downloaded_file = dw_yt_dlp(
+ url=input_u,
+ output_format=fmt,
+ output_bitrate=str(int(br)),
+ cookie=cookie
+ )
+ if downloaded_file and os.path.exists(downloaded_file):
+ if self.audio.check(downloaded_file):
+ input_a = downloaded_file
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
+
+ add_path_btn.click(
+ lambda: (gr.update(visible=True), gr.update(visible=False)),
+ outputs=[add_inputs, add_buttons_row])
+
+ add_url_btn.click(
+ lambda: (gr.update(visible=True), gr.update(visible=False)),
+ outputs=[add_inputs_from_url, add_buttons_row])
+
+ inputs_url_format.change(lambda x: gr.update(visible=False if x in ["wav", "flac", "aiff"] else True), inputs=inputs_url_format, outputs=inputs_url_bitrate)
+
+class ManualEnsembless(MVSEPLESS):
+ def UI(self, output_base_dir, output_temp_dir_check):
+ with gr.Row():
+ with gr.Column():
+ with gr.Group(visible=False) as add_ensemble_inputs:
+ input_ensemble_path = gr.Textbox(label="Путь к входному файлу", interactive=True)
+ add_ensemble_inputs_btn = gr.Button("Добавить файл", variant="primary")
+ add_ensemble_path_btn = gr.Button("Добавить файл по пути", variant="secondary")
+ input_ensemble_files = gr.File(label="Входное аудио", interactive=True, type="filepath", file_count="multiple", file_types=[f".{of}" for of in self.audio.input_formats])
+ input_ensemble_preview_check = gr.Checkbox(label="Показать плееры для входных аудио", interactive=True, value=False)
+ @gr.render(inputs=[input_ensemble_preview_check, input_ensemble_files])
+ def show_input_ensemble_players(preview, audios):
+ if preview:
+ if audios:
+ with gr.Group():
+ for file in audios:
+ gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath")
+
+ with gr.Column():
+ @gr.render(inputs=[input_ensemble_files])
+ def input_ensemble_files_fn(input_files):
+ check_ensemble_files_status = f"""Анализ входных файлов
+---"""
+ hz_ = []
+ err_list = []
+ if input_files:
+ for file in input_files:
+ basename = os.path.splitext(os.path.basename(file))[0]
+ if os.path.exists(file):
+ if self.audio.check(file):
+ info = self.audio.get_info(file)
+ hz = info[0].get("sample_rate")
+ check_ensemble_files_status += f"\n{basename} - {hz} hz"
+ hz_.append(hz)
+ else:
+ check_ensemble_files_status += f"\n{basename} - Нет аудио"
+ err_list.append(file)
+ else:
+ check_ensemble_files_status += f"\n{basename} - Файл не найден"
+ err_list.append(file)
+
+ check_ensemble_files_result = f"Действительных файлов: {len(hz_)}"
+
+ all_same = True
+
+ common_rate = None
+
+ for hz_hz in hz_:
+ if common_rate is None:
+ common_rate = hz_hz
+ elif common_rate != hz_hz:
+ all_same = False
+
+ if hz_ and len(hz_) > 1:
+ check_ensemble_files_result += "\nВсе действительные файлы имеют одинаковую частоту дискретизации" if all_same else "\nОшибка! Все действительные файлы имеют РАЗНУЮ частоту дискретизации"
+ else:
+ check_ensemble_files_result += "\nДля создания ансамбля нужно загрузить, как минимум - 2 файла, содержащие аудио"
+
+
+ check_ensemble_files_status += f"\n \n{check_ensemble_files_result}"
+
+ gr.Textbox(container=False, lines=len(check_ensemble_files_status.split("\n")), interactive=False, value=check_ensemble_files_status)
+
+ weights = gr.Textbox(label="Веса", value="1.0,1.0")
+
+ method = gr.Dropdown(label="Алгоритм склеивания", choices=["min_fft", "max_fft", "avg_fft", "median_fft"], value="avg_fft", filterable=False)
+
+ output_format = gr.Dropdown(label="Формат выходного файла", interactive=True,
+ choices=self.audio.output_formats,
+ value="mp3", filterable=False)
+
+ output_manual_ensemble_filename = gr.Textbox(label="Имя выходного файла", value="ensemble", interactive=True)
+
+ make_manual_ensemble_btn = gr.Button(value="Создать ансамбль", variant="primary")
+
+ manual_ensemble_output_audio = gr.Audio(label="Результат", type="filepath", interactive=False, show_download_button=True)
+
+ @make_manual_ensemble_btn.click(
+ inputs=[input_ensemble_files, method, output_format, output_base_dir, output_temp_dir_check, output_manual_ensemble_filename, weights], outputs=manual_ensemble_output_audio
+ )
+ def make_manual_ensemble_fn(input_files_list, method, out_format, o_dir, temp_save, o_filename, weights: str):
+ if o_dir.strip() != "" and not temp_save:
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ o = os.path.join(o_dir, f"ensembless_outputs_{timestamp}")
+ os.makedirs(o, exist_ok=True)
+ else:
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ o = tempfile.mkdtemp(prefix=f"ensembless_outputs_{timestamp}_")
+ os.makedirs(o, exist_ok=True)
+
+ o_filename = self.namer.sanitize(o_filename)
+ o_filename = self.namer.short(o_filename)
+
+ output_file = ensemble_audio_files(files=input_files_list, output=os.path.join(o, o_filename), weights=[float(x) for x in weights.split(",")], ensemble_type=method, out_format=out_format)
+ return output_file
+
+ add_ensemble_path_btn.click(
+ lambda: (gr.update(visible=True), gr.update(visible=False)),
+ outputs=[add_ensemble_inputs, add_ensemble_path_btn])
+
+ @add_ensemble_inputs_btn.click(
+ inputs=[input_ensemble_path, input_ensemble_files],
+ outputs=[add_ensemble_inputs, input_ensemble_files, add_ensemble_path_btn])
+ def add_ensemble_inputs_fn(input_p, input_a):
+ if input_p and os.path.exists(input_p):
+ if input_a is None:
+ input_a = []
+ if isinstance(input_a, str):
+ input_a = [input_a]
+ if self.audio.check(input_p):
+ input_a.append(input_p)
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
+
+class Inverter_UI(MVSEPLESS):
+ def UI(self):
+ with gr.Group():
+ with gr.Row():
+ original_audio = gr.File(label="Оригинал", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.audio.input_formats])
+ stem_audio = gr.File(label="Cтем, который будет вычтен из оригинала", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.audio.input_formats])
+ with gr.Group():
+ output_format = gr.Dropdown(label="Формат выходного файла", interactive=True,
+ choices=self.audio.output_formats,
+ value="mp3", filterable=False)
+ method = gr.Radio(
+ choices=["waveform", "spectrogram"],
+ label="Метод вычитания",
+ value="waveform",
+ )
+ btn = gr.Button("Вычесть")
+ output_audio = gr.Audio(label="Инверсия", type="filepath", interactive=False, show_download_button=True)
+ @btn.click(inputs=[original_audio, stem_audio, method, output_format], outputs=[output_audio])
+ def invert_result_ensemble(input_file, output_file, method, out_format):
+ if input_file and output_file:
+ o_dir = tempfile.mkdtemp(suffix="_inverter")
+ basename = os.path.splitext(os.path.basename(input_file))[0]
+ output_path = os.path.join(o_dir, f"inverter_{self.namer.short(basename, length=50)}_{method}.{out_format}")
+ inverted = self.inverter.process_audio(audio1_path=input_file, audio2_path=output_file, out_format=out_format, method=method, output_path=output_path)
+ return inverted
+ else:
+ return None
+
+class Vbach(MVSEPLESS):
+ pitch_methods = ("rmvpe+", "fcpe", "mangio-crepe")
+ hop_length_values = (8, 512)
+ index_rates_values = (0, 1)
+ filter_radius_values = (0, 7)
+ protect_values = (0, 0.5)
+ rms_values = (0, 1)
+ f0_min_values = (50, 3000)
+ f0_max_values = (300, 6000)
+
+ def UI(self):
+ with gr.Tab("Инференс"):
+ with gr.Row():
+ with gr.Column():
+ with gr.Group():
+ input_audio = gr.File(label="Входные аудио", interactive=True, type="filepath", file_count="multiple", file_types=[f".{of}" for of in self.audio.input_formats])
+ converted_state = gr.Textbox(label="Состояние разделения", interactive=False, value="", visible=False)
+ status = gr.Textbox(container=False, lines=3, interactive=False, max_lines=3)
+ input_preview_check = gr.Checkbox(label="Показать плееры для входных аудио", interactive=True, value=False)
+ @gr.render(inputs=[input_preview_check, input_audio])
+ def show_input_players(preview, audios):
+ if preview:
+ if audios:
+ with gr.Group():
+ for file in audios:
+ gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath")
+
+ with gr.Column():
+ with gr.Group():
+ model_name = gr.Dropdown(label="Имя модели", interactive=True)
+ model_list_refresh_btn = gr.Button("Обновить", variant="secondary", interactive=True)
+ @model_list_refresh_btn.click(
+ outputs=[model_name]
+ )
+ def refresh_list_voice_models():
+ models = []
+ models = self.vbach_model_manager.parse_voice_models()
+ first_model = None
+ if len(models) > 0:
+ first_model = models[0]
+ return gr.update(choices=models, value=first_model)
+ with gr.Group():
+ pitch_method = gr.Radio(label="Метод извлечения высоты тона", choices=self.pitch_methods, value=self.pitch_methods[0], interactive=True)
+ pitch = gr.Slider(label="Высота тона", minimum=-48, maximum=48, step=0.5, value=0, interactive=True)
+ hop_length = gr.Slider(label="Длина шага", info="Длина шага влияет на точность передачи высоты тона\nЧем меньше длина шага - тем точнее будет передана высота тона", minimum=self.hop_length_values[0], maximum=self.hop_length_values[1], step=8, value=128, interactive=True, visible=False)
+ @pitch_method.change(
+ inputs=[pitch_method],
+ outputs=[hop_length]
+ )
+ def show_mangio_crepe_hop_length(pitch_method):
+ return gr.update(visible=True if pitch_method in ["mangio-crepe"] else False)
+ stereo_mode = gr.Radio(
+ choices=["mono", "left/right", "sim/dif"],
+ label="Стерео режим",
+ info="mono - монофоническая обработка аудио, \nleft/right - обработка левого и правого каналов отдельно, \nsim/dif - обработка фантомного центра и стерео-базы, разделенную на левый и правый каналы",
+ value="mono",
+ interactive=True
+ )
+ with gr.Accordion(label="Дополнительные настройки",open=False):
+ with gr.Group():
+ with gr.Row():
+ index_rate = gr.Slider(label="Влияние индекса", info="Чем ниже значение, тем больше голос похож на исходный; чем выше, тем ближе к модели", minimum=self.index_rates_values[0], maximum=self.index_rates_values[1], step=0.05, value=0, interactive=True)
+ filter_radius = gr.Slider(label="Радиус фильтра", info="Сглаживает результаты извлечения тона\nМожет снизить дыхание и шумы на выходе", minimum=self.filter_radius_values[0], maximum=self.filter_radius_values[1], step=1, value=3, interactive=True)
+ with gr.Row():
+ rms = gr.Slider(label="Соотношение огибающих громкости", info="Значение 0 - огибающая громкости как у входного аудио, 1 - как у выходного сигнала", minimum=self.rms_values[0], maximum=self.rms_values[1], step=0.05, value=0.25, interactive=True)
+ protect = gr.Slider(label="Защита согласных", info="Предовращает роботизацию дыхания и согласных (Может влиять на четкость речи)\nЗначение 0.5 - выключает защиту, 0 - максимальная защита", minimum=self.protect_values[0], maximum=self.protect_values[1], step=0.05, value=0.35, interactive=True)
+ with gr.Group():
+ with gr.Row():
+ f0_min = gr.Slider(label="Нижний предел диапазона определения высоты тона", minimum=self.f0_min_values[0], maximum=self.f0_min_values[1], step=10, value=50, interactive=True)
+ f0_max = gr.Slider(label="Верхний предел диапазона определения высоты тона", minimum=self.f0_max_values[0], maximum=self.f0_max_values[1], step=10, value=1100, interactive=True)
+
+ with gr.Group():
+ output_name = gr.Textbox(label="Имя выходного файла", interactive=True, value="NAME - MODEL - F0METHOD - PITCH")
+ format_output_name_check = gr.Checkbox(label="Форматировать имя", info="Используйте ключи: \nNAME - имя входного файла без расширения, \nPITCH - высота тона, \nF0METHOD - метод извлечения высота тона, \nMODEL - имя голосовой модели", value=True, interactive=True)
+ output_format = gr.Dropdown(label="Формат выходного файла", interactive=True, choices=self.audio.output_formats, value=self.audio.output_formats[0], filterable=False)
+ convert_btn = gr.Button("Преобразовать", variant="primary", interactive=True)
+
+
+ @convert_btn.click(
+ inputs=[
+ input_audio,
+ model_name,
+ pitch_method,
+ pitch,
+ hop_length,
+ index_rate,
+ filter_radius,
+ rms,
+ protect,
+ f0_min,
+ f0_max,
+ output_name,
+ format_output_name_check,
+ output_format,
+ stereo_mode
+ ], outputs=[converted_state, status]
+ )
+ def vbach_convert_batch(ifl, mn, pm, p, hl, ir, fr, rms, pr, f0min, f0max, on, fn, of, sm):
+ output_converted_files = []
+ progress = gr.Progress()
+ if ifl:
+ for i, file in enumerate(ifl, start=1):
+ try:
+ print(f"Файл {i} из {len(ifl)}: {file}")
+ progress(progress=(i / len(ifl)), desc=f"Файл {i} из {len(ifl)}")
+ out_conv = vbach_inference(input_file=file, model_name=mn, output_dir=tempfile.mkdtemp(), output_name=on, format_name=True if len(ifl) > 1 else fn, output_format=of, pitch=p, method_pitch=pm, output_bitrate=320, add_params={ "index_rate": ir,"filter_radius": fr,"protect": pr,"rms": rms,"mangio_crepe_hop_length": hl,"f0_min": f0min,"f0_max": f0max,"stereo_mode": sm })
+ output_converted_files.append(out_conv)
+ except Exception as e:
+ print(e)
+ return str(output_converted_files), None
+
+ @gr.render(inputs=[converted_state])
+ def show_players_converted(state):
+ if state != "":
+ output_converted_files = ast.literal_eval(state)
+ if output_converted_files:
+ with gr.Group():
+ for conv_file in output_converted_files:
+ basename = os.path.splitext(os.path.basename(conv_file))[0]
+ gr.Audio(
+ label=basename,
+ value=conv_file,
+ type="filepath",
+ interactive=False,
+ show_download_button=True,
+ )
+
+ with gr.TabItem("Менеджер"):
+ with gr.TabItem("Загрузить по ссылке"):
+ with gr.TabItem("Через zip файл"):
+ with gr.Row():
+ with gr.Column(variant="panel"):
+ url_zip = gr.Text(label="Ссылка на zip файл")
+ with gr.Group():
+ url_zip_model_name = gr.Text(
+ label="Имя модели",
+ )
+ url_zip_download_btn = gr.Button("Загрузить", variant="primary")
+
+ url_zip_output = gr.Text(label="Статус", interactive=False, lines=5)
+ url_zip_download_btn.click(
+ (lambda x, y: self.vbach_model_manager.install_model_zip(x, self.namer.short(self.namer.sanitize(y), length=40), "url")),
+ inputs=[url_zip, url_zip_model_name],
+ outputs=url_zip_output,
+ )
+
+ with gr.TabItem("Через отдельные файлы"):
+ with gr.Row():
+ with gr.Column(variant="panel"):
+ url_pth = gr.Text(label="Ссылка на *.pth файл")
+ url_index = gr.Text(label="Ссылка на *.index файл (необязательно)")
+ with gr.Group():
+ url_file_model_name = gr.Text(
+ label="Имя модели",
+ )
+ url_file_download_btn = gr.Button("Загрузить", variant="primary")
+
+ url_file_output = gr.Text(label="Статус", interactive=False, lines=5)
+ url_file_download_btn.click(
+ (lambda x, y, z: self.vbach_model_manager.install_model_files(x, y, self.namer.short(self.namer.sanitize(z), length=40), "url")),
+ inputs=[url_index, url_pth, url_file_model_name],
+ outputs=url_file_output,
+ )
+
+ with gr.Tab("Загрузить с устройства"):
+ with gr.Tab("Через zip файл"):
+ with gr.Row():
+ with gr.Column():
+ local_zip = gr.File(
+ label="zip файл", file_types=[".zip"], file_count="single"
+ )
+ with gr.Column(variant="panel"):
+ with gr.Group():
+ local_zip_model_name = gr.Text(
+ label="Имя модели",
+ )
+ local_zip_upload_btn = gr.Button("Загрузить", variant="primary")
+
+ local_zip_output = gr.Text(label="Статус", interactive=False, lines=5)
+ local_zip_upload_btn.click(
+ (lambda x, y: self.vbach_model_manager.install_model_zip(x, self.namer.short(self.namer.sanitize(y), length=40), "local")),
+ inputs=[local_zip, local_zip_model_name],
+ outputs=local_zip_output,
+ )
+
+ with gr.TabItem("Через отдельные файлы"):
+ with gr.Group():
+ with gr.Row():
+ local_pth = gr.File(
+ label="*.pth файл", file_types=[".pth"], file_count="single"
+ )
+ local_index = gr.File(
+ label="*.index файл (необязательно)", file_types=[".index"], file_count="single"
+ )
+ with gr.Column(variant="panel"):
+ with gr.Group():
+ local_file_model_name = gr.Text(
+ label="Имя модели",
+ )
+ local_file_upload_btn = gr.Button("Загрузить", variant="primary")
+
+ local_file_output = gr.Text(
+ label="Статус", interactive=False
+ )
+ local_file_upload_btn.click(
+ (lambda x, y, z: self.vbach_model_manager.install_model_files(x, y, self.namer.short(self.namer.sanitize(z), length=40), "local")),
+ inputs=[local_index, local_pth, local_file_model_name],
+ outputs=local_file_output,
+ )
+
+ with gr.TabItem("Удалить модель"):
+ with gr.Column(variant="panel"):
+ with gr.Group():
+ delete_model_name = gr.Dropdown(
+ label="Имя модели",
+ choices=self.vbach_model_manager.parse_voice_models(),
+ interactive=True,
+ filterable=False
+ )
+ delete_refresh_btn = gr.Button("Обновить")
+ @delete_refresh_btn.click(inputs=None, outputs=delete_model_name)
+ def refresh_list_voice_models():
+ models = []
+ models = self.vbach_model_manager.parse_voice_models()
+ first_model = None
+ if len(models) > 0:
+ first_model = models[0]
+ return gr.update(choices=models, value=first_model)
+
+ delete_output = gr.Text(
+ label="Статус", interactive=False, lines=5
+ )
+ delete_btn = gr.Button("Удалить")
+ delete_btn.click(
+ fn=self.vbach_model_manager.del_voice_model,
+ inputs=delete_model_name,
+ outputs=delete_output
+ )
+
+ @gr.on(fn="decorator", inputs=None, outputs=[delete_model_name, model_name])
+ def refresh_list_voice_models():
+ models = []
+ models = self.vbach_model_manager.parse_voice_models()
+ first_model = None
+ if len(models) > 0:
+ first_model = models[0]
+ return gr.update(choices=models, value=first_model), gr.update(choices=models, value=first_model)
+
+
+class PluginManager(Separator):
+ plugins_dir = os.path.join(script_dir, "plugins")
+ os.makedirs(plugins_dir, exist_ok=True)
+
+ def restart_after_install_plugin(self):
+ subprocess.Popen([os.sys.executable] + sys.argv)
+ os._exit(0)
+
+ def parse_plugins(self):
+ for plugin_file in os.listdir(self.plugins_dir):
+ # Пропускаем не-Python файлы и __init__.py
+ if not plugin_file.endswith('.py') or plugin_file == '__init__.py':
+ continue
+
+ # Получаем имя модуля без расширения
+ plugin_module_name = os.path.splitext(plugin_file)[0]
+
+ try:
+ # Определяем путь импорта в зависимости от структуры проекта
+ if __package__:
+ # Если мы в пакете, используем абсолютный импорт
+ plugin_module = importlib.import_module(f".plugins.{plugin_module_name}", package=__package__)
+ else:
+ # Если не в пакете, пробуем разные варианты
+ try:
+ # Попробуем абсолютный импорт
+ plugin_module = importlib.import_module(f"plugins.{plugin_module_name}")
+ except ImportError:
+ # Если не сработало, загружаем из файла
+ plugin_path = os.path.join(self.plugins_dir, plugin_file)
+ spec = importlib.util.spec_from_file_location(plugin_module_name, plugin_path)
+ plugin_module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(plugin_module)
+
+ # Получаем класс Plugin из модуля
+ plugin_class = getattr(plugin_module, 'Plugin')
+
+ # Создаем экземпляр плагина
+ plugin_instance = plugin_class()
+
+ # Создаем UI плагина
+ with gr.Tab(plugin_instance.name):
+ plugin_instance.UI()
+
+ except Exception as e:
+ print(f"Ошибка загрузки плагина {plugin_module_name}: {e}")
+ continue
+
+ def UI(self):
+ with gr.Tab("Установка"):
+ upload_plugins_files = gr.File(label="Загрузить плагины", file_types=[".py"], file_count="multiple", interactive=True)
+ install_plugins_btn = gr.Button("Установить", interactive=True)
+
+ @install_plugins_btn.click(
+ inputs=[upload_plugins_files]
+ )
+ def upload_plugin_list(files):
+ if not files:
+ return
+ for file in files:
+ try:
+ # Копируем только .py файлы
+ if file.name.endswith('.py'):
+ shutil.copy(
+ file, os.path.join(self.plugins_dir, os.path.basename(file).replace(" ", "_"))
+ )
+ except Exception as e:
+ print(f"Ошибка копирования файла {file}: {e}")
+ time.sleep(2)
+ self.restart_after_install_plugin()
+
+ self.parse_plugins()
+
+def mvsepless_app(theme):
+ css = None
+ with gr.Blocks(theme=theme, css=css, title="Разделение музыки и вокала") as app:
+
+ output_base_dir_state = gr.State(value=os.path.join(os.getcwd(), "outputs"))
+ output_temp_dir_state = gr.State(value=False)
+
+ with gr.Tab("Инференс"):
+ Separator().UI(output_base_dir_state, output_temp_dir_state)
+
+ with gr.Tab("Ансамбль"):
+ with gr.Tab("Авто-ансамбль"):
+ AutoEnsembless().UI(output_base_dir_state, output_temp_dir_state)
+
+ with gr.Tab("Ручной ансамбль"):
+ ManualEnsembless().UI(output_base_dir_state, output_temp_dir_state)
+
+ with gr.Tab("Вычитание"):
+ Inverter_UI().UI()
+
+ with gr.Tab("Преобразование"):
+ Vbach().UI()
+
+ with gr.Tab("Плагины"):
+ PluginManager().UI()
+
+ with gr.Tab("Настройки"):
+ with gr.Column():
+ output_base_dir_ui = gr.Textbox(
+ label="Базовый каталог для выходных файлов",
+ interactive=True,
+ value=os.path.join(os.getcwd(), "outputs"),
+ lines=5
+ )
+ output_temp_dir_check_ui = gr.Checkbox(
+ label="Использовать временный каталог для выходных файлов",
+ interactive=True,
+ value=False
+ )
+
+ # Связываем UI элементы с состоянием
+ output_base_dir_ui.change(
+ lambda x: x,
+ inputs=[output_base_dir_ui],
+ outputs=[output_base_dir_state]
+ )
+ output_temp_dir_check_ui.change(
+ lambda x: x,
+ inputs=[output_temp_dir_check_ui],
+ outputs=[output_temp_dir_state]
+ )
+
+ return app
+
diff --git a/mvsepless/__main__.py b/mvsepless/__main__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0009e708a942278fd688793386f12963522a822
--- /dev/null
+++ b/mvsepless/__main__.py
@@ -0,0 +1,67 @@
+import argparse
+import gradio as gr
+import os
+
+if not __package__:
+ from __init__ import mvsepless_app, Separator
+else:
+ from .__init__ import mvsepless_app, Separator
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="MVSepless")
+ subparsers = parser.add_subparsers(dest="command")
+ app_parser = subparsers.add_parser("app", help="Приложение MVSepless")
+ app_parser.add_argument("--port", type=int, default=None, help="Порт для запуска сервера Gradio.")
+ app_parser.add_argument("--share", action="store_true", help="Создать публичную ссылку для приложения Gradio.")
+ cli_parser = subparsers.add_parser("cli", help="CLI MVSepless Lite")
+ cli_parser.add_argument("--input", type=str, required=True, help="Входной аудиофайл или каталог.")
+ cli_parser.add_argument("--output_dir", type=str, default=None, help="Каталог для выходных файлов.")
+ cli_parser.add_argument("--model_type", type=str, default="mel_band_roformer", help="Тип модели разделения.")
+ cli_parser.add_argument("--model_name", type=str, default="Mel-Band-Roformer_Vocals_kimberley_jensen", help="Имя модели разделения.")
+ cli_parser.add_argument("--ext_inst", action="store_true", help="Извлечь инструментал.")
+ cli_parser.add_argument("--output_format", type=str, default="mp3", choices=Separator.audio.output_formats, help="Формат выходного файла.")
+ cli_parser.add_argument("--output_bitrate", type=str, default="320k", help="Битрейт выходного файла.")
+ cli_parser.add_argument("--template", type=str, default="NAME (STEM) MODEL", help="Шаблон именования выходных файлов.")
+ cli_parser.add_argument("--selected_stems", type=str, nargs='*', default=None, help="Выбранные стемы для разделения.")
+ args = parser.parse_args()
+
+ if args.command == "app":
+ theme = gr.themes.Citrus(
+ primary_hue="teal",
+ secondary_hue="blue",
+ neutral_hue="blue",
+ spacing_size="sm",
+ font=[
+ gr.themes.GoogleFont("Montserrat"),
+ "ui-sans-serif",
+ "system-ui",
+ "sans-serif",
+ ],
+ )
+ mvsepless_lite_app = mvsepless_app(theme=theme)
+ mvsepless_lite_app.launch(server_name="0.0.0.0", server_port=args.port, share=args.share, allowed_paths=["/"], debug=True)
+ elif args.command == "cli":
+ input_data = args.input
+ if os.path.isdir(input_data):
+ list_valid_files = []
+ for file in os.listdir(args.input):
+ if os.path.isfile(os.path.join(args.input, file)):
+ if Separator.audio.check(os.path.join(args.input, file)):
+ list_valid_files.append(os.path.join(args.input, file))
+
+ input_files = list_valid_files
+ else:
+ input_files = input_data
+
+ results = Separator().separate(
+ input=input_files,
+ output_dir=args.output_dir,
+ model_type=args.model_type,
+ model_name=args.model_name,
+ ext_inst=args.ext_inst,
+ output_format=args.output_format,
+ output_bitrate=args.output_bitrate,
+ template=args.template,
+ selected_stems=args.selected_stems
+ )
+ print("Разделение завершено.")
\ No newline at end of file
diff --git a/mvsepless/audio.py b/mvsepless/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..15a6a874a99bf87b48c67fa130280fd65401232c
--- /dev/null
+++ b/mvsepless/audio.py
@@ -0,0 +1,781 @@
+import os
+from pathlib import Path
+import sys
+import json
+import subprocess
+import numpy as np
+from typing import Literal
+from collections.abc import Callable
+from pathlib import Path
+from numpy.typing import DTypeLike
+import tempfile
+import librosa
+if not __package__:
+ from namer import Namer
+else:
+ from .namer import Namer
+class NotInputFileSpecified(Exception): pass
+class NotOutputFileSpecified(Exception): pass
+class NotSupportedDataType(Exception): pass
+class ErrorDecode(Exception): pass
+class ErrorEncode(Exception): pass
+class NotSupportedFormat(Exception): pass
+class SampleRateError(Exception): pass
+class FileIsNotAudio(Exception): pass
+
+class Audio(Namer):
+ def __init__(self):
+ """
+Чтение и запись аудио файла через ffmpeg
+
+Поддерживаемые типы данных: - int16, int32, float32, float64
+ """
+ super().__init__()
+ self.ffmpeg_path = os.environ.get("MVSEPLESS_FFMPEG", "ffmpeg")
+ self.ffprobe_path = os.environ.get("MVSEPLESS_FFPROBE", "ffprobe")
+ self.output_formats = ("mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff")
+ self.input_formats = ("mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff", "mp4", "mkv", "webm", "avi", "mov", "ts")
+ self.supported_dtypes = ("int16", "int32", "float32", "float64")
+ self.dtypes_dict = {
+ "int16": "s16le",
+ "int32": "s32le",
+ "float32": "f32le",
+ "float64": "f64le",
+ np.int16: "s16le",
+ np.int32: "s32le",
+ np.float32: "f32le",
+ np.float64: "f64le",
+ }
+ self.bitrate_limit = {
+ "mp3": {"min": 8, "max": 320},
+ "aac": {"min": 8, "max": 512},
+ "m4a": {"min": 8, "max": 512},
+ "ac3": {"min": 32, "max": 640},
+ "ogg": {"min": 64, "max": 500},
+ "opus": {"min": 6, "max": 512},
+ }
+ self.sample_rates = {
+ "mp3": {
+ "supported": (44100, 48000, 32000, 22050, 24000, 16000, 11025, 12000, 8000)
+ },
+ "opus": {"supported": (48000, 24000, 16000, 12000, 8000)},
+ "m4a": {
+ "supported": (
+ 96000,
+ 88200,
+ 64000,
+ 48000,
+ 44100,
+ 32000,
+ 24000,
+ 22050,
+ 16000,
+ 12000,
+ 11025,
+ 8000,
+ 7350,
+ )
+ },
+ "aac": {
+ "supported": (
+ 96000,
+ 88200,
+ 64000,
+ 48000,
+ 44100,
+ 32000,
+ 24000,
+ 22050,
+ 16000,
+ 12000,
+ 11025,
+ 8000,
+ 7350,
+ )
+ },
+ "ac3": {
+ "supported": (
+ 48000,
+ 44100,
+ 32000,
+ )
+ },
+ "ogg": {"min": 6, "max": 192000},
+ "wav": {"min": 0, "max": float("inf")},
+ "aiff": {"min": 0, "max": float("inf")},
+ "flac": {"min": 0, "max": 192000},
+ }
+ self.check_ffmpeg()
+ self.check_ffprobe()
+
+ def check_ffmpeg(self):
+ """
+ Проверяет, установлен ли ffmpeg?
+ """
+ try:
+ ffmpeg_version_output = subprocess.check_output(
+ [self.ffmpeg_path, "-version"], text=True
+ )
+ except FileNotFoundError:
+ if "PYTEST_CURRENT_TEST" not in os.environ:
+ raise FileNotFoundError("FFMPEG не установлен. Укажите путь к установленному FFMPEG через переменную окружения MVSEPLESS_FFMPEG")
+
+ def check_ffprobe(self):
+ """
+ Проверяет, установлен ли ffprobe?
+ """
+ try:
+ ffmpeg_version_output = subprocess.check_output(
+ [self.ffprobe_path, "-version"], text=True
+ )
+ except FileNotFoundError:
+ if "PYTEST_CURRENT_TEST" not in os.environ:
+ raise FileNotFoundError("FFPROBE не установлен. Укажите путь к установленному FFPROBE через переменную окружения MVSEPLESS_FFPROBE")
+
+
+ def fit_sr(
+ self,
+ f: str | Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] = "mp3",
+ sr: int = 44100
+ ) -> int:
+ """
+ Исправляет значение частоты дисректизации выходного файла
+
+ Параметры:
+ f: Формат вывода
+ sr: Частота дискретизации (целое число)
+ Возвращает:
+ sr: Исправленная частота дискретизации
+ """
+ format_info = self.sample_rates.get(f.lower())
+
+ if not format_info:
+ return None # Формат не найден
+
+ if "supported" in format_info:
+ # Для форматов с конкретным списком
+ supported_rates = format_info["supported"]
+ if sr in supported_rates:
+ return sr
+
+ # Находим ближайшую поддерживаемую частоту
+ return min(supported_rates, key=lambda x: abs(x - sr))
+
+ elif "min" in format_info and "max" in format_info:
+ # Для форматов с диапазоном - обрезаем до границ
+ min_rate = format_info["min"]
+ max_rate = format_info["max"]
+
+ if sr < min_rate:
+ return min_rate
+ elif sr > max_rate:
+ return max_rate
+ else:
+ return sr
+
+ return None
+
+ def fit_br(
+ self,
+ f: str | Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] = "mp3",
+ br: int = 320
+ ) -> int:
+ """
+ Исправляет значение битрейта выходного файла
+
+ Параметры:
+ f: Формат вывода
+ br: Битрейт (целое число)
+ Возвращает:
+ br: Исправленный битрейт
+ """
+ if f not in self.bitrate_limit:
+ raise NotSupportedFormat(f"Формат {f} не поддерживается")
+
+ limits = self.bitrate_limit[f]
+
+ if br < limits["min"]:
+ return limits["min"]
+ elif br > limits["max"]:
+ return limits["max"]
+ else:
+ return br
+
+ def get_info(
+ self,
+ i: str | os.PathLike | Callable | None = None,
+ ) -> dict[int, dict[int, float]]:
+ """
+ Получает информацию о аудио потоках из файла напрямую через FFMPEG
+
+ Параметры:
+ i: Путь к выходному файлу
+ Возвращает:
+ audio_info: Словарь с информацией о аудиопотоках вида:
+
+ {Номер потока:
+ {
+ "sample_rate": Частота дисректизации (является целым числом),
+ "duration": Длительность аудиопотока (является числом с плавающей точкой)
+ }
+ }
+ """
+ audio_info = {}
+ if i:
+ if isinstance(i, Path):
+ i = str(i)
+ if os.path.exists(i):
+ cmd = [self.ffprobe_path, "-i", i, "-v", "quiet", "-hide_banner",
+ "-show_entries", "stream=index,sample_rate,duration", "-select_streams", "a", "-of", "json"]
+
+ process = subprocess.Popen(
+ cmd,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+
+ stdout, stderr = process.communicate()
+
+ if process.returncode != 0:
+ print(f"STDERR: {stderr.decode('utf-8')}")
+ print(f"STDOUT: {stdout.decode('utf-8')}")
+
+ json_output = json.loads(stdout)
+ streams = json_output["streams"]
+ if not streams:
+ pass
+
+ else:
+ for a, stream in enumerate(streams):
+ audio_info[a] = {
+ "sample_rate": int(stream.get("sample_rate", 0)),
+ "duration": float(stream.get("duration", 0))
+ }
+
+ return audio_info
+
+ else:
+ raise FileExistsError("Указанного файла не существует")
+
+ else:
+ raise NotInputFileSpecified("Не указан путь к файлу")
+
+ def check(
+ self,
+ i: str | os.PathLike | Callable | None = None
+ ) -> bool:
+ """
+ Проверяет, является ли файл аудио или видео файлом, поддерживаемым ffmpeg
+
+ Параметры:
+ i: Путь к выходному файлу
+ Возвращает:
+ is_audio_video: Булево значение, является ли файл аудио или видео файлом
+ """
+ if i:
+ if isinstance(i, Path):
+ i = str(i)
+ if os.path.exists(i):
+ info = self.get_info(i=i)
+ if info:
+ list_streams = list(info.keys())
+ if len(list_streams) > 0:
+ if info[0].get("sample_rate") > 0:
+ return True
+ else:
+ return False
+ else:
+ return False
+ else:
+ return False
+ else:
+ raise FileExistsError("Указанного файла не существует")
+ else:
+ raise NotInputFileSpecified("Не указан путь к файлу")
+
+ def read(
+ self,
+ i: str | os.PathLike | Callable | None = None,
+ sr: int | None = None,
+ mono: bool = False,
+ dtype: DTypeLike = np.float32,
+ s: int = 0
+ ) -> tuple[np.ndarray, int, float]:
+ """
+ Читает аудио-файл, преобразовывая его в массив с аудио данными напрямую через FFMPEG
+ Является заменой soundfile.read() и librosa.load()
+
+ Параметры:
+ i: Путь к выходному файлу
+ sr: Целевая частота дискретизации (Если не указана, то используется частота дискретизации входного файла)
+ mono: Конвертация в моно (по умолчанию отключена)
+ dtype: Тип данных (поддерживаются типы: int16, int32, float32, float64; по умолчанию - float32)
+ s: Номер аудиопотока (по умолчанию 0)
+ Возвращает:
+ audio_array: Массив с аудио данными
+ sr: Частота дискретизации массива
+ duration: Длительность аудио (количество сэмплов / частота дискретизации)
+ """
+ output_format = self.dtypes_dict.get(dtype, None)
+ if not output_format:
+ raise NotSupportedDataType(f"Этот тип данных не поддерживается {dtype}")
+ if i:
+ if isinstance(i, Path):
+ i = str(i)
+ if os.path.exists(i):
+ audio_info = self.get_info(i=i)
+ list_streams = list(audio_info.keys())
+ if audio_info.get(s, False):
+ stream = s
+ else:
+ if len(list_streams) > 0:
+ stream = 0
+ else:
+ raise FileIsNotAudio("В входном файле нет аудио потоков")
+
+ sample_rate_input = audio_info[stream]["sample_rate"]
+ if sample_rate_input == 0:
+ raise FileIsNotAudio("В входном файле нет аудио потоков")
+
+ cmd = [
+ self.ffmpeg_path,
+ "-i", i,
+ "-map", f"0:a:{stream}", "-vn",
+ "-f", output_format,
+ "-ac", "1" if mono else "2",
+ ]
+
+ if sr:
+ cmd.extend(["-ar", str(sr)])
+ else:
+ sr = sample_rate_input
+
+ cmd.append("pipe:1")
+
+ process = subprocess.Popen(
+ cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ bufsize=10**8
+ )
+
+ try:
+
+ raw_audio, stderr = process.communicate(timeout=300)
+
+ if process.returncode != 0:
+ raise ErrorDecode(f"FFmpeg error: {stderr.decode()}")
+
+ except subprocess.TimeoutExpired:
+ process.kill()
+ raise ErrorDecode("FFmpeg timeout при чтении файла")
+
+ audio_array = np.frombuffer(raw_audio, dtype=dtype)
+
+ channels = 1 if mono else 2
+ audio_array = audio_array.reshape((-1, channels)).T
+ if audio_array.ndim > 1 and channels == 1:
+ audio_array = np.mean(audio_array, axis=tuple(range(audio_array.ndim - 1)))
+
+ len_samples = float(audio_array.shape[-1])
+
+ duration = len_samples / sr
+
+ print(f"Частота дискретизации: {sr}")
+
+ return audio_array, sr, duration
+ else:
+ raise FileExistsError("Указанного файла не существует")
+
+ else:
+ raise NotInputFileSpecified("Не указан путь к файлу")
+
+ def write(
+ self,
+ o: str | os.PathLike | Callable | None = None,
+ array: np.ndarray = np.array([], dtype=np.float32),
+ sr: int = 44100,
+ of: str | Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] | None = None,
+ br: str | int | None = None
+ ) -> str:
+ """
+ Записывает numpy-массив с аудио данными в файл напрямую через ffmpeg.
+ Является заменой soundfile.write()
+
+ Параметры:
+ o: Путь к выходному файлу
+ array: Массив с аудио данными (поддерживаются типы: int16, int32, float32, float64)
+ sr: Частота дискретизации массива
+ of: Формат вывода (по умолчанию mp3)
+ br: Битрейт для кодеков, сжимающих аудио с потерями
+ Возвращает:
+ o: Путь к выходному файлу
+ """
+ if isinstance(array, np.ndarray):
+
+ if len(array.shape) == 1:
+ array = array.reshape(-1, 1)
+ elif len(array.shape) == 2:
+ if array.shape[0] == 2:
+ array = array.T
+ else:
+ raise ValueError("numpy-массив должен быть либо одномерным, либо двухмерным")
+
+ if array.dtype == np.int16:
+ input_format = "s16le"
+ elif array.dtype == np.int32:
+ input_format = "s32le"
+ elif array.dtype == np.float32:
+ input_format = "f32le"
+ elif array.dtype == np.float64:
+ input_format = "f64le"
+ else:
+ raise NotSupportedDataType(f"Этот тип данных не поддерживается {array.dtype}")
+
+ if array.shape[1] == 1:
+ audio_bytes = array.tobytes()
+
+ channels = 1
+
+ elif array.shape[1] == 2:
+ audio_bytes = array.tobytes()
+
+ channels = 2
+ else:
+ raise ValueError("numpy-массив должен содержать 1 или 2 канала")
+
+ else:
+ raise ValueError("Вход должен быть numpy-массивом")
+
+ if o:
+ if isinstance(o, Path):
+ o = str(o)
+ output_dir = os.path.dirname(o)
+ output_base = os.path.basename(o)
+ output_name, output_ext = os.path.splitext(output_base)
+ if output_dir != "":
+ os.makedirs(output_dir, exist_ok=True)
+ if output_ext == "":
+ if of:
+ o += f".{of}"
+ else:
+ o += f".mp3"
+ elif output_ext == ".":
+ if of:
+ o += f"{of}"
+ else:
+ o += f"mp3"
+ else:
+ raise NotOutputFileSpecified("Не указан путь к выходному файлу")
+
+ if of:
+ if of in self.output_formats:
+ output_name, output_ext = os.path.splitext(o)
+ if output_ext == f".{of}":
+ pass
+ else:
+ o = f"{os.path.join(output_dir, output_name)}.{of}"
+ else:
+ raise NotSupportedFormat(f"Неподдерживаемый формат: {of}")
+ else:
+ of = os.path.splitext(o)[1].strip(".")
+ if of in self.output_formats:
+ pass
+ else:
+ raise NotSupportedFormat(f"Неподдерживаемый формат: {of}")
+
+ if sr:
+ if isinstance(sr, int):
+ sample_rate_fixed = self.fit_sr(f=of, sr=sr)
+ elif isinstance(sr, float):
+ sr = int(sr)
+ sample_rate_fixed = self.fit_sr(f=of, sr=sr)
+ else:
+ raise SampleRateError(f"Частота дискретизации должна быть числом\n\nЗначение: {sr}\nТип: {type(sr)}")
+ else:
+ raise SampleRateError("Не указана частота дискретизации")
+
+ bitrate_fixed = "320k"
+
+ if of not in ["wav", "flac", "aiff"]:
+ if br:
+ if isinstance(br, int):
+ bitrate_fixed = self.fit_br(f=of, br=br)
+ elif isinstance(br, float):
+ bitrate_fixed = self.fit_br(f=of, br=int(br))
+ elif isinstance(br, str):
+ bitrate_fixed = self.fit_br(f=of, br=int(br.strip("k").strip("K")))
+ else:
+ bitrate_fixed = self.fit_br(f=of, br=320)
+ else:
+ bitrate_fixed = self.fit_br(of, 320)
+
+ format_settings = {
+ "wav": [
+ "-c:a",
+ "pcm_f32le",
+ "-sample_fmt",
+ "flt",
+ ],
+ "aiff": [
+ "-c:a",
+ "pcm_f32le",
+ "-sample_fmt",
+ "flt",
+ ],
+ "flac": [
+ "-c:a",
+ "flac",
+ "-compression_level",
+ "12",
+ "-sample_fmt",
+ "s32",
+ ],
+ "mp3": [
+ "-c:a",
+ "libmp3lame",
+ "-b:a",
+ f"{bitrate_fixed}k",
+ ],
+ "ogg": [
+ "-c:a",
+ "libvorbis",
+ "-b:a",
+ f"{bitrate_fixed}k",
+ ],
+ "opus": [
+ "-c:a",
+ "libopus",
+ "-b:a",
+ f"{bitrate_fixed}k",
+ ],
+ "m4a": [
+ "-c:a",
+ "aac",
+ "-b:a",
+ f"{bitrate_fixed}k",
+ ],
+ "aac": [
+ "-c:a",
+ "aac",
+ "-b:a",
+ f"{bitrate_fixed}k",
+ ],
+ "ac3": [
+ "-c:a",
+ "ac3",
+ "-b:a",
+ f"{bitrate_fixed}k",
+ ],
+ }
+
+ cmd = [
+ self.ffmpeg_path,
+ "-y",
+ "-f",
+ input_format,
+ "-ar",
+ str(sr),
+ "-ac",
+ str(channels),
+ "-i",
+ "pipe:0",
+ "-ac",
+ str(channels),
+ ]
+
+ cmd.extend(["-ar", str(sample_rate_fixed)])
+ cmd.extend(format_settings[of])
+ o_dir, o_base = os.path.split(o)
+ o_base_n, o_base_ext = os.path.splitext(o_base)
+ o_base_n = self.sanitize(o_base_n)
+ o_base_n = self.short(o_base_n)
+ o = os.path.join(o_dir, f"{o_base_n}{o_base_ext}")
+ o = self.iter(o)
+ cmd.append(o)
+
+ process = subprocess.Popen(
+ cmd,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+
+ try:
+ stdout, stderr = process.communicate(input=audio_bytes, timeout=300)
+ except subprocess.TimeoutExpired:
+ process.kill()
+ raise ErrorEncode("FFmpeg timeout: операция заняла слишком много времени")
+
+ if process.returncode != 0:
+ raise ErrorEncode(f"FFmpeg завершился с ошибкой (код: {process.returncode})")
+
+ return os.path.abspath(o)
+
+class Inverter(Audio):
+ def __init__(self):
+ super().__init__()
+ self.test = "test"
+ self.w_types = [
+ "boxcar", # Прямоугольное окно
+ "triang", # Треугольное окно
+ "blackman", # Окно Блэкмана
+ "hamming", # Окно Хэмминга
+ "hann", # Окно Ханна
+ "bartlett", # Окно Бартлетта
+ "flattop", # Окно с плоской вершиной
+ "parzen", # Окно Парзена
+ "bohman", # Окно Бохмана
+ "blackmanharris", # Окно Блэкмана-Харриса
+ "nuttall", # Окно Нуттала
+ "barthann", # Окно Бартлетта-Ханна
+ "cosine", # Косинусное окно
+ "exponential", # Экспоненциальное окно
+ "tukey", # Окно Туки
+ "taylor", # Окно Тейлора
+ "lanczos", # Окно Ланцоша
+ ]
+
+ def load_audio(self, filepath):
+ """Загрузка аудиофайла с помощью librosa"""
+ try:
+ y, sr, _ = self.read(i=filepath, sr=None, mono=False)
+ return y, sr
+ except Exception as e:
+ print(f"Ошибка загрузки аудио: {e}")
+ return None, None
+
+ def process_channel(self, y1_ch, y2_ch, sr, method, w_size=2048, overlap=2, w_type="hann"):
+ """Обработка одного аудиоканала"""
+ HOP_LENGTH = w_size // overlap
+ if method == "waveform":
+ return y1_ch - y2_ch
+
+ elif method == "spectrogram":
+ # Вычисляем спектрограммы
+ S1 = librosa.stft(
+ y1_ch, n_fft=w_size, hop_length=HOP_LENGTH, win_length=w_size
+ )
+ S2 = librosa.stft(
+ y2_ch, n_fft=w_size, hop_length=HOP_LENGTH, win_length=w_size
+ )
+
+ # Амплитудные спектрограммы
+ mag1 = np.abs(S1)
+ mag2 = np.abs(S2)
+
+ # Спектральное вычитание
+ mag_result = np.maximum(mag1 - mag2, 0)
+
+ # Сохраняем фазовую информацию исходного сигнала
+ phase = np.angle(S1)
+
+ # Комбинируем амплитуду результата с фазой
+ S_result = mag_result * np.exp(1j * phase)
+
+ # Обратное преобразование
+ return librosa.istft(
+ S_result,
+ n_fft=w_size,
+ hop_length=HOP_LENGTH,
+ win_length=w_size,
+ length=len(y1_ch),
+ )
+
+ def process_audio(self, audio1_path, audio2_path, out_format, method, output_path="./inverted.mp3", w_size=2048, overlap=2, w_type="hann"):
+ # Загрузка аудиофайлов
+ y1, sr1 = self.load_audio(audio1_path)
+ y2, sr2 = self.load_audio(audio2_path)
+
+ if sr1 is None or sr2 is None:
+ raise Exception("Произошла ошибка при чтении файлов")
+
+ # Определяем количество каналов
+ channels1 = 1 if y1.ndim == 1 else y1.shape[0]
+ channels2 = 1 if y2.ndim == 1 else y2.shape[0]
+
+ # Преобразование в форму (samples, channels)
+ if channels1 > 1:
+ y1 = y1.T # (channels, samples) -> (samples, channels)
+ else:
+ y1 = y1.reshape(-1, 1)
+
+ if channels2 > 1:
+ y2 = y2.T # (channels, samples) -> (samples, channels)
+ else:
+ y2 = y2.reshape(-1, 1)
+
+ if sr1 != sr2:
+ if channels2 > 1:
+ # Ресемплинг для каждого канала отдельно
+ y2_resampled_list = []
+ for c in range(channels2):
+ channel_resampled = librosa.resample(
+ y2[:, c], orig_sr=sr2, target_sr=sr1
+ )
+ y2_resampled_list.append(channel_resampled)
+
+ # Находим минимальную длину среди всех каналов
+ min_channel_length = min(len(ch) for ch in y2_resampled_list)
+
+ # Обрезаем все каналы до одинаковой длины и собираем в массив
+ y2_resampled = np.zeros((min_channel_length, channels2), dtype=np.float32)
+ for c, channel in enumerate(y2_resampled_list):
+ y2_resampled[:, c] = channel[:min_channel_length]
+
+ y2 = y2_resampled
+ else:
+ y2 = librosa.resample(y2[:, 0], orig_sr=sr2, target_sr=sr1)
+ y2 = y2.reshape(-1, 1)
+ sr2 = sr1
+
+ # Приводим к одинаковой длине
+ min_len = min(len(y1), len(y2))
+ y1 = y1[:min_len]
+ y2 = y2[:min_len]
+
+ # Обрабатываем каждый канал отдельно
+ result_channels = []
+
+ # Если основной сигнал моно, а удаляемый стерео - преобразуем удаляемый в моно
+ if channels1 == 1 and channels2 > 1:
+ y2 = y2.mean(axis=1, keepdims=True)
+ channels2 = 1
+
+ for c in range(channels1):
+ # Выбираем канал для основного сигнала
+ y1_ch = y1[:, c]
+
+ # Выбираем канал для удаляемого сигнала
+ if channels2 == 1:
+ y2_ch = y2[:, 0]
+ else:
+ # Если каналов удаляемого сигнала больше, используем соответствующий канал
+ y2_ch = y2[:, min(c, channels2 - 1)]
+
+ # Обрабатываем канал
+ result_ch = self.process_channel(y1_ch, y2_ch, sr1, method, w_size=w_size, overlap=overlap, w_type=w_type)
+ result_channels.append(result_ch)
+
+ # Собираем каналы в один массив
+ if len(result_channels) > 1:
+ result = np.column_stack(result_channels)
+ else:
+ result = np.array(result_channels[0])
+
+ # Нормализация (предотвращение клиппинга)
+ if result.ndim > 1:
+ # Для многоканального аудио нормализуем каждый канал отдельно
+ for c in range(result.shape[1]):
+ channel = result[:, c]
+ max_val = np.max(np.abs(channel))
+ if max_val > 0:
+ result[:, c] = channel * 0.9 / max_val
+ else:
+ max_val = np.max(np.abs(result))
+ if max_val > 0:
+ result = result * 0.9 / max_val
+
+ inverted = self.write(o=output_path, array=result.T, sr=sr1, of=out_format, br="320k")
+ return inverted
diff --git a/mvsepless/downloader.py b/mvsepless/downloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5eeb1ff7d4c518c36271ae1b7b87586bb83471c
--- /dev/null
+++ b/mvsepless/downloader.py
@@ -0,0 +1,92 @@
+import os
+import yt_dlp
+import time
+from tqdm import tqdm
+import urllib.request
+
+DOWNLOAD_DIR = os.environ.get(
+ "MVSEPLESS_DOWNLOAD_DIR", os.path.join(os.getcwd(), "downloaded")
+)
+
+def dw_file(url_model: str, local_path: str, retries: int = 30):
+ dir_name = os.path.dirname(local_path)
+ if dir_name != "":
+ os.makedirs(dir_name, exist_ok=True)
+
+ class TqdmUpTo(tqdm):
+ def update_to(self, b=1, bsize=1, tsize=None):
+ if tsize is not None:
+ self.total = tsize
+ self.update(b * bsize - self.n)
+
+ for attempt in range(retries):
+ try:
+ with TqdmUpTo(
+ unit="B",
+ unit_scale=True,
+ unit_divisor=1024,
+ miniters=1,
+ desc=os.path.basename(local_path),
+ ) as t:
+ urllib.request.urlretrieve(
+ url_model, local_path, reporthook=t.update_to
+ )
+ # Если дошли сюда - загрузка успешна, прерываем цикл
+ break
+ except Exception as e:
+ print(f"Попытка {attempt + 1}/{retries} не удалась. Ошибка: {e}")
+ if attempt < retries - 1: # Если это не последняя попытка
+ print("Повторная попытка...")
+ time.sleep(2) # Небольшая задержка перед повторной попыткой
+ else:
+ print("Все попытки загрузки завершились неудачно")
+ raise # Пробрасываем исключение дальше
+
+def dw_yt_dlp(
+ url,
+ output_dir=None,
+ cookie=None,
+ output_format="mp3",
+ output_bitrate="320",
+ title=None,
+):
+ # Подготовка шаблона имени файла
+ outtmpl = "%(title)s.%(ext)s" if title is None else f"{title}.%(ext)s"
+
+ ydl_opts = {
+ "format": "bestaudio/best",
+ "outtmpl": os.path.join(DOWNLOAD_DIR if not output_dir else output_dir, outtmpl),
+ "postprocessors": [
+ {
+ "key": "FFmpegExtractAudio",
+ "preferredcodec": output_format,
+ "preferredquality": output_bitrate,
+ }
+ ],
+ "noplaylist": True, # Скачивать только одно видео, не плейлист
+ "quiet": True, # Отключить вывод в консоль
+ "no_warnings": True, # Скрыть предупреждения
+ }
+
+ # Добавляем cookies если указаны
+ if cookie and os.path.exists(cookie):
+ ydl_opts["cookiefile"] = cookie
+
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
+ try:
+ info = ydl.extract_info(url, download=True)
+ if "_type" in info and info["_type"] == "playlist":
+ # Для плейлистов берем первое видео
+ entry = info["entries"][0]
+ filename = ydl.prepare_filename(entry)
+ else:
+ # Для одиночного видео
+ filename = ydl.prepare_filename(info)
+
+ # Заменяем оригинальное расширение на выбранный формат
+ base, _ = os.path.splitext(filename)
+ audio_file = base + f".{output_format}"
+
+ return os.path.join(DOWNLOAD_DIR, audio_file)
+ except Exception as e:
+ return None
\ No newline at end of file
diff --git a/mvsepless/ensemble.py b/mvsepless/ensemble.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9545704eec617954188d2d74bb88eb30d9a8776
--- /dev/null
+++ b/mvsepless/ensemble.py
@@ -0,0 +1,224 @@
+# coding: utf-8
+__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
+
+import os
+import sys
+import librosa
+import tempfile
+import numpy as np
+import argparse
+if not __package__:
+ from audio import Audio
+ from namer import Namer
+else:
+ from .audio import Audio
+ from .namer import Namer
+
+audio = Audio()
+
+def stft(wave, nfft, hl):
+ wave_left = np.asfortranarray(wave[0])
+ wave_right = np.asfortranarray(wave[1])
+ spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
+ spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
+ spec = np.asfortranarray([spec_left, spec_right])
+ return spec
+
+
+def istft(spec, hl, length):
+ spec_left = np.asfortranarray(spec[0])
+ spec_right = np.asfortranarray(spec[1])
+ wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
+ wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
+ wave = np.asfortranarray([wave_left, wave_right])
+ return wave
+
+
+def absmax(a, *, axis):
+ dims = list(a.shape)
+ dims.pop(axis)
+ indices = np.ogrid[tuple(slice(0, d) for d in dims)]
+ argmax = np.abs(a).argmax(axis=axis)
+ # Convert indices to list before insertion
+ indices = list(indices)
+ indices.insert(axis % len(a.shape), argmax)
+ return a[tuple(indices)]
+
+
+def absmin(a, *, axis):
+ dims = list(a.shape)
+ dims.pop(axis)
+ indices = np.ogrid[tuple(slice(0, d) for d in dims)]
+ argmax = np.abs(a).argmin(axis=axis)
+ indices.insert((len(a.shape) + axis) % len(a.shape), argmax)
+ return a[tuple(indices)]
+
+
+def lambda_max(arr, axis=None, key=None, keepdims=False):
+ idxs = np.argmax(key(arr), axis)
+ if axis is not None:
+ idxs = np.expand_dims(idxs, axis)
+ result = np.take_along_axis(arr, idxs, axis)
+ if not keepdims:
+ result = np.squeeze(result, axis=axis)
+ return result
+ else:
+ return arr.flatten()[idxs]
+
+
+def lambda_min(arr, axis=None, key=None, keepdims=False):
+ idxs = np.argmin(key(arr), axis)
+ if axis is not None:
+ idxs = np.expand_dims(idxs, axis)
+ result = np.take_along_axis(arr, idxs, axis)
+ if not keepdims:
+ result = np.squeeze(result, axis=axis)
+ return result
+ else:
+ return arr.flatten()[idxs]
+
+
+def average_waveforms(pred_track, weights, algorithm):
+ """
+ :param pred_track: shape = (num, channels, length)
+ :param weights: shape = (num, )
+ :param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
+ :return: averaged waveform in shape (channels, length)
+ """
+
+ pred_track = np.array(pred_track)
+ final_length = pred_track.shape[-1]
+
+ mod_track = []
+ for i in range(pred_track.shape[0]):
+ if algorithm == "avg_wave":
+ mod_track.append(pred_track[i] * weights[i])
+ elif algorithm in ["median_wave", "min_wave", "max_wave"]:
+ mod_track.append(pred_track[i])
+ elif algorithm in ["avg_fft", "min_fft", "max_fft", "median_fft"]:
+ spec = stft(pred_track[i], nfft=2048, hl=1024)
+ if algorithm in ["avg_fft"]:
+ mod_track.append(spec * weights[i])
+ else:
+ mod_track.append(spec)
+ pred_track = np.array(mod_track)
+
+ if algorithm in ["avg_wave"]:
+ pred_track = pred_track.sum(axis=0)
+ pred_track /= np.array(weights).sum().T
+ elif algorithm in ["median_wave"]:
+ pred_track = np.median(pred_track, axis=0)
+ elif algorithm in ["min_wave"]:
+ pred_track = np.array(pred_track)
+ pred_track = lambda_min(pred_track, axis=0, key=np.abs)
+ elif algorithm in ["max_wave"]:
+ pred_track = np.array(pred_track)
+ pred_track = lambda_max(pred_track, axis=0, key=np.abs)
+ elif algorithm in ["avg_fft"]:
+ pred_track = pred_track.sum(axis=0)
+ pred_track /= np.array(weights).sum()
+ pred_track = istft(pred_track, 1024, final_length)
+ elif algorithm in ["min_fft"]:
+ pred_track = np.array(pred_track)
+ pred_track = lambda_min(pred_track, axis=0, key=np.abs)
+ pred_track = istft(pred_track, 1024, final_length)
+ elif algorithm in ["max_fft"]:
+ pred_track = np.array(pred_track)
+ pred_track = absmax(pred_track, axis=0)
+ pred_track = istft(pred_track, 1024, final_length)
+ elif algorithm in ["median_fft"]:
+ pred_track = np.array(pred_track)
+ pred_track = np.median(pred_track, axis=0)
+ pred_track = istft(pred_track, 1024, final_length)
+ return pred_track
+
+
+def ensemble_audio_files(
+ files, output="res.wav", ensemble_type="avg_wave", weights=None, out_format="wav", add_wav=False
+) -> str | tuple[str, str]:
+ """
+ Основная функция для объединения аудиофайлов
+
+ :param files: список путей к аудиофайлам
+ :param output: путь для сохранения результата
+ :param ensemble_type: метод объединения (avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft)
+ :param weights: список весов для каждого файла (None для равных весов)
+ :return: None
+ """
+ print("Алгоритм склеивания: {}".format(ensemble_type))
+ print("Количество входных файлов: {}".format(len(files)))
+ if weights is not None:
+ weights = np.array(weights)
+ else:
+ weights = np.ones(len(files))
+ print("Весы: {}".format(weights))
+ print("Имя выходного файла: {}".format(output))
+
+ data = []
+ sr = None
+ max_length = 0
+ max_channels = 0
+
+ # Первый проход: определяем максимальную длину и количество каналов
+ for f in files:
+ if not os.path.isfile(f):
+ print("Не удается найти файл: {}. Check paths.".format(f))
+ exit()
+ print("Читается файл: {}".format(f))
+ wav, current_sr, _ = audio.read(i=f, sr=None, mono=False)
+ if sr is None:
+ sr = current_sr
+ elif sr != current_sr:
+ print("Частота дискретизации на всех файлах должна быть одинаковой")
+ exit()
+
+ # Определяем количество каналов
+ if wav.ndim == 1:
+ channels = 1
+ length = len(wav)
+ else:
+ channels = wav.shape[0]
+ length = wav.shape[1]
+
+ max_length = max(max_length, length)
+ max_channels = max(max_channels, channels)
+ print("Форма сигнала: {} частота дискретизации: {}".format(wav.shape, sr))
+
+ # Второй проход: обработка и выравнивание файлов
+ for f in files:
+ wav, current_sr, _ = audio.read(i=f, sr=None, mono=False)
+
+ # Обработка каналов
+ if wav.ndim == 1:
+ # Моно -> стерео
+ wav = np.vstack([wav, wav])
+ elif wav.shape[0] == 1:
+ # Один канал -> стерео
+ wav = np.vstack([wav[0], wav[0]])
+ elif wav.shape[0] > 2:
+ # Более 2 каналов -> берем первые два
+ wav = wav[:2, :]
+
+ # Выравнивание длины
+ if wav.shape[1] < max_length:
+ pad_width = ((0, 0), (0, max_length - wav.shape[1]))
+ wav = np.pad(wav, pad_width, mode="constant")
+ elif wav.shape[1] > max_length:
+ wav = wav[:, :max_length]
+
+ data.append(wav)
+
+ data = np.array(data)
+ res = average_waveforms(data, weights, ensemble_type)
+ print("Форма результата: {}".format(res.shape))
+
+ output_wav = f"{output}_orig.wav"
+ output = f"{output}.{out_format}"
+
+ output = audio.write(o=output, array=res.T, sr=sr, of=out_format, br="320k")
+ if add_wav:
+ output_wav = audio.write(o=output_wav, array=res.T, sr=sr, of="wav")
+ return output, output_wav
+
+ else:
+ return output
\ No newline at end of file
diff --git a/mvsepless/infer.py b/mvsepless/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..be215b92020d0fb23238451583a0f08afda43139
--- /dev/null
+++ b/mvsepless/infer.py
@@ -0,0 +1,623 @@
+import os
+import sys
+import json
+import argparse
+import time
+from datetime import datetime
+import gc
+import glob
+import yaml
+import torch
+import numpy as np
+import soundfile as sf
+import torch.nn as nn
+
+from typing import Literal
+
+from audio import Audio
+from namer import Namer
+
+namer = Namer()
+audio = Audio()
+
+from infer_utils import (
+ prefer_target_instrument,
+ demix,
+ get_model_from_config
+)
+
+
+def normalize_peak(audio, peak):
+ current_peak = np.max(np.abs(audio))
+ if current_peak == 0:
+ return audio # избегаем деления на ноль
+ scale_factor = peak / current_peak
+ return audio * scale_factor
+
+
+gc.enable()
+
+
+def cleanup_model(model):
+ try:
+ if isinstance(model, torch.nn.DataParallel):
+ model = model.module
+
+ model.to("cpu")
+
+ for name, param in list(model.named_parameters()):
+ del param
+ for name, buf in list(model.named_buffers()):
+ del buf
+
+ del model
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+
+ gc.collect()
+ sys.stdout.write(json.dumps({"cleanup": "Модель выгружена из памяти"}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+ except Exception as e:
+ sys.stdout.write(json.dumps({"error": f"Ошибка при выгрузке модели: {str(e)}"}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+
+
+def once_inference(
+ path: str = None,
+ model: any = None,
+ config: any = None,
+ device: any = None,
+ model_type: str = None,
+ extract_instrumental: bool = False,
+ detailed_pbar: bool = False,
+ output_format: Literal[
+ "mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "aiff"
+ ] = "mp3",
+ output_bitrate: str = "320k",
+ use_tta: bool = False,
+ verbose: bool = False,
+ model_name: str = None,
+ sample_rate: int = 44100,
+ instruments: list = [],
+ store_dir: str = None,
+ template: str = None,
+ selected_instruments: list = [],
+ model_id: int = 0,
+):
+ results = []
+ sys.stdout.write(json.dumps({"reading": path}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+ sys.stdout.write(json.dumps({"selected_stems": selected_instruments}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+ sys.stdout.write(json.dumps({"stems": list(instruments)}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+
+ if config.training.target_instrument is not None:
+ sys.stdout.write(json.dumps({"target_instrument": config.training.target_instrument}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+
+ try:
+ mix, sr, _ = audio.read(i=path, sr=sample_rate, mono=False)
+ except Exception as e:
+ error_msg = f"Не удалось прочитать аудио: {path}\nОшибка: {e}"
+ sys.stdout.write(json.dumps({"error": error_msg}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+ return results
+
+ mix_orig = mix.copy()
+
+ mean = std = None
+ if config.inference.get("normalize", False):
+ mono = mix.mean(0)
+ mean = mono.mean()
+ std = mono.std()
+ mix = (mix - mean) / std
+
+ if use_tta:
+ track_proc_list = [mix.copy(), mix[::-1].copy(), -1.0 * mix.copy()]
+ else:
+ track_proc_list = [mix.copy()]
+ full_result = []
+ for m in track_proc_list:
+ try:
+ waveforms = demix(
+ config, model, m, device, pbar=detailed_pbar, model_type=model_type
+ )
+
+ full_result.append(waveforms)
+ except Exception as e:
+ sys.stdout.write(json.dumps({"error": f"Ошибка при демиксе: {e}"}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+ del m
+ gc.collect()
+
+ if not full_result:
+ sys.stdout.write(json.dumps({"error": "Пустой результат демикса."}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+ return results
+
+ waveforms = full_result[0]
+ for i in range(1, len(full_result)):
+ d = full_result[i]
+ for el in d:
+ if i == 2:
+ waveforms[el] += -1.0 * d[el]
+ elif i == 1:
+ waveforms[el] += d[el][::-1].copy()
+ else:
+ waveforms[el] += d[el]
+ for el in waveforms:
+ waveforms[el] /= len(full_result)
+
+ if (
+ extract_instrumental and config.training.target_instrument is not None
+ ): # Если включен "Extract Instrumental / Извлечь инструментал" и найден целевой инструмент
+ second_stem = [
+ s
+ for s in config.training.instruments
+ if s != config.training.target_instrument
+ ]
+ if second_stem:
+ second_stem_key = second_stem[0]
+ if second_stem_key not in instruments:
+ instruments.append(second_stem_key)
+ waveforms[second_stem_key] = mix_orig - waveforms[instruments[0]]
+
+ elif (
+ extract_instrumental
+ and selected_instruments
+ and config.training.target_instrument is None
+ ): # Если включен "Extract Instrumental / Извлечь инструментал" и выбраны инструменты, то создаются стемы "inverted -" и "inverted +" (если не найден целевого инструмент)
+
+
+ all_instruments = config.training.instruments
+ if len(all_instruments) > 2:
+
+ waveforms["inverted -"] = mix_orig.copy()
+ for instr in instruments:
+ if instr in waveforms:
+ waveforms["inverted -"] -= waveforms[
+ instr
+ ] # стем "inverted -": вычитание выбранного стема из оригинального сигнала (не всегда хорошо)
+
+ if "inverted -" not in instruments:
+ instruments.append("inverted -")
+
+ unselected_stems = [s for s in all_instruments if s not in selected_instruments]
+ if unselected_stems:
+ waveforms["inverted +"] = np.zeros_like(mix_orig)
+ for stem in unselected_stems:
+ if stem in waveforms:
+ waveforms["inverted +"] += waveforms[
+ stem
+ ] # стем "inverted +": сложение не выбранных инструментов в один стем
+ if "inverted +" not in instruments:
+ instruments.append("inverted +")
+
+ peak = np.max(np.abs(waveforms["inverted -"]))
+ waveforms["inverted +"] = normalize_peak(waveforms["inverted +"], peak)
+
+ elif (
+ extract_instrumental
+ and not selected_instruments
+ and config.training.target_instrument is None
+ and (
+ all(
+ instr in config.training.instruments
+ for instr in ["bass", "drums", "other", "vocals"]
+ )
+ or all(
+ instr in config.training.instruments
+ for instr in ["bass", "drums", "other", "vocals", "piano", "guitar"]
+ )
+ )
+ ):
+
+ waveforms["instrumental -"] = mix_orig.copy()
+ waveforms["instrumental -"] -= waveforms[
+ "vocals"
+ ] # стем "inverted -": вычитание выбранного стема из оригинального сигнала (не всегда хорошо)
+
+ if "instrumental -" not in instruments:
+ instruments.append("instrumental -")
+
+ all_instruments = config.training.instruments
+ non_vocal_stems = [s for s in all_instruments if s not in ["vocals"]]
+ if non_vocal_stems:
+ waveforms["instrumental +"] = np.zeros_like(mix_orig)
+ for stem in non_vocal_stems:
+ if stem in waveforms:
+ waveforms["instrumental +"] += waveforms[
+ stem
+ ] # стем "inverted +": сложение не выбранных инструментов в один стем
+ if "instrumental +" not in instruments:
+ instruments.append("instrumental +")
+
+ peak = np.max(np.abs(waveforms["instrumental -"]))
+ waveforms["instrumental +"] = normalize_peak(waveforms["instrumental +"], peak)
+
+ template = namer.sanitize(template)
+ template = namer.dedup_template(template, keys=["NAME", "MODEL", "STEM", "ID"])
+ template = namer.short(template, length=40)
+
+ for instr in instruments:
+ try:
+ estimates = waveforms[instr].T
+ if mean is not None and std is not None:
+ estimates = estimates * std + mean
+
+ file_name = os.path.splitext(os.path.basename(path))[0]
+ file_name_shorted = namer.short_input_name_template(template, STEM=instr, MODEL=model_name, ID=model_id, NAME=file_name)
+ custom_name = namer.template(
+ template, STEM=instr, MODEL=model_name, ID=model_id, NAME=file_name_shorted
+ )
+ output_path = os.path.join(store_dir, f"{custom_name}.{output_format}")
+
+ sys.stdout.write(json.dumps({"writing": output_path}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+
+ output_path = audio.write(
+ o=output_path, array=estimates, sr=sr, of=output_format, br=output_bitrate
+ ) # запись стема в аудио файл с помощью универсальной функции
+
+ results.append(
+ (instr, output_path)
+ ) # запись информации о разделении: (название стема, путь к файлу)
+ del estimates
+ except Exception as e:
+ sys.stdout.write(json.dumps({"error": f"Ошибка при обработке {instr}: {e}"}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+ gc.collect()
+
+ del mix, mix_orig, waveforms, full_result
+ gc.collect()
+
+ return results
+
+
+def run_inference(
+ model: any = None,
+ config: any = None,
+ input_path: str = None,
+ store_dir: str = None,
+ device: any = None,
+ model_type: str = None,
+ extract_instrumental: bool = False,
+ disable_detailed_pbar: bool = False,
+ output_format: Literal[
+ "mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "aiff"
+ ] = "mp3",
+ output_bitrate: str = "320k",
+ use_tta: bool = False,
+ verbose: bool = False,
+ model_name: str = None,
+ template: str = "NAME_STEM",
+ selected_instruments: list = [],
+ model_id: int = 0,
+):
+ start_time = time.time()
+ if model_type != "vr":
+ model.eval()
+ sample_rate = 44100
+ if "sample_rate" in config.audio:
+ sample_rate = config.audio["sample_rate"]
+
+ instruments = prefer_target_instrument(config)
+
+ if config.training.target_instrument is not None:
+ sys.stdout.write(json.dumps({"info": "Целевой инструмент найден в конфигурации модели. Выбранные стемы будут проигнорированы."}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+ else:
+ if selected_instruments is not None and selected_instruments != []:
+ instruments = [
+ instr for instr in instruments if instr in selected_instruments
+ ]
+ if verbose:
+ sys.stdout.write(json.dumps({"selected_stems": instruments}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+
+ os.makedirs(store_dir, exist_ok=True)
+
+ detailed_pbar = not disable_detailed_pbar
+
+ results = once_inference(
+ path=input_path,
+ model=model,
+ config=config,
+ device=device,
+ model_type=model_type,
+ extract_instrumental=extract_instrumental,
+ detailed_pbar=detailed_pbar,
+ output_format=output_format,
+ output_bitrate=output_bitrate,
+ use_tta=use_tta,
+ verbose=verbose,
+ model_name=model_name,
+ sample_rate=sample_rate,
+ instruments=instruments,
+ store_dir=store_dir,
+ template=template,
+ selected_instruments=selected_instruments,
+ model_id=model_id,
+ )
+
+ time.sleep(1)
+ time_taken = time.time() - start_time
+ sys.stdout.write(json.dumps({"time": f"{time_taken:.2f} сек."}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+ sys.stdout.write(json.dumps({"done": results}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+ return results
+
+
+def load_model(model_type, config_path, start_check_point, device_ids, force_cpu=False):
+ device = "cpu"
+ if force_cpu:
+ device = "cpu"
+ elif torch.cuda.is_available():
+ sys.stdout.write(json.dumps({"info": "Разделение выполняется на ядрах CUDA. Для выполнения на процессоре установите force_cpu=True."}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+ device = "cuda"
+
+ if device_ids is None:
+ device = "cuda:0"
+ elif isinstance(device_ids, (list, tuple)):
+ device = f"cuda:{device_ids[0]}" if device_ids else "cuda:0"
+ elif isinstance(device_ids, bool):
+ device = "cuda:0"
+ else:
+ device = f"cuda:{int(device_ids)}"
+ elif torch.backends.mps.is_available():
+ device = "mps"
+
+ sys.stdout.write(json.dumps({"device": device}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+
+ model_load_start_time = time.time()
+ torch.backends.cudnn.benchmark = True
+
+ model, config = get_model_from_config(model_type, config_path)
+
+ if model_type == "vr":
+ model.load_checkpoint(start_check_point, device)
+ model.settings(enable_post_process=False,
+ post_process_threshold=config.inference.post_process_threshold,
+ batch_size=config.inference.batch_size,
+ window_size=config.inference.window_size,
+ high_end_process=config.inference.high_end_process,
+ primary_stem=config.training.instruments[0],
+ secondary_stem=config.training.instruments[1]
+ )
+ return model, config, device
+
+ elif model_type == "mdxnet":
+ if start_check_point != "":
+ sys.stdout.write(json.dumps({"checkpoint": start_check_point}) + '\n')
+ sys.stdout.flush()
+ model.init_onnx_session(start_check_point, device)
+
+ return model, config, device
+
+ else:
+ if start_check_point != "":
+ sys.stdout.write(json.dumps({"checkpoint": start_check_point}) + '\n')
+ sys.stdout.flush()
+
+ if model_type in ["htdemucs", "apollo"]:
+ state_dict = torch.load(
+ start_check_point, map_location=device, weights_only=False
+ )
+ if "state" in state_dict:
+ state_dict = state_dict["state"]
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+ else:
+ try:
+ state_dict = torch.load(
+ start_check_point, map_location=device, weights_only=True
+ )
+ except torch.serialization.pickle.UnpicklingError:
+ with torch.serialization.safe_globals([torch._C._nn.gelu]):
+ state_dict = torch.load(
+ start_check_point, map_location=device, weights_only=True
+ )
+ try:
+ model.load_state_dict(state_dict)
+ except RuntimeError:
+ model.load_state_dict(state_dict, strict=False)
+
+ sys.stdout.write(json.dumps({"stems": list(config.training.instruments)}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+
+ if (
+ isinstance(device_ids, (list, tuple))
+ and len(device_ids) > 1
+ and not force_cpu
+ and torch.cuda.is_available()
+ ):
+ model = nn.DataParallel(model, device_ids=[int(d) for d in device_ids])
+
+ model = model.to(device)
+
+ load_time = time.time() - model_load_start_time
+
+ sys.stdout.write(json.dumps({"model_load_time": f"{load_time:.2f} сек."}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+
+ return model, config, device
+
+
+def mvsep_offline(
+ input_path,
+ store_dir,
+ model_type,
+ config_path,
+ start_check_point,
+ extract_instrumental,
+ output_format,
+ output_bitrate,
+ model_name,
+ template,
+ device_ids=None,
+ disable_detailed_pbar=False,
+ use_tta=False,
+ force_cpu=False,
+ verbose=False,
+ selected_instruments=None,
+ model_id=0,
+):
+ model, config, device = load_model(
+ model_type, config_path, start_check_point, device_ids, force_cpu
+ )
+
+ results = run_inference(
+ model=model,
+ config=config,
+ input_path=input_path,
+ store_dir=store_dir,
+ device=device,
+ model_type=model_type,
+ extract_instrumental=extract_instrumental,
+ disable_detailed_pbar=disable_detailed_pbar,
+ output_format=output_format,
+ output_bitrate=output_bitrate,
+ use_tta=use_tta,
+ verbose=verbose,
+ model_name=model_name,
+ template=template,
+ selected_instruments=selected_instruments,
+ model_id=model_id,
+ )
+
+ if model_type != "vr":
+ cleanup_model(model)
+ del config
+ gc.collect()
+ return results
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Модифицированный Music-Source-Separation-Training для разделения аудио на источники"
+ )
+
+ # Обязательные аргументы
+ parser.add_argument("--input", type=str, help="Путь к входному файлу или папке")
+ parser.add_argument(
+ "--store_dir", type=str, required=True, help="Путь для сохранения результатов"
+ )
+
+ # Основные параметры модели
+ parser.add_argument(
+ "--model_type",
+ type=str,
+ default="htdemucs",
+ choices=[
+ "mel_band_roformer",
+ "bs_roformer",
+ "mdx23c",
+ "scnet",
+ "htdemucs",
+ "bandit",
+ "bandit_v2",
+ "mdxnet",
+ "vr"
+ ],
+ help="Тип модели (по умолчанию: htdemucs)",
+ )
+ parser.add_argument(
+ "--config_path",
+ type=str,
+ required=True,
+ help="Путь к конфигурационному файлу модели",
+ )
+ parser.add_argument(
+ "--start_check_point", type=str, required=True, help="Путь к чекпоинту модели"
+ )
+
+ # Параметры вывода
+ parser.add_argument(
+ "--output_format",
+ type=str,
+ default="wav",
+ choices=audio.output_formats,
+ help="Формат выходных файлов",
+ )
+ parser.add_argument(
+ "--output_bitrate", type=str, required=True, help="Битрейт выходного файла"
+ )
+
+ parser.add_argument(
+ "--selected_instruments",
+ nargs="+",
+ help="Список стемов для сохранения (например: vocals drums)",
+ )
+ parser.add_argument(
+ "--extract_instrumental",
+ action="store_true",
+ help="Извлечь инструментальную версию",
+ )
+ parser.add_argument(
+ "--template",
+ type=str,
+ default="NAME_STEM",
+ help="Шаблон для имен выходных файлов",
+ )
+ parser.add_argument(
+ "--model_name",
+ type=str,
+ default="model",
+ help="Имя модели для шаблона имен файлов",
+ )
+ parser.add_argument("-m_id", "--model_id", type=int, required=True, help="Model ID")
+ parser.add_argument(
+ "--device_ids", nargs="+", help="ID GPU устройств для использования"
+ )
+ parser.add_argument(
+ "--force_cpu", action="store_true", help="Принудительно использовать CPU"
+ )
+ parser.add_argument(
+ "--use_tta", action="store_true", help="Использовать тестовую аугментацию"
+ )
+ parser.add_argument(
+ "--disable_detailed_pbar",
+ action="store_true",
+ help="Отключить детальный прогресс-бар",
+ )
+ parser.add_argument("--verbose", action="store_true", help="Подробный вывод")
+
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+
+ device_ids = None
+ if args.device_ids:
+ device_ids = [int(x) for x in args.device_ids]
+
+ results = mvsep_offline(
+ input_path=args.input,
+ store_dir=args.store_dir,
+ model_type=args.model_type,
+ config_path=args.config_path,
+ start_check_point=args.start_check_point,
+ extract_instrumental=args.extract_instrumental,
+ output_format=args.output_format,
+ output_bitrate=args.output_bitrate,
+ model_name=args.model_name,
+ template=args.template,
+ device_ids=device_ids,
+ disable_detailed_pbar=args.disable_detailed_pbar,
+ use_tta=args.use_tta,
+ force_cpu=args.force_cpu,
+ verbose=args.verbose,
+ selected_instruments=args.selected_instruments,
+ model_id=args.model_id,
+ )
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/mvsepless/infer_utils.py b/mvsepless/infer_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..57b2bb76b9765b809cd29595b6d8bb81b861c86f
--- /dev/null
+++ b/mvsepless/infer_utils.py
@@ -0,0 +1,382 @@
+# coding: utf-8
+__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
+
+import sys
+import json
+import numpy as np
+import torch
+import torch.nn as nn
+import yaml
+import librosa
+import torch.nn.functional as F
+from ml_collections import ConfigDict
+from omegaconf import OmegaConf
+from typing import Dict, List, Tuple, Any, List, Optional
+
+
+def load_config(model_type: str, config_path: str) -> Any:
+ """
+ Load the configuration from the specified path based on the model type.
+ """
+ try:
+ with open(config_path, "r") as f:
+ if model_type == "htdemucs":
+ config = OmegaConf.load(config_path)
+ else:
+ config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
+ return config
+ except FileNotFoundError:
+ raise FileNotFoundError(f"Configuration file not found at {config_path}")
+ except Exception as e:
+ raise ValueError(f"Error loading configuration: {e}")
+
+
+def get_model_from_config(model_type: str, config_path: str) -> Tuple:
+ """
+ Load the model specified by the model type and configuration file.
+ """
+ config = load_config(model_type, config_path)
+
+ if model_type == "mdx23c":
+ from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net
+
+ model = TFC_TDF_net(config)
+
+ elif model_type == "mdxnet":
+ from models.mdx_net import MDXNet
+
+ model = MDXNet(**dict(config.model))
+
+ # В функции get_model_from_config добавьте:
+
+ elif model_type == "vr":
+ from models.vr_arch import VRNet
+ # Передаем instruments из config.training в модель
+ model = VRNet(**dict(config.model))
+
+ elif model_type == "htdemucs":
+ from models.demucs4ht import get_model
+
+ model = get_model(config)
+
+ elif model_type == "mel_band_roformer":
+ if hasattr(config, "windowed"): # Это не нарушает совместимость со обычными моделями на Mel-Band Roformer
+ from models.windowed_roformer.model import MelBandRoformerWSA
+
+ model = MelBandRoformerWSA(**dict(config.model))
+
+ else:
+ from models.bs_roformer import MelBandRoformer
+
+ model = MelBandRoformer(**dict(config.model))
+
+ elif model_type == "bs_roformer":
+ if hasattr(config.model, "use_shared_bias"):
+ from models.bs_roformer import BSRoformer_SW
+
+ model = BSRoformer_SW(**dict(config.model))
+ elif hasattr(config.model, "fno"):
+ from models.bs_roformer import BSRoformer_FNO
+
+ model = BSRoformer_FNO(**dict(config.model))
+ else:
+ from models.bs_roformer import BSRoformer
+
+ model = BSRoformer(**dict(config.model))
+
+ elif model_type == "bandit":
+ from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple
+
+ model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model)
+
+ elif model_type == "bandit_v2":
+ from models.bandit_v2.bandit import Bandit
+
+ model = Bandit(**config.kwargs)
+ elif model_type == "scnet_unofficial":
+ from models.scnet_unofficial import SCNet
+
+ model = SCNet(**config.model)
+ elif model_type == "scnet":
+ from models.scnet import SCNet
+
+ model = SCNet(**config.model)
+ else:
+ raise ValueError(f"Unknown model type: {model_type}")
+
+ return model, config
+
+def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor:
+ """
+ Generate a windowing array with a linear fade-in at the beginning and a fade-out at the end.
+ """
+ fadein = torch.linspace(0, 1, fade_size)
+ fadeout = torch.linspace(1, 0, fade_size)
+
+ window = torch.ones(window_size)
+ window[-fade_size:] = fadeout
+ window[:fade_size] = fadein
+ return window
+
+def demix_mdxnet(
+ config: Any,
+ model: Any,
+ mix: np.ndarray,
+ device: torch.device,
+ pbar: bool = False,
+) -> Dict[str, np.ndarray]:
+ """
+ MDX-Net specific demixing function с поддержкой overlap
+ """
+ mix_tensor = torch.tensor(mix, dtype=torch.float32)
+ inv_mix_tensor = torch.tensor(-mix, dtype=torch.float32)
+
+ num_overlap = config.inference.num_overlap
+ denoise = config.inference.denoise
+ stem_name = model.primary_stem
+ if denoise:
+ processed_wav = model.process_wave(mix_tensor, device, num_overlap, pbar=pbar)
+ inv_processed_wav = model.process_wave(inv_mix_tensor, device, num_overlap, pbar=pbar)
+ result = processed_wav.cpu().numpy()
+ inv_result = inv_processed_wav.cpu().numpy()
+ result_separation = (result + -inv_result) * 0.5
+ else:
+ processed_wav = model.process_wave(mix_tensor, device, num_overlap, pbar=pbar)
+ result_separation = processed_wav.cpu().numpy()
+
+ result_separation = np.nan_to_num(result_separation, nan=0.0, posinf=0.0, neginf=0.0)
+
+ return {stem_name: result_separation} # Перемещаем на CPU для возврата
+
+def demix_vr(
+ config: Any,
+ model: Any,
+ mix: np.ndarray,
+ device: torch.device,
+ pbar: bool = False,
+) -> Dict[str, np.ndarray]:
+ """
+ VR-specific demixing function that processes the entire audio at once
+ since VR architecture doesn't support chunk-based processing
+ """
+ # Convert to tensor and add batch dimension
+ return model.demix(mix, config.audio.sample_rate, device, config.inference.aggression)
+
+def demix_demucs(config, model, mix, device, pbar=False):
+ mix = torch.tensor(mix, dtype=torch.float32)
+ chunk_size = config.training.samplerate * config.training.segment
+ num_instruments = len(config.training.instruments)
+ num_overlap = config.inference.num_overlap
+ step = chunk_size // num_overlap
+ fade_size = chunk_size // 10 # Добавляем fade_size для оконной функции
+ windowing_array = _getWindowingArray(chunk_size, fade_size) # Создаём окно
+
+ batch_size = config.inference.batch_size
+ use_amp = getattr(config.training, "use_amp", True)
+
+ with torch.cuda.amp.autocast(enabled=use_amp):
+ with torch.inference_mode():
+ req_shape = (num_instruments,) + mix.shape
+ result = torch.zeros(req_shape, dtype=torch.float32)
+ counter = torch.zeros(req_shape, dtype=torch.float32)
+
+ i = 0
+ batch_data = []
+ batch_locations = []
+
+ while i < mix.shape[1]:
+ part = mix[:, i : i + chunk_size].to(device)
+ chunk_len = part.shape[-1]
+ pad_mode = "reflect" if chunk_len > chunk_size // 2 else "constant"
+ part = nn.functional.pad(
+ part, (0, chunk_size - chunk_len), mode=pad_mode, value=0
+ )
+
+ batch_data.append(part)
+ batch_locations.append((i, chunk_len))
+ i += step
+
+ if len(batch_data) >= batch_size or i >= mix.shape[1]:
+ arr = torch.stack(batch_data, dim=0)
+ x = model(arr)
+
+ window = windowing_array.clone()
+ if i - step == 0: # Первый чанк, без fade-in
+ window[:fade_size] = 1
+ elif i >= mix.shape[1]: # Последний чанк, без fade-out
+ window[-fade_size:] = 1
+
+ for j, (start, seg_len) in enumerate(batch_locations):
+ result[..., start : start + seg_len] += (
+ x[j, ..., :seg_len].cpu() * window[..., :seg_len]
+ )
+ counter[..., start : start + seg_len] += window[..., :seg_len]
+
+ # Output progress
+ processed = min(i, mix.shape[1])
+ total = mix.shape[1]
+ sys.stdout.write(json.dumps({"processing": {"processed": processed, "total": total}}) + '\n')
+ sys.stdout.flush()
+
+ batch_data.clear()
+ batch_locations.clear()
+
+ estimated_sources = result / counter
+ estimated_sources = estimated_sources.cpu().numpy()
+ np.nan_to_num(estimated_sources, copy=False, nan=0.0)
+
+ if num_instruments <= 1:
+ return estimated_sources
+ else:
+ instruments = config.training.instruments
+ return {k: v for k, v in zip(instruments, estimated_sources)}
+
+def demix_generic(
+ config: ConfigDict,
+ model: torch.nn.Module,
+ mix: torch.Tensor,
+ device: torch.device,
+ pbar: bool = False,
+) -> Dict[str, np.ndarray]:
+ """
+ Generic demixing function for models that support chunk-based processing
+ """
+ mix = torch.tensor(mix, dtype=torch.float32)
+ chunk_size = config.audio.chunk_size
+ instruments = prefer_target_instrument(config)
+ num_instruments = len(instruments)
+ num_overlap = config.inference.num_overlap
+
+ fade_size = chunk_size // 10
+ step = chunk_size // num_overlap
+ border = chunk_size - step
+ length_init = mix.shape[-1]
+ windowing_array = _getWindowingArray(chunk_size, fade_size)
+
+ # Add padding to handle edge artifacts
+ if length_init > 2 * border and border > 0:
+ mix = nn.functional.pad(mix, (border, border), mode="reflect")
+
+ batch_size = config.inference.batch_size
+ use_amp = getattr(config.training, "use_amp", True)
+
+ with torch.cuda.amp.autocast(enabled=use_amp):
+ with torch.inference_mode():
+ # Initialize result and counter tensors
+ req_shape = (num_instruments,) + mix.shape
+ result = torch.zeros(req_shape, dtype=torch.float32)
+ counter = torch.zeros(req_shape, dtype=torch.float32)
+
+ i = 0
+ batch_data = []
+ batch_locations = []
+
+ while i < mix.shape[1]:
+ # Extract chunk and apply padding if necessary
+ part = mix[:, i : i + chunk_size].to(device)
+ chunk_len = part.shape[-1]
+
+ pad_mode = "reflect" if chunk_len > chunk_size // 2 else "constant"
+ part = nn.functional.pad(
+ part, (0, chunk_size - chunk_len), mode=pad_mode, value=0
+ )
+
+ batch_data.append(part)
+ batch_locations.append((i, chunk_len))
+ i += step
+
+ # Process batch if it's full or the end is reached
+ if len(batch_data) >= batch_size or i >= mix.shape[1]:
+ arr = torch.stack(batch_data, dim=0)
+ x = model(arr)
+
+ window = windowing_array.clone()
+ if i - step == 0: # First audio chunk, no fadein
+ window[:fade_size] = 1
+ elif i >= mix.shape[1]: # Last audio chunk, no fadeout
+ window[-fade_size:] = 1
+
+ for j, (start, seg_len) in enumerate(batch_locations):
+ result[..., start : start + seg_len] += (
+ x[j, ..., :seg_len].cpu() * window[..., :seg_len]
+ )
+ counter[..., start : start + seg_len] += window[..., :seg_len]
+
+ # Output progress
+ processed = min(i, mix.shape[1])
+ total = mix.shape[1]
+ sys.stdout.write(json.dumps({"processing": {"processed": processed, "total": total}}, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+
+ batch_data.clear()
+ batch_locations.clear()
+
+ # Compute final estimated sources
+ estimated_sources = result / counter
+ estimated_sources = estimated_sources.cpu().numpy()
+ np.nan_to_num(estimated_sources, copy=False, nan=0.0)
+
+ # Remove padding
+ if length_init > 2 * border and border > 0:
+ estimated_sources = estimated_sources[..., border:-border]
+
+ # Return the result as a dictionary
+ return {k: v for k, v in zip(instruments, estimated_sources)}
+
+def demix(
+ config: ConfigDict,
+ model: torch.nn.Module,
+ mix: np.ndarray,
+ device: torch.device,
+ model_type: str,
+ pbar: bool = False,
+) -> Dict[str, np.ndarray]:
+ """
+ Unified function for audio source separation with support for multiple processing modes.
+ """
+ # Handle different model types
+ if model_type == "vr":
+ return demix_vr(config, model, mix, device, pbar)
+ elif model_type == "mdxnet":
+ return demix_mdxnet(config, model, mix, device, pbar)
+ elif model_type == "htdemucs":
+ # HTDemucs uses its own processing
+ return demix_demucs(config, model, mix, device, pbar)
+ else:
+ # Generic processing for other models
+ return demix_generic(config, model, mix, device, pbar)
+
+
+def prefer_target_instrument(config: ConfigDict) -> List[str]:
+ """
+ Return the list of target instruments based on the configuration.
+ If a specific target instrument is specified in the configuration,
+ it returns a list with that instrument. Otherwise, it returns the list of instruments.
+ """
+ if config.training.get("target_instrument"):
+ return [config.training.target_instrument]
+ else:
+ return config.training.instruments
+
+
+def prefer_target_instrument_test(
+ config: ConfigDict, selected_instruments: Optional[List[str]] = None
+) -> List[str]:
+ """
+ Return the list of target instruments based on the configuration and selected instruments.
+ If selected_instruments is specified, returns the intersection with available instruments.
+ Otherwise, if a target instrument is specified, returns it, else returns all instruments.
+ """
+ available_instruments = config.training.instruments
+
+ if selected_instruments is not None:
+ # Return only selected instruments that are available
+ return [
+ instr for instr in selected_instruments if instr in available_instruments
+ ]
+ elif config.training.get("target_instrument"):
+ # Default behavior if no selection - return target instrument
+ return [config.training.target_instrument]
+ else:
+ # If no target and no selection, return all instruments
+ return available_instruments
\ No newline at end of file
diff --git a/mvsepless/model_manager.py b/mvsepless/model_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..0492e403eee6eedb9e35db7a7ae1d5d4a35b9b5f
--- /dev/null
+++ b/mvsepless/model_manager.py
@@ -0,0 +1,540 @@
+import os
+import sys
+import json
+import yaml
+from tabulate import tabulate
+import shutil
+from tqdm import tqdm
+import urllib.request
+import gdown
+import requests
+import zipfile
+import tempfile
+import secrets
+import string
+import argparse
+from typing import Dict, Any
+script_dir = os.path.dirname(os.path.abspath(__file__))
+if not __package__:
+ from downloader import dw_file
+else:
+ from .downloader import dw_file
+
+def generate_secure_random(length=10):
+ """Генерирует криптографически безопасную случайную строку"""
+ characters = string.ascii_letters + string.digits
+ return ''.join(secrets.choice(characters) for _ in range(length))
+
+class MvseplessModelManager:
+ def __init__(
+ self,
+ models_info_path=os.path.join(script_dir, "models.json"),
+ cache_dir=os.path.join(script_dir, "mvsepless_models_cache"),
+ ):
+ self.models_cache_dir = cache_dir
+ self.models_info_path = models_info_path
+ with open(self.models_info_path, "r", encoding="utf-8") as f:
+ models_info = json.load(f)
+ self.models_info = models_info
+
+ def get_mt(self):
+ return list(self.models_info.keys())
+
+ def get_mn(self, model_type):
+ try:
+ mt = self.models_info.get(model_type, None)
+ if mt:
+ return list(self.models_info[model_type].keys())
+ return []
+ except (KeyError, TypeError):
+ return []
+
+ def get_stems(self, model_type, model_name):
+ try:
+ mt = self.models_info.get(model_type, None)
+ if mt:
+ mn = self.models_info[model_type].get(model_name, None)
+ if mn and "stems" in self.models_info[model_type][model_name]:
+ return self.models_info[model_type][model_name]["stems"]
+ return []
+ except (KeyError, TypeError):
+ return []
+
+ def get_id(self, model_type, model_name):
+ try:
+ mt = self.models_info.get(model_type, None)
+ if mt:
+ mn = self.models_info[model_type].get(model_name, None)
+ if mn and "id" in self.models_info[model_type][model_name]:
+ return self.models_info[model_type][model_name]["id"]
+ return 0
+ except (KeyError, TypeError):
+ return 0
+
+ def get_tgt_inst(self, model_type, model_name):
+ try:
+ mt = self.models_info.get(model_type, None)
+ if mt:
+ mn = self.models_info[model_type].get(model_name, None)
+ if mn and "target_instrument" in self.models_info[model_type][model_name]:
+ return self.models_info[model_type][model_name]["target_instrument"]
+ return None
+ except (KeyError, TypeError):
+ return None
+
+ def display_models_info(self, filter: str = None):
+ # Собираем данные для таблицы
+ table_data = []
+ headers = [
+ "Тип модели",
+ "ID",
+ "Имя модели",
+ "Стемы",
+ "Целевой инструмент",
+ ]
+
+ for model_type, models in self.models_info.items():
+ for model_name, model_info in models.items():
+ try:
+ stems_list = model_info.get("stems", [])
+ id = model_info.get("id", "н/д")
+ # Применяем фильтр (регистронезависимо)
+ if filter:
+ filter_lower = filter.lower()
+ if not any(filter_lower == s.lower() for s in stems_list):
+ continue
+
+ # Подготавливаем данные для строки таблицы
+ row = [
+ model_type,
+ id,
+ model_name,
+ ", ".join(stems_list) or "н/д",
+ model_info.get("target_instrument", "н/д"),
+ ]
+ table_data.append(row)
+ except (KeyError, TypeError, AttributeError) as e:
+ print(f"Ошибка при обработке модели {model_type}/{model_name}: {e}")
+ continue
+
+ # Выводим результат
+ if table_data:
+ print(tabulate(table_data, headers=headers, tablefmt="grid"))
+ else:
+ print("Нет моделей, которые содержат указанный стем")
+
+ def download_model(
+ self, model_paths, model_name, model_type, ckpt_url, conf_url
+ ):
+ model_dir = os.path.join(model_paths, model_type)
+ os.makedirs(model_dir, exist_ok=True)
+
+ config_path = None
+ checkpoint_path = None
+
+ if model_type == "mel_band_roformer":
+ config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
+ checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
+
+ elif model_type == "vr":
+ config_path = os.path.join(model_dir, f"{model_name}.yaml")
+ checkpoint_path = os.path.join(model_dir, f"{model_name}.pth")
+
+ elif model_type == "mdxnet":
+ config_path = os.path.join(model_dir, f"{model_name}.yaml")
+ checkpoint_path = os.path.join(model_dir, f"{model_name}.onnx")
+
+ elif model_type == "bs_roformer":
+ config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
+ checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
+
+ elif model_type == "mdx23c":
+ config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
+ checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
+
+ elif model_type == "scnet":
+ config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
+ checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
+
+ elif model_type == "bandit":
+ config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
+ checkpoint_path = os.path.join(model_dir, f"{model_name}.chpt")
+
+ elif model_type == "bandit_v2":
+ config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
+ checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
+
+ elif model_type == "htdemucs":
+ config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
+ checkpoint_path = os.path.join(model_dir, f"{model_name}.th")
+
+ else:
+ raise ValueError(
+ f"{self.I18N_helper.t('error_unsupported_model_type')}: {model_type}"
+ )
+
+ # Проверяем, что пути заданы (на всякий случай)
+ if config_path is None or checkpoint_path is None:
+ raise RuntimeError()
+
+ # Если файлы уже есть — пропускаем загрузку
+ if os.path.exists(checkpoint_path) and os.path.exists(config_path):
+ if os.path.getsize(checkpoint_path) == 0 or os.path.getsize(checkpoint_path) == 0:
+ for local_path, url_model in [
+ (checkpoint_path, ckpt_url),
+ (config_path, conf_url),
+ ]:
+ if not os.path.exists(local_path):
+
+ dw_file(url_model, local_path)
+ else:
+ pass
+ else:
+ for local_path, url_model in [
+ (checkpoint_path, ckpt_url),
+ (config_path, conf_url),
+ ]:
+ if not os.path.exists(local_path):
+
+ dw_file(url_model, local_path)
+
+ return config_path, checkpoint_path
+
+ def conf_editor(self, config_path, mdx_denoise, vr_aggr, model_type):
+
+ class IndentDumper(yaml.Dumper):
+ def increase_indent(self, flow=False, indentless=False):
+ return super(IndentDumper, self).increase_indent(flow, False)
+
+ def tuple_constructor(loader, node):
+ # Load the sequence of values from the YAML node
+ values = loader.construct_sequence(node)
+ # Return a tuple constructed from the sequence
+ return tuple(values)
+
+ # Register the constructor with PyYAML
+ yaml.SafeLoader.add_constructor(
+ "tag:yaml.org,2002:python/tuple", tuple_constructor
+ )
+
+ def conf_edit(config_path, mdx_denoise, vr_aggr, model_type):
+ with open(config_path, "r") as f:
+ data = yaml.load(f, Loader=yaml.SafeLoader)
+
+ # handle cases where 'use_amp' is missing from config:
+ if "use_amp" not in data.keys():
+ data["training"]["use_amp"] = True
+
+ if model_type != "vr":
+ if data["inference"]["num_overlap"] != 2:
+ data["inference"]["num_overlap"] = 2
+
+ if data["inference"]["batch_size"] != 1:
+ data["inference"]["batch_size"] = 1
+
+ if model_type == "mdxnet":
+ data["inference"]["denoise"] = mdx_denoise
+
+ elif model_type == "vr":
+ data["inference"]["aggression"] = vr_aggr
+
+ with open(config_path, "w") as f:
+ yaml.dump(
+ data,
+ f,
+ default_flow_style=False,
+ sort_keys=False,
+ Dumper=IndentDumper,
+ allow_unicode=True,
+ )
+
+ conf_edit(config_path, mdx_denoise, vr_aggr, model_type)
+
+class VbachModelManager:
+ def __init__(self):
+ self.rmvpe_path = os.path.join(script_dir, "predictors", "rmvpe.pt")
+ self.fcpe_path = os.path.join(script_dir, "predictors", "fcpe.pt")
+ self.hubert_path = os.path.join(script_dir, "embedders", "hubert_base.pt")
+ self.requirements = [["https://huggingface.co/Politrees/RVC_resources/resolve/main/predictors/rmvpe.pt", self.rmvpe_path], ["https://huggingface.co/Politrees/RVC_resources/resolve/main/predictors/fcpe.pt", self.fcpe_path], ["https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/hubert_base.pt", self.hubert_path]]
+ self.voicemodels_dir = os.path.join(script_dir, "vbach_models_cache")
+ os.makedirs(self.voicemodels_dir, exist_ok=True)
+ self.voicemodels_info = os.path.join(self.voicemodels_dir, "vbach_models.json")
+ self.voicemodels: Dict[str, Dict[str, str]] = {}
+ self.download_requirements()
+ self.check_and_load()
+ pass
+
+ def write_voicemodels_info(self):
+ with open(self.voicemodels_info, "w") as f:
+ json.dump(self.voicemodels, f, indent=4)
+
+ def load_voicemodels_info(self):
+ with open(self.voicemodels_info, "r") as f:
+ return json.load(f)
+
+ def add_voice_model(
+ self,
+ name,
+ pth_path,
+ index_path,
+ ):
+ self.voicemodels[name] = {"pth": pth_path, "index": index_path}
+ self.write_voicemodels_info()
+
+ def del_voice_model(
+ self, name
+ ):
+ if name in self.parse_voice_models():
+ pth = self.voicemodels[name].get("pth", None)
+ index = self.voicemodels[name].get("index", None)
+ if index:
+ os.remove(index)
+ if pth:
+ os.remove(pth)
+ del self.voicemodels[name]
+ self.write_voicemodels_info()
+ return f"Модель {name} удалена"
+ else:
+ return f"Модель не была удалена, как так её не существует"
+
+ def parse_voice_models(self):
+ list_models = list(self.voicemodels.keys())
+ return list_models
+
+ def parse_pth_and_index(self, name):
+ pth = self.voicemodels[name].get("pth", None)
+ index = self.voicemodels[name].get("index", None)
+ return pth, index
+
+ def check_and_load(self):
+ if os.path.exists(self.voicemodels_info):
+ self.voicemodels = self.load_voicemodels_info()
+ else:
+ self.write_voicemodels_info()
+
+ def clear_voicemodels_info(self):
+ self.voicemodels: Dict[str, Dict[str, str]] = {}
+ self.write_voicemodels_info()
+
+ def download_file(self, url_model, local_path):
+ dir_name = os.path.dirname(local_path)
+ if dir_name != "":
+ os.makedirs(dir_name, exist_ok=True)
+ class TqdmUpTo(tqdm):
+ def update_to(self, b=1, bsize=1, tsize=None):
+ if tsize is not None:
+ self.total = tsize
+ self.update(b * bsize - self.n)
+
+ with TqdmUpTo(
+ unit="B",
+ unit_scale=True,
+ unit_divisor=1024,
+ miniters=1,
+ desc=os.path.basename(local_path),
+ ) as t:
+ urllib.request.urlretrieve(
+ url_model, local_path, reporthook=t.update_to
+ )
+
+ def download_requirements(self):
+ for url, file in self.requirements:
+ if not os.path.exists(file):
+ self.download_file(url_model=url, local_path=file)
+
+ def download_voice_model_file(self, url, zip_name):
+ try:
+ if "drive.google.com" in url:
+ self.download_from_google_drive(url, zip_name)
+ elif "pixeldrain.com" in url:
+ self.download_from_pixeldrain(url, zip_name)
+ elif "disk.yandex.ru" in url or "yadi.sk" in url:
+ self.download_from_yandex(url, zip_name)
+ else:
+ self.download_file(url, zip_name)
+ except Exception as e:
+ print(e)
+
+ def download_from_google_drive(self, url, zip_name):
+ file_id = (
+ url.split("file/d/")[1].split("/")[0]
+ if "file/d/" in url
+ else url.split("id=")[1].split("&")[0]
+ )
+ gdown.download(id=file_id, output=str(zip_name), quiet=False)
+
+ def download_from_pixeldrain(self, url, zip_name):
+ file_id = url.split("pixeldrain.com/u/")[1]
+ response = requests.get(f"https://pixeldrain.com/api/file/{file_id}")
+ with open(zip_name, "wb") as f:
+ f.write(response.content)
+
+ def download_from_yandex(self, url, zip_name):
+ yandex_public_key = f"download?public_key={url}"
+ yandex_api_url = f"https://cloud-api.yandex.net/v1/disk/public/resources/{yandex_public_key}"
+ response = requests.get(yandex_api_url)
+ if response.status_code == 200:
+ download_link = response.json().get("href")
+ urllib.request.urlretrieve(download_link, zip_name)
+ else:
+ print(response.status_code)
+
+ def extract_zip(self, zip_name, model_name):
+ model_dir = os.path.join(self.voicemodels_dir, f"{model_name}_{generate_secure_random(17)}")
+ os.makedirs(model_dir, exist_ok=True)
+ try:
+ with zipfile.ZipFile(zip_name, "r") as zip_ref:
+ zip_ref.extractall(model_dir)
+ os.remove(zip_name)
+
+ added_voice_models = []
+
+ index_filepath, model_filepaths = None, []
+ for root, _, files in os.walk(model_dir):
+ for name in files:
+ file_path = os.path.join(root, name)
+ if name.endswith(".index") and os.stat(file_path).st_size > 1024 * 100:
+ index_filepath = file_path
+ if name.endswith(".pth") and os.stat(file_path).st_size > 1024 * 1024 * 20:
+ model_filepaths.append(file_path)
+
+ if len(model_filepaths) == 1:
+ self.add_voice_model(model_name, model_filepaths[0], index_filepath)
+ added_voice_models.append(model_name)
+ else:
+ for i, pth in enumerate(model_filepaths):
+ self.add_voice_model(f"{model_name}_{i + 1}", pth, index_filepath)
+ added_voice_models.append(f"{model_name}_{i + 1}")
+ list_models_str = '\n'.join(added_voice_models)
+ return f"Добавленные модели:\n{list_models_str}"
+ except Exception as e:
+ return f"Произошла ошибка при загрузке модели: {e}"
+
+ def install_model_zip(self, zip, model_name, mode="url"):
+ if model_name in self.parse_voice_models():
+ print("Эта модель уже есть в списке установленных моделей. Она будут перезаписана")
+ if mode == "url":
+ with tempfile.TemporaryDirectory(prefix="vbach_temp_model", ignore_cleanup_errors=True) as tmp:
+ zip_path = os.path.join(tmp, "model.zip")
+ self.download_voice_model_file(zip, zip_path)
+ status = self.extract_zip(zip_path, model_name)
+ if mode == "local":
+ status = self.extract_zip(zip, model_name)
+ return status
+
+ def install_model_files(self, index, pth, model_name, mode="url"):
+ if model_name in self.parse_voice_models():
+ print("Эта модель уже есть в списке установленных моделей. Она будут перезаписана")
+ model_dir = os.path.join(self.voicemodels_dir, f"{model_name}_{generate_secure_random(17)}")
+ os.makedirs(model_dir, exist_ok=True)
+ local_index_path = None
+ local_pth_path = None
+ try:
+ if mode == "url":
+ if index:
+ local_index_path = os.path.join(model_dir, "model.index")
+ self.download_voice_model_file(index, local_index_path)
+ if pth:
+ local_pth_path = os.path.join(model_dir, "model.pth")
+ self.download_voice_model_file(pth, local_pth_path)
+
+ if mode == "local":
+ if index:
+ if os.path.exists(index):
+ local_index_path = os.path.join(model_dir, os.path.basename(index))
+ shutil.copy(index, local_index_path)
+ if pth:
+ if os.path.exists(pth):
+ local_pth_path = os.path.join(model_dir, os.path.basename(pth))
+ shutil.copy(pth, local_pth_path)
+
+ self.add_voice_model(model_name, local_pth_path, local_index_path)
+ return f"Модель {model_name} добавлена"
+ except Exception as e:
+ return f"Произошла ошибка при загрузке модели: {e}"
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Менеджер моделей")
+ subparsers = parser.add_subparsers(title="subcommands", dest="command", required=True)
+
+ # Mvsepless subcommand
+ mvsepless_parser = subparsers.add_parser("mvsepless", help="Скачивание моделей в MVSepLess")
+ mvsepless_parser.add_argument("--model_type", required=True, help="Тип модели")
+ mvsepless_parser.add_argument("--model_name", required=True, help="Имя модели")
+
+ # Vbach subcommand
+ vbach_parser = subparsers.add_parser("vbach", help="Установка голосовых моделей в Vbach")
+ vbach_subparsers = vbach_parser.add_subparsers(title="vbach_commands", dest="vbach_command", required=True)
+
+ # Vbach install_local
+ install_local_parser = vbach_subparsers.add_parser("install_local", help="Установка голосовой модели по локальным файлам")
+ install_local_parser.add_argument("--model_name", required=True, help="Имя голосовой модели")
+ install_local_parser.add_argument("--pth", required=True, help="Путь к *.pth файлу")
+ install_local_parser.add_argument("--index", required=False, help="Путь к *.index файлу")
+
+ # Vbach install_url_zip
+ install_url_zip_parser = vbach_subparsers.add_parser("install_url_zip", help="Установка голосовой модели по URL (архив с файлами)")
+ install_url_zip_parser.add_argument("--model_name", required=True, help="Имя голосовой модели")
+ install_url_zip_parser.add_argument("--url", required=True, help="URL *.zip файла")
+
+ # Vbach install_url_files
+ install_url_files_parser = vbach_subparsers.add_parser("install_url_files", help="Установка голосовой модели по URL (отдельные файлы)")
+ install_url_files_parser.add_argument("--model_name", required=True, help="Имя голосовой модели")
+ install_url_files_parser.add_argument("--pth_url", required=True, help="URL *.pth файла")
+ install_url_files_parser.add_argument("--index_url", required=False, help="URL *.index файла")
+
+ # Vbach list
+ list_parser = vbach_subparsers.add_parser("list", help="List installed voice models")
+
+ args = parser.parse_args()
+
+ if args.command == "mvsepless":
+
+ _model_manager = MvseplessModelManager()
+ info = _model_manager.models_info[args.model_type].get(args.model_name, None)
+ if not info:
+ raise ValueError(f"Модель {args.model_name} не найдена для типа {args.model_type}")
+ conf, ckpt = _model_manager.download_model(
+ _model_manager.models_cache_dir,
+ args.model_name,
+ args.model_type,
+ info["checkpoint_url"],
+ info["config_url"],
+ )
+
+ elif args.command == "vbach":
+ model_manager = VbachModelManager()
+
+ if args.vbach_command == "install_local":
+ status = model_manager.install_model_files(
+ args.index, args.pth, args.model_name, mode="local"
+ )
+ print(status)
+
+ elif args.vbach_command == "install_url_zip":
+ status = model_manager.install_model_zip(
+ args.url, args.model_name, mode="url"
+ )
+ print(status)
+
+ elif args.vbach_command == "install_url_files":
+ status = model_manager.install_model_files(
+ args.index_url, args.pth_url, args.model_name, mode="url"
+ )
+ print(status)
+
+ elif args.vbach_command == "list":
+ models = model_manager.parse_voice_models()
+ if models:
+ print("Установленные модели:")
+ for model in models:
+ print(f" - {model}")
+ else:
+ print("Нет установленных моделей")
+
+
+
+
+
+
diff --git a/mvsepless/models.json b/mvsepless/models.json
new file mode 100644
index 0000000000000000000000000000000000000000..c28d8e7b9db1735ab51795fd67f74e8b91336c17
--- /dev/null
+++ b/mvsepless/models.json
@@ -0,0 +1,2859 @@
+{
+ "mel_band_roformer": {
+ "mbr_vocals_kim": {
+ "category": "Вокал",
+ "id": 1000,
+ "full_name": "Mel-Band Roformer Vocals by Kimberley Jensen",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/KimberleyJSN/melbandroformer/resolve/main/MelBandRoformer.ckpt?download=true",
+ "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/KimberleyJensen/config_vocals_mel_band_roformer_kj.yaml"
+ },
+ "mbr_wsa": {
+ "category": "Вокал",
+ "id": 1910,
+ "full_name": "Windowed Sink Attention Mel-Band Roformer Vocals by Smule Labs",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/smulelabs/windowed-roformer/resolve/main/mbr-win10-sink8.ckpt?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/config_windowed_roformer_by_smulelabs_wsa.yaml?download=true"
+ },
+ "mbr_instvoc_duality1_unwa": {
+ "category": "Инструментал и вокал",
+ "id": 1010,
+ "full_name": "Mel-Band Roformer InstVoc Duality v1 by Unwa",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/resolve/main/melband_roformer_instvoc_duality_v1.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/resolve/main/config_melbandroformer_instvoc_duality.yaml?download=true"
+ },
+ "mbr_instvoc_duality2_unwa": {
+ "category": "Инструментал и вокал",
+ "id": 1011,
+ "full_name": "Mel-Band Roformer InstVoc Duality v2 by Unwa",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/resolve/main/melband_roformer_instvox_duality_v2.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/resolve/main/config_melbandroformer_instvoc_duality.yaml?download=true"
+ },
+ "mbr_kimft1_unwa": {
+ "category": "Вокал",
+ "id": 1020,
+ "full_name": "Mel-Band Roformer Kim FT v1 by Unwa",
+ "stems": [
+ "Vocals",
+ "other"
+ ],
+ "target_instrument": "Vocals",
+ "checkpoint_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/kimmel_unwa_ft.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/config_kimmel_unwa_ft.yaml?download=true"
+ },
+ "mbr_kimft2_unwa": {
+ "category": "Вокал",
+ "id": 1021,
+ "full_name": "Mel-Band Roformer Kim FT v2 by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/kimmel_unwa_ft2.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/config_kimmel_unwa_ft.yaml?download=true"
+ },
+ "mbr_kimft2b_unwa": {
+ "category": "Вокал",
+ "id": 1022,
+ "full_name": "Mel-Band Roformer Kim FT v2 Bleedless by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/kimmel_unwa_ft2_bleedless.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/config_kimmel_unwa_ft.yaml?download=true"
+ },
+ "mbr_kimft3_prev_unwa": {
+ "category": "Вокал",
+ "id": 1023,
+ "full_name": "Mel-Band Roformer Kim FT v3 preview by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/kimmel_unwa_ft3_prev.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/config_kimmel_unwa_ft.yaml?download=true"
+ },
+ "mbr_bigbeta1_unwa": {
+ "category": "Вокал",
+ "id": 1030,
+ "full_name": "Mel-Band Roformer Big Beta v1 by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/melband_roformer_big_beta1.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/config_melbandroformer_big.yaml?download=true"
+ },
+ "mbr_bigbeta2_unwa": {
+ "category": "Вокал",
+ "id": 1031,
+ "full_name": "Mel-Band Roformer Big Beta v2 by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/melband_roformer_big_beta2.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/config_melbandroformer_big.yaml?download=true"
+ },
+ "mbr_bigbeta3_unwa": {
+ "category": "Вокал",
+ "id": 1032,
+ "full_name": "Mel-Band Roformer Big Beta v3 by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/melband_roformer_big_beta3.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/config_melbandroformer_big.yaml?download=true"
+ },
+ "mbr_bigbeta4_unwa": {
+ "category": "Вокал",
+ "id": 1033,
+ "full_name": "Mel-Band Roformer Big Beta v4 by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/melband_roformer_big_beta4.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/config_melbandroformer_big_beta4.yaml?download=true"
+ },
+ "mbr_bigbeta5e_unwa": {
+ "category": "Вокал",
+ "id": 1034,
+ "full_name": "Mel-Band Roformer Vocals Big Beta v5e by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta5e.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta5e.yaml?download=true"
+ },
+ "mbr_bigbeta6_unwa": {
+ "category": "Вокал",
+ "id": 1035,
+ "full_name": "Mel-Band Roformer Big Beta v6 by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta6.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta6.yaml?download=true"
+ },
+ "mbr_bigbeta6x_unwa": {
+ "category": "Вокал",
+ "id": 1036,
+ "full_name": "Mel-Band Roformer Big Beta v6x by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta6x.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta6x.yaml?download=true"
+ },
+ "mbr_inst1_unwa": {
+ "category": "Инструментал",
+ "id": 1040,
+ "full_name": "Mel-Band Roformer Instrumental v1 by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "other",
+ "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/melband_roformer_inst_v1.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/config_melbandroformer_inst.yaml?download=true"
+ },
+ "mbr_inst1+_unwa": {
+ "category": "Инструментал",
+ "id": 1041,
+ "full_name": "Mel-Band Roformer Instrumental v1+ by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "other",
+ "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/inst_v1_plus_test.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/config_melbandroformer_inst.yaml?download=true"
+ },
+ "mbr_inst1e_unwa": {
+ "category": "Инструментал",
+ "id": 1042,
+ "full_name": "Mel-Band Roformer Instrumental v1e by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "other",
+ "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/inst_v1e.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/config_melbandroformer_inst.yaml?download=true"
+ },
+ "mbr_inst1e+_unwa": {
+ "category": "Инструментал",
+ "id": 1043,
+ "full_name": "Mel-Band Roformer Instrumental v1e Plus by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "other",
+ "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/inst_v1e_plus.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/config_melbandroformer_inst.yaml?download=true"
+ },
+ "mbr_inst2_unwa": {
+ "category": "Инструментал",
+ "id": 1044,
+ "full_name": "Mel-Band Roformer Instrumental v2 by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "other",
+ "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/melband_roformer_inst_v2.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/config_melbandroformer_inst_v2.yaml?download=true"
+ },
+ "mbr_small_unwa": {
+ "category": "Вокал",
+ "id": 1050,
+ "full_name": "Mel-Band Roformer Small v1 by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-small/resolve/main/melband_roformer_small_v1.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-small/resolve/main/config_melbandroformer_small.yaml?download=true"
+ },
+ "mbr_bleed_supressor_unwa_97chris": {
+ "category": "Шум",
+ "id": 1051,
+ "full_name": "Mel-Band Roformer Bleed Suppressor v1 by Unwa / 97chris",
+ "stems": [
+ "Instrumental",
+ "Bleed"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/jarredou/bleed_suppressor_melband_rofo_by_unwa_97chris/resolve/main/bleed_suppressor_v1.ckpt?download=true",
+ "config_url": "https://huggingface.co/jarredou/bleed_suppressor_melband_rofo_by_unwa_97chris/resolve/main/config_bleed_suppressor_v1.yaml?download=true"
+ },
+ "mbr_inst_becruily": {
+ "category": "Инструментал",
+ "id": 1060,
+ "full_name": "Mel-Band Roformer Instrumental by Becruily",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/becruily/mel-band-roformer-instrumental/resolve/main/mel_band_roformer_instrumental_becruily.ckpt?download=true",
+ "config_url": "https://huggingface.co/becruily/mel-band-roformer-instrumental/resolve/main/config_instrumental_becruily.yaml?download=true"
+ },
+ "mbr_guitar_becruily": {
+ "category": "Гитара",
+ "id": 1061,
+ "full_name": "Mel-Band Roformer Instrumental by Becruily",
+ "stems": [
+ "Guitar",
+ "Other"
+ ],
+ "target_instrument": "Guitar",
+ "checkpoint_url": "https://huggingface.co/becruily/mel-band-roformer-guitar/resolve/main/becruily_guitar.ckpt?download=true",
+ "config_url": "https://huggingface.co/becruily/mel-band-roformer-guitar/resolve/main/config_guitar_becruily.yaml?download=true"
+ },
+ "mbr_karaoke_becruily": {
+ "category": "Караоке",
+ "id": 1062,
+ "full_name": "Mel-Band Roformer Karaoke by Becruily",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/becruily/mel-band-roformer-karaoke/resolve/main/mel_band_roformer_karaoke_becruily.ckpt?download=true",
+ "config_url": "https://huggingface.co/becruily/mel-band-roformer-karaoke/resolve/main/config_karaoke_becruily.yaml?download=true"
+ },
+ "mbr_vocals_becruily": {
+ "category": "Вокал",
+ "id": 1063,
+ "full_name": "Mel-Band Roformer Vocals by Becruily",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/becruily/mel-band-roformer-vocals/resolve/main/mel_band_roformer_vocals_becruily.ckpt?download=true",
+ "config_url": "https://huggingface.co/becruily/mel-band-roformer-vocals/resolve/main/config_vocals_becruily.yaml?download=true"
+ },
+ "mbr_syhft1": {
+ "category": "Вокал",
+ "id": 1070,
+ "full_name": "Mel-Band Roformer SYHFT v1 by SYH99999",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFT/resolve/main/MelBandRoformerSYHFT.ckpt?download=true",
+ "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFT/resolve/main/config_vocals_mel_band_roformer_ft.yaml?download=true"
+ },
+ "mbr_syhft2": {
+ "category": "Вокал",
+ "id": 1071,
+ "full_name": "Mel-Band Roformer SYHFT v2 by SYH99999",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTV2/resolve/main/MelBandRoformerSYHFTV2.ckpt?download=true",
+ "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFT/resolve/main/config_vocals_mel_band_roformer_ft.yaml?download=true"
+ },
+ "mbr_syhft2.5": {
+ "category": "Вокал",
+ "id": 1072,
+ "full_name": "Mel-Band Roformer SYHFT v2.5 by SYH99999",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTV2.5/resolve/main/MelBandRoformerSYHFTV2.5.ckpt/MelBandRoformerSYHFTV2.5.ckpt?download=true",
+ "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFT/resolve/main/config_vocals_mel_band_roformer_ft.yaml?download=true"
+ },
+ "mbr_syhft3": {
+ "category": "Вокал",
+ "id": 1073,
+ "full_name": "Mel-Band Roformer SYHFT v3 by SYH99999",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTV3Epsilon/resolve/main/MelBandRoformerSYHFTV3Epsilon.ckpt?download=true",
+ "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFT/resolve/main/config_vocals_mel_band_roformer_ft.yaml?download=true"
+ },
+ "mbr_bigsyhft1fast": {
+ "category": "Вокал",
+ "id": 1074,
+ "full_name": "Mel-Band Roformer Big SYHFT v1 Fast by SYH99999",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerBigSYHFTV1Fast/resolve/main/MelBandRoformerBigSYHFTV1.ckpt?download=true",
+ "config_url": "https://huggingface.co/SYH99999/MelBandRoformerBigSYHFTV1Fast/resolve/main/config.yaml?download=true"
+ },
+ "mbr_syhftbeta1": {
+ "category": "Вокал",
+ "id": 1075,
+ "full_name": "Mel-Band Roformer Merged Beta v1 by SYH99999",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerMergedSYHFTBeta1/resolve/main/merge_syhft.ckpt?download=true",
+ "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFT/resolve/main/config_vocals_mel_band_roformer_ft.yaml?download=true"
+ },
+ "mbr_syhftB1_1": {
+ "category": "Вокал",
+ "id": 1076,
+ "full_name": "Mel-Band Roformer SYHFT B1 1 by SYH99999",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/model.ckpt?download=true",
+ "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/config.yaml?download=true"
+ },
+ "mbr_syhftB1_2": {
+ "category": "Вокал",
+ "id": 1077,
+ "full_name": "Mel-Band Roformer SYHFT B1 2 by SYH99999",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/model2.ckpt?download=true",
+ "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/config.yaml?download=true"
+ },
+ "mbr_syhftB1_3": {
+ "category": "Вокал",
+ "id": 1078,
+ "full_name": "Mel-Band Roformer SYHFT B1 3 by SYH99999",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/model3.ckpt?download=true",
+ "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/config.yaml?download=true"
+ },
+ "mbr_syhft_4stem": {
+ "category": "4 стема",
+ "id": 1079,
+ "full_name": "Mel-Band Roformer 4 Stems FT Large v1 by SYH99999",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/MelBandRoformer4StemFTLarge.ckpt?download=true",
+ "config_url": "https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/config.yaml?download=true"
+ },
+ "mbr_syhft_4stem2": {
+ "category": "4 стема",
+ "id": 1080,
+ "full_name": "Mel-Band Roformer 4 Stems FT Large v2 by SYH99999",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/ver2.ckpt?download=true",
+ "config_url": "https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/config.yaml?download=true"
+ },
+ "mbr_inst_1652_essid": {
+ "category": "Инструментал",
+ "id": 1085,
+ "full_name": "Mel-Band Roformer Instrumental by Essid (sdr 16.52)",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/Essid/Essid-MelBandRoformer/resolve/3960860f7895c87a12707ca6b378df2b3c68e2c0/model_mel_band_roformer_ep_17_sdr_16.5244.ckpt?download=true",
+ "config_url": "https://huggingface.co/Essid/Essid-MelBandRoformer/resolve/4768859bd59bc699d33f4567e82082993dde7eb9/config_vocals_mel_band_roformer_essid.yaml?download=true"
+ },
+ "mbr_inst_1681_essid": {
+ "category": "Инструментал",
+ "id": 1086,
+ "full_name": "Mel-Band Roformer Instrumental by Essid (sdr 16.81)",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/Essid/Essid-MelBandRoformer/resolve/main/essid_mel_inst_old.ckpt?download=true",
+ "config_url": "https://huggingface.co/Essid/Essid-MelBandRoformer/resolve/4768859bd59bc699d33f4567e82082993dde7eb9/config_vocals_mel_band_roformer_essid.yaml?download=true"
+ },
+ "mbr_instfv1_gabox": {
+ "category": "Инструментал",
+ "id": 1100,
+ "full_name": "Mel-Band Roformer Instrumental Fv1 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxFv1.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instfv2_gabox": {
+ "category": "Инструментал",
+ "id": 1101,
+ "full_name": "Mel-Band Roformer Instrumental Fv2 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxFv2.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instfv3_gabox": {
+ "category": "Инструментал",
+ "id": 1102,
+ "full_name": "Mel-Band Roformer Instrumental Fv3 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxFv3.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instfv4_gabox": {
+ "category": "Инструментал",
+ "id": 1117,
+ "full_name": "Mel-Band Roformer Instrumental Fv4 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_Fv4.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instfv4n_gabox": {
+ "category": "Инструментал",
+ "id": 1103,
+ "full_name": "Mel-Band Roformer Instrumental Fv4 Noise by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_Fv4Noise.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instfv5_gabox": {
+ "category": "Инструментал",
+ "id": 1104,
+ "full_name": "Mel-Band Roformer Instrumental Fv5 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV5.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instfv5n_gabox": {
+ "category": "Инструментал",
+ "id": 1105,
+ "full_name": "Mel-Band Roformer Instrumental Fv5 Noise by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV5N.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instfv6_gabox": {
+ "category": "Инструментал",
+ "id": 1106,
+ "full_name": "Mel-Band Roformer Instrumental Fv6 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV6.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instfv6n_gabox": {
+ "category": "Инструментал",
+ "id": 1107,
+ "full_name": "Mel-Band Roformer Instrumental Fv6 Noise by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV6N.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instfv7_gabox": {
+ "category": "Инструментал",
+ "id": 1108,
+ "full_name": "Mel-Band Roformer Instrumental Fv7 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/Inst_GaboxV7.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instfv7n_gabox": {
+ "category": "Инструментал",
+ "id": 1109,
+ "full_name": "Mel-Band Roformer Instrumental Fv7 Noise by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV7N.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instfv7+_gabox": {
+ "category": "Инструментал",
+ "id": 1110,
+ "full_name": "Mel-Band Roformer Instrumental Fv7+ by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/5fba9605d4b6bc1a31c04c50d08d757c5107d23f/melbandroformers/experimental/instv7plus.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instfv7z_gabox": {
+ "category": "Инструментал",
+ "id": 1111,
+ "full_name": "Mel-Band Roformer Instrumental Fv7z by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/Inst_GaboxFv7z.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instfv8_gabox": {
+ "category": "Инструментал",
+ "id": 1112,
+ "full_name": "Mel-Band Roformer Instrumental Fv8 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/Inst_GaboxFv8.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instfv8b_gabox": {
+ "category": "Инструментал",
+ "id": 1113,
+ "full_name": "Mel-Band Roformer Instrumental Fv8b by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/experimental/Inst_FV8b.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instfv9_gabox": {
+ "category": "Инструментал",
+ "id": 1114,
+ "full_name": "Mel-Band Roformer Instrumental Fv9 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/experimental/Inst_Fv9.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instfv10_gabox": {
+ "category": "Инструментал",
+ "id": 1115,
+ "full_name": "Mel-Band Roformer Instrumental Fv10 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/experimental/INSTV10.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instfvx_gabox": {
+ "category": "Инструментал",
+ "id": 1116,
+ "full_name": "Mel-Band Roformer Instrumental FvX by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/Inst_GaboxFVX.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instbv1_gabox": {
+ "category": "Инструментал",
+ "id": 1120,
+ "full_name": "Mel-Band Roformer Instrumental Bv1 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxBv1.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instbv2_gabox": {
+ "category": "Инструментал",
+ "id": 1121,
+ "full_name": "Mel-Band Roformer Instrumental Bv2 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxBv2.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_instbv3_gabox": {
+ "category": "Инструментал",
+ "id": 1122,
+ "full_name": "Mel-Band Roformer Instrumental Bv3 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxBv3.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_vocalsfv1_gabox": {
+ "category": "Вокал",
+ "id": 1130,
+ "full_name": "Mel-Band Roformer Vocals Fv1 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Vocals",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gaboxFv1.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml?download=true"
+ },
+ "mbr_vocalsfv2_gabox": {
+ "category": "Вокал",
+ "id": 1131,
+ "full_name": "Mel-Band Roformer Vocals Fv2 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Vocals",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gaboxFv2.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml?download=true"
+ },
+ "mbr_vocalsfv3_gabox": {
+ "category": "Вокал",
+ "id": 1132,
+ "full_name": "Mel-Band Roformer Vocals Fv3 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_Fv3.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml?download=true"
+ },
+ "mbr_vocalsfv4_gabox": {
+ "category": "Вокал",
+ "id": 1133,
+ "full_name": "Mel-Band Roformer Vocals Fv4 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Vocals",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_fv4.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml?download=true"
+ },
+ "mbr_vocalsfv5_gabox": {
+ "category": "Вокал",
+ "id": 1134,
+ "full_name": "Mel-Band Roformer Vocals Fv5 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Vocals",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_fv5.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml?download=true"
+ },
+ "mbr_vocalsfv6_gabox": {
+ "category": "Вокал",
+ "id": 1135,
+ "full_name": "Mel-Band Roformer Vocals Fv6 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Vocals",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_fv6.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml?download=true"
+ },
+ "mbr_karaoke25022025_gabox": {
+ "category": "Караоке",
+ "id": 1140,
+ "full_name": "Mel-Band Roformer Karaoke 25-02-2025 by GaboxR67",
+ "stems": [
+ "karaoke",
+ "other"
+ ],
+ "target_instrument": "karaoke",
+ "checkpoint_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/gabox_karaoke_25_02_2025.ckpt?download=true",
+ "config_url": "https://huggingface.co/jarredou/aufr33-viperx-karaoke-melroformer-model/resolve/main/config_mel_band_roformer_karaoke.yaml?download=true"
+ },
+ "mbr_karaoke28022025_gabox": {
+ "category": "Караоке",
+ "id": 1141,
+ "full_name": "Mel-Band Roformer Karaoke 28-02-2025 by GaboxR67",
+ "stems": [
+ "karaoke",
+ "other"
+ ],
+ "target_instrument": "karaoke",
+ "checkpoint_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/gabox_karaoke_28_02_2025.ckpt?download=true",
+ "config_url": "https://huggingface.co/jarredou/aufr33-viperx-karaoke-melroformer-model/resolve/main/config_mel_band_roformer_karaoke.yaml?download=true"
+ },
+ "mbr_karaoke1_gabox": {
+ "category": "Караоке",
+ "id": 1142,
+ "full_name": "Mel-Band Roformer Karaoke v1 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Vocals",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/27e73ca2beec0ab7daa46e366159753a166612e1/melbandroformers/karaoke/Karaoke_GaboxV1.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/27e73ca2beec0ab7daa46e366159753a166612e1/melbandroformers/karaoke/karaokegabox_1750911344.yaml?download=true"
+ },
+ "mbr_karaoke2_gabox": {
+ "category": "Караоке",
+ "id": 1143,
+ "full_name": "Mel-Band Roformer Karaoke v2 by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Vocals",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/experimental/Karaoke_GaboxV2.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/27e73ca2beec0ab7daa46e366159753a166612e1/melbandroformers/karaoke/karaokegabox_1750911344.yaml?download=true"
+ },
+ "mbr_leadvoc_dereverb_gabox": {
+ "category": "Реверб",
+ "id": 1144,
+ "full_name": "Mel-Band Roformer Lead Vocals DeReverb by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Vocals",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/experimental/Lead_VocalDereverb.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/27e73ca2beec0ab7daa46e366159753a166612e1/melbandroformers/karaoke/karaokegabox_1750911344.yaml?download=true"
+ },
+ "mbr_denoise_debleed_gabox": {
+ "category": "Шум",
+ "id": 1150,
+ "full_name": "Mel-Band Roformer Denoise DeBleed by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/denoisedebleed.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true"
+ },
+ "mbr_karaoke_fusion_gonzaluigi": {
+ "category": "Караоке",
+ "id": 1160,
+ "full_name": "Mel-Band Roformer Karaoke Fusion by Gonzaluigi",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Vocals",
+ "checkpoint_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/mel_band_karaoke_fusion_standard.ckpt?download=true",
+ "config_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/melband_karaokefusion_gonza.yaml?download=true"
+ },
+ "mbr_karaoke_fusion_aggr_gonzaluigi": {
+ "category": "Караоке",
+ "id": 1161,
+ "full_name": "Mel-Band Roformer Karaoke Fusion Aggressive by Gonzaluigi",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Vocals",
+ "checkpoint_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/mel_band_karaoke_fusion_aggressive.ckpt?download=true",
+ "config_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/melband_karaokefusion_gonza.yaml?download=true"
+ },
+ "mbr_bve_gonzaluigi": {
+ "category": "Караоке",
+ "id": 1162,
+ "full_name": "Mel-Band Roformer BVE by Gonzaluigi",
+ "stems": [
+ "Lead",
+ "Back"
+ ],
+ "target_instrument": "Lead",
+ "checkpoint_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Roformer-BVE-Gonzaluigi/resolve/main/mel_band_roformer_bve_gonza.ckpt?download=true",
+ "config_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Roformer-BVE-Gonzaluigi/resolve/main/config_bve_gonza.yaml?download=true"
+ },
+ "mbr_karaoke_fusion2_aggr_gonzaluigi": {
+ "category": "Караоке",
+ "id": 1163,
+ "full_name": "Mel-Band Roformer Karaoke Fusion Aggressive by Gonzaluigi",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Vocals",
+ "checkpoint_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/mel_band_karaoke_fusion_agg_v2.ckpt?download=true",
+ "config_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/karaoke_fusion_v2_config.yaml?download=true"
+ },
+ "mbr_karaoke_fusion_total_aggr_gonzaluigi": {
+ "category": "Караоке",
+ "id": 1164,
+ "full_name": "Mel-Band Roformer Karaoke Fusion Total by Gonzaluigi",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Vocals",
+ "checkpoint_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/mel_band_karaoke_fusion_total.ckpt?download=true",
+ "config_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/karaoke_fusion_total_config.yaml?download=true"
+ },
+ "mbr_dereverb_anvuew": {
+ "category": "Реверб",
+ "id": 1170,
+ "full_name": "Mel-Band Roformer DeReverb by Anvuew",
+ "stems": [
+ "reverb",
+ "noreverb"
+ ],
+ "target_instrument": "noreverb",
+ "checkpoint_url": "https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt?download=true",
+ "config_url": "https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew.yaml?download=true"
+ },
+ "mbr_dereverb_less_aggr_anvuew": {
+ "category": "Реверб",
+ "id": 1171,
+ "full_name": "Mel-Band Roformer DeReverb Less Aggressive by Anvuew",
+ "stems": [
+ "reverb",
+ "noreverb"
+ ],
+ "target_instrument": "noreverb",
+ "checkpoint_url": "https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_less_aggressive_anvuew_sdr_18.8050.ckpt?download=true",
+ "config_url": "https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew.yaml?download=true"
+ },
+ "mbr_dereverb_mono_anvuew": {
+ "category": "Реверб",
+ "id": 1172,
+ "full_name": "Mel-Band Roformer DeReverb Mono by Anvuew",
+ "stems": [
+ "reverb",
+ "noreverb"
+ ],
+ "target_instrument": "noreverb",
+ "checkpoint_url": "https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_mono_anvuew_sdr_20.4029.ckpt?download=true",
+ "config_url": "https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew.yaml?download=true"
+ },
+ "mbr_aspiration_sucial": {
+ "category": "Дыхание",
+ "id": 1180,
+ "full_name": "Mel-Band Roformer Aspiration by Sucial",
+ "stems": [
+ "aspiration",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Sucial/Aspiration_Mel_Band_Roformer/resolve/main/aspiration_mel_band_roformer_less_aggr_sdr_18.1201.ckpt?download=true",
+ "config_url": "https://huggingface.co/Sucial/Aspiration_Mel_Band_Roformer/resolve/main/config_aspiration_mel_band_roformer.yaml?download=true"
+ },
+ "mbr_derverb_echo1_sucial": {
+ "category": "Реверб и эхо",
+ "id": 1181,
+ "full_name": "Mel-Band Roformer DeReverb-Echo by Sucial",
+ "stems": [
+ "dry",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/dereverb-echo_mel_band_roformer_sdr_10.0169.ckpt?download=true",
+ "config_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/config_dereverb-echo_mel_band_roformer.yaml?download=true"
+ },
+ "mbr_debigreverb_sucial": {
+ "category": "Реверб",
+ "id": 1182,
+ "full_name": "Mel-Band Roformer DeBigReverb by Sucial",
+ "stems": [
+ "dry",
+ "other"
+ ],
+ "target_instrument": "dry",
+ "checkpoint_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/de_big_reverb_mbr_ep_362.ckpt?download=true",
+ "config_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/config_dereverb_echo_mbr_v2.yaml?download=true"
+ },
+ "mbr_desuperbigreverb_sucial": {
+ "category": "Реверб",
+ "id": 1183,
+ "full_name": "Mel-Band Roformer Super Big DeReverb by Sucial",
+ "stems": [
+ "dry",
+ "other"
+ ],
+ "target_instrument": "dry",
+ "checkpoint_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/de_super_big_reverb_mbr_ep_346.ckpt?download=true",
+ "config_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/config_dereverb_echo_mbr_v2.yaml?download=true"
+ },
+ "mbr_dereverb-echo_fused_sucial": {
+ "category": "Реверб и эхо",
+ "id": 1184,
+ "full_name": "Mel-Band Roformer DeReverb-Echo Fused by Sucial",
+ "stems": [
+ "dry",
+ "other"
+ ],
+ "target_instrument": "dry",
+ "checkpoint_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/dereverb_echo_mbr_fused_0.5_v2_0.25_big_0.25_super.ckpt?download=true",
+ "config_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/config_dereverb_echo_mbr_v2.yaml?download=true"
+ },
+ "mbr_dereverb-echo2_sucial": {
+ "category": "Реверб и эхо",
+ "id": 1185,
+ "full_name": "Mel-Band Roformer DeReverb-Echo v2 by Sucial",
+ "stems": [
+ "dry",
+ "other"
+ ],
+ "target_instrument": "dry",
+ "checkpoint_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/dereverb_echo_mbr_v2_sdr_dry_13.4843.ckpt?download=true",
+ "config_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/config_dereverb_echo_mbr_v2.yaml?download=true"
+ },
+ "mbr_karaoke_aufr33_viperx": {
+ "category": "Караоке",
+ "id": 1190,
+ "full_name": "Mel-Band Roformer Karaoke by Aufr33 & ViperX",
+ "stems": [
+ "karaoke",
+ "other"
+ ],
+ "target_instrument": "karaoke",
+ "checkpoint_url": "https://huggingface.co/jarredou/aufr33-viperx-karaoke-melroformer-model/resolve/main/mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956.ckpt?download=true",
+ "config_url": "https://huggingface.co/jarredou/aufr33-viperx-karaoke-melroformer-model/resolve/main/config_mel_band_roformer_karaoke.yaml?download=true"
+ },
+ "mbr_denoise_aufr33": {
+ "category": "Шум",
+ "id": 1191,
+ "full_name": "Mel-Band Roformer DeNoise by Aufr33",
+ "stems": [
+ "dry",
+ "other"
+ ],
+ "target_instrument": "dry",
+ "checkpoint_url": "https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/denoise_mel_band_roformer_aufr33_aggr_sdr_27.9768.ckpt?download=true",
+ "config_url": "https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/model_mel_band_roformer_denoise.yaml?download=true"
+ },
+ "mbr_denoise_aggr_aufr33": {
+ "category": "Шум",
+ "id": 1192,
+ "full_name": "Mel-Band Roformer DeNoise Aggressive by Aufr33",
+ "stems": [
+ "dry",
+ "other"
+ ],
+ "target_instrument": "dry",
+ "checkpoint_url": "https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/denoise_mel_band_roformer_aufr33_aggr_sdr_27.9768.ckpt?download=true",
+ "config_url": "https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/model_mel_band_roformer_denoise.yaml?download=true"
+ },
+ "mbr_crowd_aufr33_viperx": {
+ "category": "Звуки толпы",
+ "id": 1193,
+ "full_name": "Mel-Band Roformer Crowd by Aufr33 & ViperX",
+ "stems": [
+ "crowd",
+ "other"
+ ],
+ "target_instrument": "crowd",
+ "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.4/mel_band_roformer_crowd_aufr33_viperx_sdr_8.7144.ckpt",
+ "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.4/model_mel_band_roformer_crowd.yaml"
+ },
+ "mbr_vocals_viperx": {
+ "category": "Вокал",
+ "id": 1194,
+ "full_name": "Mel-Band Roformer Vocals by ViperX",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt",
+ "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/viperx/model_mel_band_roformer_ep_3005_sdr_11.4360.yaml"
+ },
+ "mbr_vocalsf_aname": {
+ "category": "Вокал",
+ "id": 1200,
+ "full_name": "Mel-Band Roformer Vocals Fullness by Aname",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Aname-Tommy/MelBandRoformers/resolve/main/FullnessVocalModel.ckpt?download=true",
+ "config_url": "https://huggingface.co/Aname-Tommy/MelBandRoformers/resolve/main/config.yaml?download=true"
+ },
+ "mbr_kinft1_aname": {
+ "category": "Вокал",
+ "id": 1201,
+ "full_name": "Mel-Band Roformer Kim FT v1 by Aname",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Aname-Tommy/Test/resolve/main/model_kim.ckpt?download=true",
+ "config_url": "https://huggingface.co/Aname-Tommy/MelBandRoformers/resolve/main/config.yaml?download=true"
+ },
+ "mbr_kinft2_aname": {
+ "category": "Вокал",
+ "id": 1202,
+ "full_name": "Mel-Band Roformer Kim FT v2 by Aname",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Aname-Tommy/Test/resolve/main/model_kim_2.ckpt?download=true",
+ "config_url": "https://huggingface.co/Aname-Tommy/MelBandRoformers/resolve/main/config.yaml?download=true"
+ },
+ "mbr_kinft2f_aname": {
+ "category": "Вокал",
+ "id": 1203,
+ "full_name": "Mel-Band Roformer Kim FT v2 Fullness by Aname",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Aname-Tommy/Test/resolve/main/model_kim_2_fullness.ckpt?download=true",
+ "config_url": "https://huggingface.co/Aname-Tommy/MelBandRoformers/resolve/main/config.yaml?download=true"
+ },
+ "mbr_kinft3_aname": {
+ "category": "Вокал",
+ "id": 1204,
+ "full_name": "Mel-Band Roformer Kim FT v3 by Aname",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Aname-Tommy/Test/resolve/main/model_kim_3.ckpt?download=true",
+ "config_url": "https://huggingface.co/Aname-Tommy/MelBandRoformers/resolve/main/config.yaml?download=true"
+ },
+ "mbr_small_aname": {
+ "category": "Вокал",
+ "id": 1205,
+ "full_name": "Mel-Band Roformer Small by Aname",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Aname-Tommy/Mel_Band_Roformer_small/resolve/main/mel_band_roformer_small.ckpt?download=true",
+ "config_url": "https://huggingface.co/Aname-Tommy/Mel_Band_Roformer_small/resolve/main/config.yaml?download=true"
+ },
+ "mbr_duality1_aname": {
+ "category": "Вокал",
+ "id": 1209,
+ "full_name": "Mel-Band Roformer Duality v1 by Aname",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Aname-Tommy/Mel-Band-Roformer_Duality/resolve/main/duality_v1.ckpt?download=true",
+ "config_url": "https://huggingface.co/Aname-Tommy/Mel-Band-Roformer_Duality/resolve/main/config_v1.yaml?download=true"
+ },
+ "mbr_4stemlarge1_aname": {
+ "category": "4 стема",
+ "id": 1206,
+ "full_name": "Mel-Band Roformer 4 Stems Large by Aname",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Aname-Tommy/melbandroformer4stems/resolve/main/mel_band_roformer_4stems_large_ver1.ckpt?download=true",
+ "config_url": "https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/config.yaml?download=true"
+ },
+ "mbr_4stemlarge2_aname": {
+ "category": "4 стема",
+ "id": 1207,
+ "full_name": "Mel-Band Roformer 4 Stems v2 Large by Aname",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Aname-Tommy/Test/resolve/main/4stemsver2.ckpt?download=true",
+ "config_url": "https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/config.yaml?download=true"
+ },
+ "mbr_4stemxl1_aname": {
+ "category": "4 стема",
+ "id": 1208,
+ "full_name": "Mel-Band Roformer 4 Stems XL by Aname",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Aname-Tommy/melbandroformer4stems/resolve/main/mel_band_roformer_4stems_xl_ver1.ckpt?download=true",
+ "config_url": "https://huggingface.co/Aname-Tommy/melbandroformer4stems/resolve/main/config_xl.yaml?download=true"
+ },
+ "mbr_percussion_yolkispaliks": {
+ "category": "Ударные",
+ "id": 1900,
+ "full_name": "Mel-Band Roformer Drums Experimental by yolkispalkis",
+ "stems": [
+ "percussions",
+ "other"
+ ],
+ "target_instrument": "percussions",
+ "checkpoint_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/model_mel_band_roformer_ep_11_sdr_7.6853.ckpt?download=true",
+ "config_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/config_drums_musdb18_moises_mel_band_roformer.yaml?download=true"
+ },
+ "mbr_inst_metal_prev_meskvlla33": {
+ "category": "Инструментал",
+ "id": 1901,
+ "full_name": "Mel-Band Roformer Metal Inst Preview by Mesk",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "other",
+ "checkpoint_url": "https://huggingface.co/meskvlla33/metal_roformer_preview/resolve/main/metal_roformer_inst_mesk_preview.ckpt?download=true",
+ "config_url": "https://huggingface.co/meskvlla33/metal_roformer_preview/resolve/main/config_inst_metal_roformer_mesk.yaml?download=true"
+ },
+ "mbr_neo_inst_vfx": {
+ "category": "Инструментал",
+ "id": 1902,
+ "full_name": "Mel-Band Roformer NEO Inst VFX by natanworkspace",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Instrumental",
+ "checkpoint_url": "https://huggingface.co/natanworkspace/melband_roformer/resolve/main/Neo_InstVFX.ckpt?download=true",
+ "config_url": "https://huggingface.co/natanworkspace/melband_roformer/resolve/main/config_neo_inst.yaml?download=true"
+ }
+ },
+ "bs_roformer": {
+ "bs_drums_beatloo_labs": {
+ "category": "Ударные",
+ "id": 200,
+ "full_name": "BS Roformer Drums Experimental by BeatLoo Labs",
+ "stems": [
+ "drums",
+ "other"
+ ],
+ "target_instrument": "drums",
+ "checkpoint_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/model_drums_bs_roformer_ep_12_sdr_7.2279.ckpt?download=true",
+ "config_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/BS-Roformer_Drums_beatloo_labs_config.yaml?download=true"
+ },
+ "bs_bass_beatloo_labs": {
+ "category": "Басс",
+ "id": 201,
+ "full_name": "BS Roformer Bass Experimental by BeatLoo Labs",
+ "stems": [
+ "bass",
+ "other"
+ ],
+ "target_instrument": "bass",
+ "checkpoint_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/model_bass_bs_roformer_ep_10_sdr_5.7862.ckpt?download=true",
+ "config_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/BS-Roformer_Bass_beatloo_labs_config.yaml?download=true"
+ },
+ "bs_vocals_1296_viperx": {
+ "category": "Вокал",
+ "id": 202,
+ "full_name": "BS Roformer Vocals (sdr 12.96) by ViperX",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Vocals",
+ "checkpoint_url": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_bs_roformer_ep_368_sdr_12.9628.ckpt",
+ "config_url": "https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/mdx_c_configs/model_bs_roformer_ep_368_sdr_12.9628.yaml"
+ },
+ "bs_other_viperx": {
+ "category": "Прочее",
+ "id": 203,
+ "full_name": "BS Roformer Other by ViperX",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "other",
+ "checkpoint_url": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_bs_roformer_ep_937_sdr_10.5309.ckpt",
+ "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/viperx/model_bs_roformer_ep_937_sdr_10.5309.yaml"
+ },
+ "bs_revive1_unwa": {
+ "category": "Вокал",
+ "id": 210,
+ "full_name": "BS Roformer Vocals Revive v1 by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/pcunwa/BS-Roformer-Revive/resolve/main/bs_roformer_revive.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/BS-Roformer-Revive/resolve/main/config.yaml?download=true"
+ },
+ "bs_revive2_unwa": {
+ "category": "Вокал",
+ "id": 211,
+ "full_name": "BS Roformer Vocals Revive v2 by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/pcunwa/BS-Roformer-Revive/resolve/main/bs_roformer_revive2.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/BS-Roformer-Revive/resolve/main/config.yaml?download=true"
+ },
+ "bs_revive3e_unwa": {
+ "category": "Вокал",
+ "id": 212,
+ "full_name": "BS Roformer Vocals Revive v3e by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/pcunwa/BS-Roformer-Revive/resolve/main/bs_roformer_revive3e.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/BS-Roformer-Revive/resolve/main/config.yaml?download=true"
+ },
+ "bs_resurrection_unwa": {
+ "category": "Вокал",
+ "id": 213,
+ "full_name": "BS Roformer Vocals Resurrection by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/pcunwa/BS-Roformer-Resurrection/resolve/main/BS-Roformer-Resurrection.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/BS-Roformer-Resurrection/resolve/main/BS-Roformer-Resurrection-Config.yaml?download=true"
+ },
+ "bs_resurrection_inst_unwa": {
+ "category": "Инструментал",
+ "id": 214,
+ "full_name": "BS Roformer Instrumental Resurrection by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "other",
+ "checkpoint_url": "https://huggingface.co/pcunwa/BS-Roformer-Resurrection/resolve/main/BS-Roformer-Resurrection-Inst.ckpt?download=true",
+ "config_url": "https://huggingface.co/pcunwa/BS-Roformer-Resurrection/resolve/main/BS-Roformer-Resurrection-Inst-Config.yaml?download=true"
+ },
+ "bs_inst_fno_unwa": {
+ "category": "Инструментал",
+ "id": 226,
+ "full_name": "BS Roformer Instrumental FNO by Unwa",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "other",
+ "checkpoint_url": "https://huggingface.co/pcunwa/BS-Roformer-Inst-FNO/resolve/main/bs_roformer_fno.ckpt?download=true",
+ "config_url": "https://raw.githubusercontent.com/noblebarkrr/mvsepless/refs/heads/beta/fixed_configs/BS-Roformer_Inst_FNO_unwa_config.yaml"
+ },
+ "bs_karaoke_becruily": {
+ "category": "Караоке",
+ "id": 215,
+ "full_name": "BS Roformer Karaoke by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Vocals",
+ "checkpoint_url": "https://huggingface.co/becruily/bs-roformer-karaoke/resolve/main/bs_roformer_karaoke_frazer_becruily.ckpt?download=true",
+ "config_url": "https://huggingface.co/becruily/bs-roformer-karaoke/resolve/main/config_karaoke_frazer_becruily.yaml?download=true"
+ },
+ "bs_voctest_gabox": {
+ "category": "Вокал",
+ "id": 216,
+ "full_name": "BS Roformer Vocals by GaboxR67",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Vocals",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/BSRoformerVocTest/resolve/main/voc_gaboxBSR.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/BSRoformerVocTest/resolve/main/voc_gaboxBSroformer.yaml?download=true"
+ },
+ "bs_karaoke_gabox": {
+ "category": "Караоке",
+ "id": 229,
+ "full_name": "BS Roformer Karaoke by GaboxR67",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/aa0a94097a222fcc8ebdd691b1435145eda31b46/bsroformers/bs_karaoke_gabox_IS.ckpt?download=true",
+ "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/aa0a94097a222fcc8ebdd691b1435145eda31b46/bsroformers/karaoke_bs_roformer.yaml?download=true"
+ },
+ "bs_6stem": {
+ "category": "6 стемов",
+ "id": 217,
+ "full_name": "BS Roformer SW",
+ "stems": [
+ "bass",
+ "drums",
+ "other",
+ "piano",
+ "guitar",
+ "vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://github.com/undef13/splifft/releases/download/v0.0.1/roformer-fp16.pt",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/BS-Roformer_SW_config.yaml?download=true"
+ },
+ "bs_6stem_fixed": {
+ "category": "6 стемов",
+ "id": 227,
+ "full_name": "BS Roformer SW Fixed by jarredou",
+ "stems": [
+ "bass",
+ "drums",
+ "other",
+ "piano",
+ "guitar",
+ "vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/jarredou/BS-ROFO-SW-Fixed/resolve/main/BS-Rofo-SW-Fixed.ckpt?download=true",
+ "config_url": "https://huggingface.co/jarredou/BS-ROFO-SW-Fixed/resolve/main/BS-Rofo-SW-Fixed.yaml?download=true"
+ },
+ "bs_4stem_zfturbo": {
+ "category": "4 стема",
+ "id": 218,
+ "full_name": "BS Roformer 4 Stems by ZFTurbo",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.12/model_bs_roformer_ep_17_sdr_9.6568.ckpt",
+ "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.12/config_bs_roformer_384_8_2_485100.yaml"
+ },
+ "bs_4stemft_syh99999": {
+ "category": "4 стема",
+ "id": 219,
+ "full_name": "BS Roformer 4 Stems FT by SYH99999",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/SYH99999/bs_roformer_4stems_ft/resolve/main/bs_roformer_4stems_ft.pth?download=true",
+ "config_url": "https://huggingface.co/SYH99999/bs_roformer_4stems_ft/resolve/main/config.yaml?download=true"
+ },
+ "bs_male_female_146_sucial": {
+ "category": "Мужской/Женский вокал",
+ "id": 220,
+ "full_name": "BS Roformer Male-Female (ep 146) by Sucial",
+ "stems": [
+ "male",
+ "female"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Sucial/Chorus_Male_Female_BS_Roformer/resolve/main/model_chorus_bs_roformer_ep_146_sdr_23.8613.ckpt?download=true",
+ "config_url": "https://huggingface.co/Sucial/Chorus_Male_Female_BS_Roformer/resolve/main/config_chorus_male_female_bs_roformer.yaml?download=true"
+ },
+ "bs_male_female_267_sucial": {
+ "category": "Мужской/Женский вокал",
+ "id": 221,
+ "full_name": "BS Roformer Male-Female (ep 267) by Sucial",
+ "stems": [
+ "male",
+ "female"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Sucial/Chorus_Male_Female_BS_Roformer/resolve/main/model_chorus_bs_roformer_ep_267_sdr_24.1275.ckpt?download=true",
+ "config_url": "https://huggingface.co/Sucial/Chorus_Male_Female_BS_Roformer/resolve/main/config_chorus_male_female_bs_roformer.yaml?download=true"
+ },
+ "bs_male_female_aufr33": {
+ "category": "Мужской/Женский вокал",
+ "id": 222,
+ "full_name": "BS Roformer Male-Female by Aufr33",
+ "stems": [
+ "male",
+ "female"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/RareSirMix/AIModelRehosting/resolve/main/bs_roformer_male_female_by_aufr33_sdr_7.2889.ckpt",
+ "config_url": "https://huggingface.co/Sucial/Chorus_Male_Female_BS_Roformer/resolve/main/config_chorus_male_female_bs_roformer.yaml"
+ },
+ "bs_deverb_256_8_anvuew": {
+ "category": "Реверб",
+ "id": 223,
+ "full_name": "BS Roformer Deverb 256-8 by Anvuew",
+ "stems": [
+ "reverb",
+ "noreverb"
+ ],
+ "target_instrument": "noreverb",
+ "checkpoint_url": "https://huggingface.co/anvuew/deverb_bs_roformer/resolve/main/deverb_bs_roformer_8_256dim_8depth.ckpt?download=true",
+ "config_url": "https://huggingface.co/anvuew/deverb_bs_roformer/resolve/main/deverb_bs_roformer_8_256dim_8depth.yaml?download=true"
+ },
+ "bs_deverb_384_10_anvuew": {
+ "category": "Реверб",
+ "id": 224,
+ "full_name": "BS Roformer Deverb 384-10 by Anvuew",
+ "stems": [
+ "reverb",
+ "noreverb"
+ ],
+ "target_instrument": "noreverb",
+ "checkpoint_url": "https://huggingface.co/anvuew/deverb_bs_roformer/resolve/main/deverb_bs_roformer_8_384dim_10depth.ckpt?download=true",
+ "config_url": "https://huggingface.co/anvuew/deverb_bs_roformer/resolve/main/deverb_bs_roformer_8_384dim_10depth.yaml?download=true"
+ },
+ "bs_karaoke_anvuew": {
+ "category": "Караоке",
+ "id": 228,
+ "full_name": "BS Roformer Karaoke by Anvuew",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": "Vocals",
+ "checkpoint_url": "https://huggingface.co/anvuew/karaoke_bs_roformer/resolve/main/karaoke_bs_roformer_anvuew.ckpt?download=true",
+ "config_url": "https://huggingface.co/anvuew/karaoke_bs_roformer/resolve/main/karaoke_bs_roformer_anvuew.yaml?download=true"
+ },
+ "bs_vocals_anvuew": {
+ "category": "Вокал",
+ "id": 230,
+ "full_name": "BS Roformer Vocals by Anvuew",
+ "stems": [
+ "vocals",
+ "instrument"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/anvuew/BS-RoFormer/resolve/main/bs_roformer_anvuew_sdr_12.45.ckpt?download=true",
+ "config_url": "https://huggingface.co/anvuew/BS-RoFormer/resolve/main/config.yaml?download=true"
+ },
+ "bs_4stem_aname": {
+ "category": "4 стема",
+ "id": 225,
+ "full_name": "BS Roformer 4 stems by Aname",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/RareSirMix/AIModelRehosting/resolve/main/Amane4stem_bs_roformer.ckpt",
+ "config_url": "https://huggingface.co/RareSirMix/AIModelRehosting/resolve/main/Amane4stem_bs_roformer.yaml"
+ }
+ },
+ "mdx23c": {
+ "mdx23c_instvoc_zfturbo": {
+ "category": "Инструментал и вокал",
+ "id": 300,
+ "full_name": "MDX23C Inst-Voc HQ by ZFTurbo",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.0/model_vocals_mdx23c_sdr_10.17.ckpt",
+ "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_vocals_mdx23c.yaml"
+ },
+ "mdx23c_instvoc_hq1": {
+ "category": "Инструментал и вокал",
+ "id": 301,
+ "full_name": "MDX23C 8k FFT Inst-Voc HQ v1",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDX23C/MDX23C-8KFFT-InstVoc_HQ.ckpt?download=true",
+ "config_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDX23C/model_2_stem_full_band_8k.yaml?download=true"
+ },
+ "mdx23c_instvoc_hq2": {
+ "category": "Инструментал и вокал",
+ "id": 302,
+ "full_name": "MDX23C 8k FFT Inst-Voc HQ v2",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDX23C/MDX23C-8KFFT-InstVoc_HQ_2.ckpt?download=true",
+ "config_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDX23C/model_2_stem_full_band_8k.yaml?download=true"
+ },
+ "mdx23c_d1581": {
+ "category": "Инструментал и вокал",
+ "id": 303,
+ "full_name": "MDX23C D1581",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDX23C/MDX23C_D1581.ckpt?download=true",
+ "config_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDX23C/model_2_stem_061321.yaml?download=true"
+ },
+ "mdx23c_drumsep_6stem_aufr33_jarredou": {
+ "category": "DrumSep",
+ "id": 304,
+ "full_name": "MDX23C DrumSep 6 stems by Aufr33 & Jarredou",
+ "stems": [
+ "kick",
+ "snare",
+ "toms",
+ "hh",
+ "ride",
+ "crash"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://github.com/jarredou/models/releases/download/aufr33-jarredou_MDX23C_DrumSep_model_v0.1/aufr33-jarredou_DrumSep_model_mdx23c_ep_141_sdr_10.8059.ckpt",
+ "config_url": "https://github.com/jarredou/models/releases/download/aufr33-jarredou_MDX23C_DrumSep_model_v0.1/aufr33-jarredou_DrumSep_model_mdx23c_ep_141_sdr_10.8059.yaml"
+ },
+ "mdx23c_drumsep_5stem_aufr33_jarredou": {
+ "category": "DrumSep",
+ "id": 309,
+ "full_name": "MDX23C DrumSep 5 stems by Aufr33 & Jarredou",
+ "stems": [
+ "kick",
+ "snare",
+ "toms",
+ "hh",
+ "cymbals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://github.com/jarredou/models/releases/download/DrumSep/drumsep_5stems_mdx23c_jarredou.ckpt",
+ "config_url": "https://github.com/jarredou/models/releases/download/DrumSep/config_mdx23c.yaml"
+ },
+ "mdx23c_derverb_aufr33_jarredou": {
+ "category": "Реверб",
+ "id": 305,
+ "full_name": "MDX23C DeReverb by Aufr33 & Jarredou",
+ "stems": [
+ "dry",
+ "other"
+ ],
+ "target_instrument": "dry",
+ "checkpoint_url": "https://huggingface.co/jarredou/aufr33_jarredou_MDXv3_DeReverb/resolve/main/dereverb_mdx23c_sdr_6.9096.ckpt",
+ "config_url": "https://huggingface.co/jarredou/aufr33_jarredou_MDXv3_DeReverb/resolve/main/config_dereverb_mdx23c.yaml"
+ },
+ "mdx23c_mid_side_wesleyr36": {
+ "category": "Фантомный центр",
+ "id": 306,
+ "full_name": "MDX23C Mid-Side by WesleyR36",
+ "stems": [
+ "similarity",
+ "difference"
+ ],
+ "target_instrument": "similarity",
+ "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.10/model_mdx23c_ep_271_l1_freq_72.2383.ckpt",
+ "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.10/config_mdx23c_similarity.yaml"
+ },
+ "mdx23c_4stem_zfturbo": {
+ "category": "4 стема",
+ "id": 307,
+ "full_name": "MDX23C 4 Stems by ZFTurbo",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.1/model_mdx23c_ep_168_sdr_7.0207.ckpt",
+ "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.1/config_musdb18_mdx23c.yaml"
+ },
+ "mdx23c_orch_verosment": {
+ "category": "Оркестр",
+ "id": 308,
+ "full_name": "MDX23C Orchestra Experimental by Verosment",
+ "stems": [
+ "inst",
+ "orch"
+ ],
+ "target_instrument": "orch",
+ "checkpoint_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/model_mdx23c_ep_120_sdr_4.4174.ckpt?download=true",
+ "config_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/config_orchestra_mdx23c.yaml?download=true"
+ }
+ },
+ "mdxnet": {
+ "mdx_kim_inst": {
+ "category": "Инструментал",
+ "id": 600,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Kim Inst",
+ "stems": [
+ "instrumental",
+ "vocals"
+ ],
+ "target_instrument": "instrumental",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/Kim_Inst.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/Kim_Inst.yaml?download=true"
+ },
+ "mdx_kim_vocal1": {
+ "category": "Вокал",
+ "id": 601,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Kim Vocal 1",
+ "stems": [
+ "vocals",
+ "instrumental"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/Kim_Vocal_1.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/Kim_Vocal_1.yaml?download=true"
+ },
+ "mdx_kim_vocal2": {
+ "category": "Вокал",
+ "id": 602,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Kim Vocal 2",
+ "stems": [
+ "vocals",
+ "instrumental"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/Kim_Vocal_2.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/Kim_Vocal_2.yaml?download=true"
+ },
+ "mdx_kuielab_a_bass": {
+ "category": "Басс",
+ "id": 603,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Kuielab A Bass",
+ "stems": [
+ "bass",
+ "other"
+ ],
+ "target_instrument": "bass",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/kuielab_a_bass.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/kuielab_a_bass.yaml?download=true"
+ },
+ "mdx_kuielab_a_drums": {
+ "category": "Ударные",
+ "id": 604,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Kuielab A Drums",
+ "stems": [
+ "drums",
+ "other"
+ ],
+ "target_instrument": "drums",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/kuielab_a_drums.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/kuielab_a_drums.yaml?download=true"
+ },
+ "mdx_kuielab_a_other": {
+ "category": "Прочее",
+ "id": 605,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Kuielab A Other",
+ "stems": [
+ "other",
+ "no_other"
+ ],
+ "target_instrument": "other",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/kuielab_a_other.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/kuielab_a_other.yaml?download=true"
+ },
+ "mdx_kuielab_a_vocals": {
+ "category": "Вокал",
+ "id": 606,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Kuielab A Vocals",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/kuielab_a_vocals.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/kuielab_a_vocals.yaml?download=true"
+ },
+ "mdx_kuielab_b_bass": {
+ "category": "Басс",
+ "id": 607,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Kuielab B Bass",
+ "stems": [
+ "bass",
+ "other"
+ ],
+ "target_instrument": "bass",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/kuielab_b_bass.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/kuielab_b_bass.yaml?download=true"
+ },
+ "mdx_kuielab_b_drums": {
+ "category": "Ударные",
+ "id": 608,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Kuielab B Drums",
+ "stems": [
+ "drums",
+ "other"
+ ],
+ "target_instrument": "drums",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/kuielab_b_drums.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/kuielab_b_drums.yaml?download=true"
+ },
+ "mdx_kuielab_b_other": {
+ "category": "Прочее",
+ "id": 609,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Kuielab B Other",
+ "stems": [
+ "other",
+ "no_other"
+ ],
+ "target_instrument": "other",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/kuielab_b_other.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/kuielab_b_other.yaml?download=true"
+ },
+ "mdx_kuielab_b_vocals": {
+ "category": "Вокал",
+ "id": 610,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Kuielab B Vocals",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/kuielab_b_vocals.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/kuielab_b_vocals.yaml?download=true"
+ },
+ "mdx_reverb_hq_foxjoy": {
+ "category": "Реверб",
+ "id": 611,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Reverb HQ FoxJoy",
+ "stems": [
+ "reverb",
+ "other"
+ ],
+ "target_instrument": "reverb",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/Reverb_HQ_By_FoxJoy.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/Reverb_HQ_By_FoxJoy.yaml?download=true"
+ },
+ "mdx_inst1": {
+ "category": "Инструментал",
+ "id": 612,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Inst 1",
+ "stems": [
+ "instrumental",
+ "vocals"
+ ],
+ "target_instrument": "instrumental",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_1.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_1.yaml?download=true"
+ },
+ "mdx_inst2": {
+ "category": "Инструментал",
+ "id": 613,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Inst 2",
+ "stems": [
+ "instrumental",
+ "vocals"
+ ],
+ "target_instrument": "instrumental",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_2.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_2.yaml?download=true"
+ },
+ "mdx_inst3": {
+ "category": "Инструментал",
+ "id": 614,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Inst 3",
+ "stems": [
+ "instrumental",
+ "vocals"
+ ],
+ "target_instrument": "instrumental",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_3.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_3.yaml?download=true"
+ },
+ "mdx_inst_full_292": {
+ "category": "Инструментал",
+ "id": 615,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Inst Full 292",
+ "stems": [
+ "instrumental",
+ "vocals"
+ ],
+ "target_instrument": "instrumental",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_full_292.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_full_292.yaml?download=true"
+ },
+ "mdx_inst_hq1": {
+ "category": "Инструментал",
+ "id": 616,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Inst HQ 1",
+ "stems": [
+ "instrumental",
+ "vocals"
+ ],
+ "target_instrument": "instrumental",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_HQ_1.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_HQ_1.yaml?download=true"
+ },
+ "mdx_inst_hq2": {
+ "category": "Инструментал",
+ "id": 617,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Inst HQ 2",
+ "stems": [
+ "instrumental",
+ "vocals"
+ ],
+ "target_instrument": "instrumental",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_HQ_2.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_HQ_2.yaml?download=true"
+ },
+ "mdx_inst_hq3": {
+ "category": "Инструментал",
+ "id": 618,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Inst HQ 3",
+ "stems": [
+ "instrumental",
+ "vocals"
+ ],
+ "target_instrument": "instrumental",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_HQ_3.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_HQ_3.yaml?download=true"
+ },
+ "mdx_inst_hq4": {
+ "category": "Инструментал",
+ "id": 619,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Inst HQ 4",
+ "stems": [
+ "instrumental",
+ "vocals"
+ ],
+ "target_instrument": "instrumental",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_HQ_4.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_HQ_4.yaml?download=true"
+ },
+ "mdx_inst_hq5": {
+ "category": "Инструментал",
+ "id": 620,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Inst HQ 5",
+ "stems": [
+ "instrumental",
+ "vocals"
+ ],
+ "target_instrument": "instrumental",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_HQ_5.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_HQ_5.yaml?download=true"
+ },
+ "mdx_inst_main": {
+ "category": "Инструментал",
+ "id": 621,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Inst Main",
+ "stems": [
+ "instrumental",
+ "vocals"
+ ],
+ "target_instrument": "instrumental",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_Main.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_Main.yaml?download=true"
+ },
+ "mdx_vocft": {
+ "category": "Вокал",
+ "id": 622,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Voc FT",
+ "stems": [
+ "vocals",
+ "instrumental"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Voc_FT.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Voc_FT.yaml?download=true"
+ },
+ "mdx_crowd_hq1": {
+ "category": "Звуки толпы",
+ "id": 623,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Crowd HQ 1",
+ "stems": [
+ "other",
+ "crowd"
+ ],
+ "target_instrument": "other",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET_Crowd_HQ_1.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET_Crowd_HQ_1.yaml?download=true"
+ },
+ "mdx_inst_187_beta": {
+ "category": "Инструментал",
+ "id": 624,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Inst 187 beta",
+ "stems": [
+ "instrumental",
+ "vocals"
+ ],
+ "target_instrument": "instrumental",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET_Inst_187_beta.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET_Inst_187_beta.yaml?download=true"
+ },
+ "mdx_inst_82_beta": {
+ "category": "Инструментал",
+ "id": 625,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Inst 82 beta",
+ "stems": [
+ "instrumental",
+ "vocals"
+ ],
+ "target_instrument": "instrumental",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET_Inst_82_beta.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET_Inst_82_beta.yaml?download=true"
+ },
+ "mdx_inst_90_beta": {
+ "category": "Инструментал",
+ "id": 626,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Inst 90 beta",
+ "stems": [
+ "instrumental",
+ "vocals"
+ ],
+ "target_instrument": "instrumental",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET_Inst_90_beta.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET_Inst_90_beta.yaml?download=true"
+ },
+ "mdx_main_340": {
+ "category": "Вокал",
+ "id": 627,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Main 340",
+ "stems": [
+ "vocals",
+ "instrumental"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET_Main_340.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET_Main_340.yaml?download=true"
+ },
+ "mdx_main_390": {
+ "category": "Вокал",
+ "id": 628,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Main 390",
+ "stems": [
+ "vocals",
+ "instrumental"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET_Main_390.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET_Main_390.yaml?download=true"
+ },
+ "mdx_main_406": {
+ "category": "Вокал",
+ "id": 629,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Main 406",
+ "stems": [
+ "vocals",
+ "instrumental"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET_Main_406.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET_Main_406.yaml?download=true"
+ },
+ "mdx_main_427": {
+ "category": "Вокал",
+ "id": 630,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Main 427",
+ "stems": [
+ "vocals",
+ "instrumental"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET_Main_427.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET_Main_427.yaml?download=true"
+ },
+ "mdx_main_438": {
+ "category": "Вокал",
+ "id": 631,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Main 438",
+ "stems": [
+ "vocals",
+ "instrumental"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET_Main_438.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET_Main_438.yaml?download=true"
+ },
+ "mdx_1_9703": {
+ "category": "Вокал",
+ "id": 632,
+ "full_name": "MDX-Net Model: UVR-MDX-NET 1 9703",
+ "stems": [
+ "vocals",
+ "instrumental"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR_MDXNET_1_9703.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR_MDXNET_1_9703.yaml?download=true"
+ },
+ "mdx_2_9682": {
+ "category": "Вокал",
+ "id": 633,
+ "full_name": "MDX-Net Model: UVR-MDX-NET 2 9682",
+ "stems": [
+ "vocals",
+ "instrumental"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR_MDXNET_2_9682.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR_MDXNET_2_9682.yaml?download=true"
+ },
+ "mdx_3_9662": {
+ "category": "Вокал",
+ "id": 634,
+ "full_name": "MDX-Net Model: UVR-MDX-NET 3 9662",
+ "stems": [
+ "vocals",
+ "instrumental"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR_MDXNET_3_9662.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR_MDXNET_3_9662.yaml?download=true"
+ },
+ "mdx_9482": {
+ "category": "Вокал",
+ "id": 635,
+ "full_name": "MDX-Net Model: UVR-MDX-NET 9482",
+ "stems": [
+ "vocals",
+ "instrumental"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR_MDXNET_9482.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR_MDXNET_9482.yaml?download=true"
+ },
+ "mdx_karaoke1": {
+ "category": "Караоке",
+ "id": 636,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Karaoke 1",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR_MDXNET_KARA.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR_MDXNET_KARA.yaml?download=true"
+ },
+ "mdx_karaoke2": {
+ "category": "Караоке",
+ "id": 637,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Karaoke 2",
+ "stems": [
+ "other",
+ "vocals"
+ ],
+ "target_instrument": "other",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR_MDXNET_KARA_2.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR_MDXNET_KARA_2.yaml?download=true"
+ },
+ "mdx_main": {
+ "category": "Вокал",
+ "id": 638,
+ "full_name": "MDX-Net Model: UVR-MDX-NET Main",
+ "stems": [
+ "vocals",
+ "instrumental"
+ ],
+ "target_instrument": "vocals",
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR_MDXNET_Main.onnx?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR_MDXNET_Main.yaml?download=true"
+ }
+ },
+ "vr": {
+ "1_hp-uvr": {
+ "category": "Инструментал",
+ "id": 500,
+ "full_name": "VR Arch Single Model v5: 1_HP-UVR",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/1_HP-UVR.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/1_HP-UVR.yaml?download=true"
+ },
+ "2_hp-uvr": {
+ "category": "Инструментал",
+ "id": 501,
+ "full_name": "VR Arch Single Model v5: 2_HP-UVR",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/2_HP-UVR.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/2_HP-UVR.yaml?download=true"
+ },
+ "3_hp-vocal-uvr": {
+ "category": "Вокал",
+ "id": 502,
+ "full_name": "VR Arch Single Model v5: 3_HP-Vocal-UVR",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/3_HP-Vocal-UVR.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/3_HP-Vocal-UVR.yaml?download=true"
+ },
+ "4_hp-vocal-uvr": {
+ "category": "Вокал",
+ "id": 503,
+ "full_name": "VR Arch Single Model v5: 4_HP-Vocal-UVR",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/4_HP-Vocal-UVR.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/4_HP-Vocal-UVR.yaml?download=true"
+ },
+ "5_hp-karaoke-uvr": {
+ "category": "Караоке",
+ "id": 504,
+ "full_name": "VR Arch Single Model v5: 5_HP-Karaoke-UVR",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/5_HP-Karaoke-UVR.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/5_HP-Karaoke-UVR.yaml?download=true"
+ },
+ "6_hp-karaoke-uvr": {
+ "category": "Караоке",
+ "id": 505,
+ "full_name": "VR Arch Single Model v5: 6_HP-Karaoke-UVR",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/6_HP-Karaoke-UVR.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/6_HP-Karaoke-UVR.yaml?download=true"
+ },
+ "7_hp2-uvr": {
+ "category": "Инструментал и вокал",
+ "id": 506,
+ "full_name": "VR Arch Single Model v5: 7_HP2-UVR",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/7_HP2-UVR.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/7_HP2-UVR.yaml?download=true"
+ },
+ "8_hp2-uvr": {
+ "category": "Инструментал и вокал",
+ "id": 507,
+ "full_name": "VR Arch Single Model v5: 8_HP2-UVR",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/8_HP2-UVR.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/8_HP2-UVR.yaml?download=true"
+ },
+ "9_hp2-uvr": {
+ "category": "Инструментал и вокал",
+ "id": 508,
+ "full_name": "VR Arch Single Model v5: 9_HP2-UVR",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/9_HP2-UVR.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/9_HP2-UVR.yaml?download=true"
+ },
+ "10_sp-uvr-2b-32000-1": {
+ "category": "Инструментал и вокал",
+ "id": 509,
+ "full_name": "VR Arch Single Model v5: 10_SP-UVR-2B-32000-1",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/10_SP-UVR-2B-32000-1.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/10_SP-UVR-2B-32000-1.yaml?download=true"
+ },
+ "11_sp-uvr-2b-32000-2": {
+ "category": "Инструментал и вокал",
+ "id": 510,
+ "full_name": "VR Arch Single Model v5: 11_SP-UVR-2B-32000-2",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/11_SP-UVR-2B-32000-2.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/11_SP-UVR-2B-32000-2.yaml?download=true"
+ },
+ "12_sp-uvr-3b-44100": {
+ "category": "Инструментал и вокал",
+ "id": 511,
+ "full_name": "VR Arch Single Model v5: 12_SP-UVR-3B-44100",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/12_SP-UVR-3B-44100.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/12_SP-UVR-3B-44100.yaml?download=true"
+ },
+ "13_sp-uvr-4b-44100-1": {
+ "category": "Инструментал и вокал",
+ "id": 512,
+ "full_name": "VR Arch Single Model v5: 13_SP-UVR-4B-44100-1",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/13_SP-UVR-4B-44100-1.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/13_SP-UVR-4B-44100-1.yaml?download=true"
+ },
+ "14_sp-uvr-4b-44100-2": {
+ "category": "Инструментал и вокал",
+ "id": 513,
+ "full_name": "VR Arch Single Model v5: 14_SP-UVR-4B-44100-2",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/14_SP-UVR-4B-44100-2.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/14_SP-UVR-4B-44100-2.yaml?download=true"
+ },
+ "15_sp-uvr-mid-44100-1": {
+ "category": "Инструментал и вокал",
+ "id": 514,
+ "full_name": "VR Arch Single Model v5: 15_SP-UVR-MID-44100-1",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/15_SP-UVR-MID-44100-1.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/15_SP-UVR-MID-44100-1.yaml?download=true"
+ },
+ "16_sp-uvr-mid-44100-2": {
+ "category": "Инструментал и вокал",
+ "id": 515,
+ "full_name": "VR Arch Single Model v5: 16_SP-UVR-MID-44100-2",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/16_SP-UVR-MID-44100-2.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/16_SP-UVR-MID-44100-2.yaml?download=true"
+ },
+ "17_hp-wind_inst-uvr": {
+ "category": "Деревянные духовые",
+ "id": 516,
+ "full_name": "VR Arch Single Model v5: 17_HP-Wind_Inst-UVR",
+ "stems": [
+ "No Woodwinds",
+ "Woodwinds"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/17_HP-Wind_Inst-UVR.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/17_HP-Wind_Inst-UVR.yaml?download=true"
+ },
+ "uvr-de-echo-aggressive": {
+ "category": "Эхо",
+ "id": 517,
+ "full_name": "VR Arch Single Model v5: UVR-De-Echo-Aggressive by FoxJoy",
+ "stems": [
+ "No Echo",
+ "Echo"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/UVR-De-Echo-Aggressive.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-De-Echo-Aggressive.yaml?download=true"
+ },
+ "uvr-de-echo-normal": {
+ "category": "Эхо",
+ "id": 518,
+ "full_name": "VR Arch Single Model v5: UVR-De-Echo-Normal by FoxJoy",
+ "stems": [
+ "No Echo",
+ "Echo"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/UVR-De-Echo-Normal.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-De-Echo-Normal.yaml?download=true"
+ },
+ "uvr-deecho-dereverb": {
+ "category": "Реверб",
+ "id": 519,
+ "full_name": "VR Arch Single Model v5: UVR-DeEcho-DeReverb by FoxJoy",
+ "stems": [
+ "No Reverb",
+ "Reverb"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/UVR-DeEcho-DeReverb.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-DeEcho-DeReverb.yaml?download=true"
+ },
+ "uvr-denoise-lite": {
+ "category": "Шум",
+ "id": 520,
+ "full_name": "VR Arch Single Model v5: UVR-DeNoise-Lite by FoxJoy",
+ "stems": [
+ "Noise",
+ "No Noise"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/UVR-DeNoise-Lite.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-DeNoise-Lite.yaml?download=true"
+ },
+ "uvr-denoise": {
+ "category": "Шум",
+ "id": 521,
+ "full_name": "VR Arch Single Model v5: UVR-DeNoise by FoxJoy",
+ "stems": [
+ "Noise",
+ "No Noise"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/UVR-DeNoise.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-DeNoise.yaml?download=true"
+ },
+ "uvr-bve-4b_sn-44100-1": {
+ "category": "Караоке",
+ "id": 522,
+ "full_name": "VR Arch Single Model v5: UVR-BVE-4B_SN-44100",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/UVR-BVE-4B_SN-44100-1.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-BVE-4B_SN-44100-1.yaml?download=true"
+ },
+ "uvr-bve-v2-4b-sn-44100": {
+ "category": "Караоке",
+ "id": 523,
+ "full_name": "VR Arch Single Model v4: UVR-BVE-v2-4B-SN-44100",
+ "stems": [
+ "Vocals",
+ "Instrumental"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/RareSirMix/AIModelRehosting/resolve/main/UVR-5-1_4band_v4_ms_fullband_BVE_v2_by_aufr33.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-BVE-v2-4B-SN-44100.yaml?download=true"
+ },
+ "mgm-v5-karokee-32000-beta1": {
+ "category": "Караоке",
+ "id": 524,
+ "full_name": "VR Arch Single Model v5: MGM-v5-KAROKEE-32000-BETA1",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://github.com/lucassantilli/UVR-Colab-GUI/releases/download/m5.1/MGM-v5-KAROKEE-32000-BETA1.pth",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/MGM-v5-KAROKEE-32000-BETA1.yaml?download=true"
+ },
+ "mgm-v5-karokee-32000-beta2-agr": {
+ "category": "Караоке",
+ "id": 525,
+ "full_name": "VR Arch Single Model v5: MGM-v5-KAROKEE-32000-BETA2-AGR.pth",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://github.com/lucassantilli/UVR-Colab-GUI/releases/download/m5.1/MGM-v5-KAROKEE-32000-BETA2-AGR.pth",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/MGM-v5-KAROKEE-32000-BETA2-AGR.yaml?download=true"
+ },
+ "mgm_highend_v4": {
+ "category": "Инструментал и вокал",
+ "id": 526,
+ "full_name": "VR Arch Single Model v4: MGM_HIGHEND_v4",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/MGM_HIGHEND_v4.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/MGM_HIGHEND_v4.yaml?download=true"
+ },
+ "mgm_lowend_a_v4": {
+ "category": "Инструментал и вокал",
+ "id": 527,
+ "full_name": "VR Arch Single Model v4: MGM_LOWEND_A_v4",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/MGM_LOWEND_A_v4.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/MGM_LOWEND_A_v4.yaml?download=true"
+ },
+ "mgm_lowend_b_v4": {
+ "category": "Инструментал и вокал",
+ "id": 528,
+ "full_name": "VR Arch Single Model v4: MGM_LOWEND_B_v4",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/MGM_LOWEND_B_v4.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/MGM_LOWEND_B_v4.yaml?download=true"
+ },
+ "mgm_main_v4": {
+ "category": "Инструментал и вокал",
+ "id": 529,
+ "full_name": "VR Arch Single Model v4: MGM_MAIN_v4",
+ "stems": [
+ "Instrumental",
+ "Vocals"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/MGM_MAIN_v4.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/MGM_MAIN_v4.yaml?download=true"
+ },
+ "uvr-de-reverb-aufr33-jarredou": {
+ "category": "Реверб",
+ "id": 530,
+ "full_name": "VR Arch Single Model v4: UVR-De-Reverb by aufr33-jarredou",
+ "stems": [
+ "Dry",
+ "No Dry"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/UVR-De-Reverb-aufr33-jarredou.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-De-Reverb-aufr33-jarredou.yaml?download=true"
+ },
+ "uvr-de-breath-sucial-v1": {
+ "category": "Дыхание",
+ "id": 531,
+ "full_name": "VR Arch Single Model v4: UVR-De-Breath v1 by Sucial",
+ "stems": [
+ "Breath",
+ "No Breath"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Sucial/De-Breathe-Models/resolve/main/UVR_De-Breathe_1band_sr44100_hl1024_Sucial_v1.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-De-Breath-sucial-v1.yaml?download=true"
+ },
+ "uvr-de-breath-sucial-v2": {
+ "category": "Дыхание",
+ "id": 532,
+ "full_name": "VR Arch Single Model v4: UVR-De-Breath v2 by Sucial",
+ "stems": [
+ "Breath",
+ "No Breath"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Sucial/De-Breathe-Models/resolve/main/UVR_De-Breathe_1band_sr44100_hl1024_Sucial_v2.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-De-Breath-sucial-v2.yaml?download=true"
+ },
+ "vr_harmonic_noise_sep": {
+ "category": "Дыхание",
+ "id": 533,
+ "full_name": "VR Arch Single Model v5: Harmonic_Noise_Sep",
+ "stems": [
+ "Noise",
+ "No Noise"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/Sucial/MSST-WebUI/resolve/main/All_Models/VR_Models/Harmonic_Noise_Separation_yxlllc.pth?download=true",
+ "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/VR_Harmonic_Noise_Sep.yaml?download=true"
+ }
+ },
+ "scnet": {
+ "scnet_4stem_zfturbo": {
+ "category": "4 стема",
+ "id": 400,
+ "full_name": "SCNet 4 Stems by ZFTurbo",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.9/SCNet-large_starrytong_fixed.ckpt",
+ "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.9/config_musdb18_scnet_large_starrytong.yaml"
+ },
+ "scnet_xl_ihf_4stem_zfturbo": {
+ "category": "4 стема",
+ "id": 401,
+ "full_name": "SCNet XL IHF 4 Stems by ZFTurbo",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.15/model_scnet_ep_36_sdr_10.0891.ckpt",
+ "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.15/config_musdb18_scnet_xl_more_wide_v5.yaml"
+ },
+ "scnet_xl_4stem_starrytong": {
+ "category": "4 стема",
+ "id": 402,
+ "full_name": "SCNet 4 Stems XL by StarryTong",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.13/model_scnet_ep_54_sdr_9.8051.ckpt",
+ "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.13/config_musdb18_scnet_xl.yaml"
+ },
+ "scnet_xl_4stem_zftrubo": {
+ "category": "4 стема",
+ "id": 403,
+ "full_name": "SCNet 4 Stems XL by ZFTurbo",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.6/scnet_checkpoint_musdb18.ckpt",
+ "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.6/config_musdb18_scnet.yaml"
+ },
+ "scnet_jazz_4stem_jorisvaneyghen": {
+ "category": "4 стема",
+ "id": 404,
+ "full_name": "SCNet Large Jazz model by Joris Vaneyghen",
+ "stems": [
+ "piano",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/jorisvaneyghen/SCNet/resolve/main/model_jazz_scnet_large.ckpt?download=true",
+ "config_url": "https://huggingface.co/spaces/jorisvaneyghen/jazz_playalong/resolve/main/configs/config_jazz_scnet_large.yaml?download=true"
+ },
+ "scnet_xl_jazz_4stem_jorisvaneyghen": {
+ "category": "4 стема",
+ "id": 405,
+ "full_name": "SCNet XL Jazz model by Joris Vaneyghen",
+ "stems": [
+ "piano",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/jorisvaneyghen/SCNet/resolve/main/model_jazz_scnet_xl.ckpt?download=true",
+ "config_url": "https://huggingface.co/spaces/jorisvaneyghen/jazz_playalong/resolve/main/configs/config_jazz_scnet_xl.yaml?download=true"
+ },
+ "scnet_choirsep_exp": {
+ "category": "Хор",
+ "id": 420,
+ "full_name": "SCNet Choirsep by concert.isolations.business@gmail.com",
+ "stems": [
+ "alto",
+ "bass",
+ "tenor",
+ "soprano"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/scnet_choirsep/model_scnet_ep_36_sdr_5.4596.ckpt?download=true",
+ "config_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/scnet_choirsep/config_scnet_choirsep.yaml?download=true"
+ }
+ },
+ "htdemucs": {
+ "demucs4_mvsep_vocals": {
+ "category": "Вокал",
+ "id": 2,
+ "full_name": "HTDemucs4 (MVSep finetuned)",
+ "stems": [
+ "vocals",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.0/model_vocals_htdemucs_sdr_8.78.ckpt",
+ "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_vocals_htdemucs.yaml"
+ },
+ "demucs4_4stem": {
+ "category": "4 стема",
+ "id": 0,
+ "full_name": "HTDemucs4",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/955717e8-8726e21a.th",
+ "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml"
+ },
+ "demucs4_6stem": {
+ "category": "6 стемов",
+ "id": 1,
+ "full_name": "HTDemucs4 (6 stems)",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other",
+ "guitar",
+ "piano"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/5c90dfd2-34c22ccb.th",
+ "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_htdemucs_6stems.yaml"
+ },
+ "demucs3_mmi": {
+ "category": "4 стема",
+ "id": 3,
+ "full_name": "Demucs3 mmi",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/75fc33f5-1941ce65.th",
+ "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_demucs3_mmi.yaml"
+ },
+ "demucs4_ft_bass": {
+ "category": "Басс",
+ "id": 4,
+ "full_name": "HTDemucs4 FT Bass",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/d12395a8-e57c48e6.th",
+ "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml"
+ },
+ "demucs4_ft_drums": {
+ "category": "Ударные",
+ "id": 5,
+ "full_name": "HTDemucs4 FT Drums",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th",
+ "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml"
+ },
+ "demucs4_ft_vocals": {
+ "category": "Вокал",
+ "id": 6,
+ "full_name": "HTDemucs4 FT Vocals",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/04573f0d-f3cf25b2.th",
+ "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml"
+ },
+ "demucs4_ft_other": {
+ "category": "Прочее",
+ "id": 7,
+ "full_name": "HTDemucs4 FT Other",
+ "stems": [
+ "vocals",
+ "drums",
+ "bass",
+ "other"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/92cfc3b6-ef3bcb9c.th",
+ "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml"
+ },
+ "demucs_mid_side_wesleyr36": {
+ "category": "Фантомный центр",
+ "id": 8,
+ "full_name": "HTDemucs4 MId-Side by wesleyr36",
+ "stems": [
+ "similarity",
+ "difference"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/jarredou/HTDemucs_Similarity_Extractor_by_wesleyr36/resolve/main/model_htdemucs_ep_21_sdr_13.6970.ckpt?download=true",
+ "config_url": "https://huggingface.co/jarredou/HTDemucs_Similarity_Extractor_by_wesleyr36/resolve/main/config_htdemucs_similarity.yaml?download=true"
+ },
+ "demucs4_choirsep": {
+ "category": "Хор",
+ "id": 9,
+ "full_name": "HTDemucs4 Choirsep by concert.isolations.business@gmail.com",
+ "stems": [
+ "alto",
+ "bass",
+ "tenor",
+ "soprano"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/demucs_choirsep/model_htdemucs_ep_94_sdr_5.2474.ckpt?download=true",
+ "config_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/demucs_choirsep/config_htdemucs_choirsep.yaml?download=true"
+ }
+ },
+ "bandit": {
+ "bandit_plus": {
+ "category": "Кинематограф",
+ "id": 10,
+ "full_name": "Bandit Plus: Cinematic Bandit Plus (by kwatcharasupat)",
+ "stems": [
+ "speech",
+ "music",
+ "effects"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.3/model_bandit_plus_dnr_sdr_11.47.chpt",
+ "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.3/config_dnr_bandit_bsrnn_multi_mus64.yaml"
+ }
+ },
+ "bandit_v2": {
+ "bandit_v2_multi": {
+ "category": "Кинематограф",
+ "id": 11,
+ "full_name": "Bandit v2: Cinematic Bandit v2 Multilang (by kwatcharasupat)",
+ "stems": [
+ "speech",
+ "music",
+ "sfx"
+ ],
+ "target_instrument": null,
+ "checkpoint_url": "https://huggingface.co/jarredou/banditv2_state_dicts_only/resolve/main/checkpoint-multi_state_dict.ckpt",
+ "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/refs/heads/main/configs/config_dnr_bandit_v2_mus64.yaml"
+ }
+ }
+}
\ No newline at end of file
diff --git a/mvsepless/models/bandit/core/__init__.py b/mvsepless/models/bandit/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..662d4075fc46d93e919b5dcdd3c215bcb8f57ef0
--- /dev/null
+++ b/mvsepless/models/bandit/core/__init__.py
@@ -0,0 +1,691 @@
+import os.path
+from collections import defaultdict
+from itertools import chain, combinations
+from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Type, TypedDict
+
+import pytorch_lightning as pl
+import torch
+import torchaudio as ta
+import torchmetrics as tm
+from asteroid import losses as asteroid_losses
+
+# from deepspeed.ops.adam import DeepSpeedCPUAdam
+# from geoopt import optim as gooptim
+from pytorch_lightning.utilities.types import STEP_OUTPUT
+from torch import nn, optim
+from torch.optim import lr_scheduler
+from torch.optim.lr_scheduler import LRScheduler
+
+from . import loss, metrics as metrics_, model
+from .data._types import BatchedDataDict
+from .data.augmentation import BaseAugmentor, StemAugmentor
+from .utils import audio as audio_
+from .utils.audio import BaseFader
+
+# from pandas.io.json._normalize import nested_to_record
+
+ConfigDict = TypedDict("ConfigDict", {"name": str, "kwargs": Dict[str, Any]})
+
+
+class SchedulerConfigDict(ConfigDict):
+ monitor: str
+
+
+OptimizerSchedulerConfigDict = TypedDict(
+ "OptimizerSchedulerConfigDict",
+ {"optimizer": ConfigDict, "scheduler": SchedulerConfigDict},
+ total=False,
+)
+
+
+class LRSchedulerReturnDict(TypedDict, total=False):
+ scheduler: LRScheduler
+ monitor: str
+
+
+class ConfigureOptimizerReturnDict(TypedDict, total=False):
+ optimizer: torch.optim.Optimizer
+ lr_scheduler: LRSchedulerReturnDict
+
+
+OutputType = Dict[str, Any]
+MetricsType = Dict[str, torch.Tensor]
+
+
+def get_optimizer_class(name: str) -> Type[optim.Optimizer]:
+
+ if name == "DeepSpeedCPUAdam":
+ return DeepSpeedCPUAdam
+
+ for module in [optim, gooptim]:
+ if name in module.__dict__:
+ return module.__dict__[name]
+
+ raise NameError
+
+
+def parse_optimizer_config(
+ config: OptimizerSchedulerConfigDict, parameters: Iterator[nn.Parameter]
+) -> ConfigureOptimizerReturnDict:
+ optim_class = get_optimizer_class(config["optimizer"]["name"])
+ optimizer = optim_class(parameters, **config["optimizer"]["kwargs"])
+
+ optim_dict: ConfigureOptimizerReturnDict = {
+ "optimizer": optimizer,
+ }
+
+ if "scheduler" in config:
+
+ lr_scheduler_class_ = config["scheduler"]["name"]
+ lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_]
+ lr_scheduler_dict: LRSchedulerReturnDict = {
+ "scheduler": lr_scheduler_class(optimizer, **config["scheduler"]["kwargs"])
+ }
+
+ if lr_scheduler_class_ == "ReduceLROnPlateau":
+ lr_scheduler_dict["monitor"] = config["scheduler"]["monitor"]
+
+ optim_dict["lr_scheduler"] = lr_scheduler_dict
+
+ return optim_dict
+
+
+def parse_model_config(config: ConfigDict) -> Any:
+ name = config["name"]
+
+ for module in [model]:
+ if name in module.__dict__:
+ return module.__dict__[name](**config["kwargs"])
+
+ raise NameError
+
+
+_LEGACY_LOSS_NAMES = ["HybridL1Loss"]
+
+
+def _parse_legacy_loss_config(config: ConfigDict) -> nn.Module:
+ name = config["name"]
+
+ if name == "HybridL1Loss":
+ return loss.TimeFreqL1Loss(**config["kwargs"])
+
+ raise NameError
+
+
+def parse_loss_config(config: ConfigDict) -> nn.Module:
+ name = config["name"]
+
+ if name in _LEGACY_LOSS_NAMES:
+ return _parse_legacy_loss_config(config)
+
+ for module in [loss, nn.modules.loss, asteroid_losses]:
+ if name in module.__dict__:
+ # print(config["kwargs"])
+ return module.__dict__[name](**config["kwargs"])
+
+ raise NameError
+
+
+def get_metric(config: ConfigDict) -> tm.Metric:
+ name = config["name"]
+
+ for module in [tm, metrics_]:
+ if name in module.__dict__:
+ return module.__dict__[name](**config["kwargs"])
+ raise NameError
+
+
+def parse_metric_config(config: Dict[str, ConfigDict]) -> tm.MetricCollection:
+ metrics = {}
+
+ for metric in config:
+ metrics[metric] = get_metric(config[metric])
+
+ return tm.MetricCollection(metrics)
+
+
+def parse_fader_config(config: ConfigDict) -> BaseFader:
+ name = config["name"]
+
+ for module in [audio_]:
+ if name in module.__dict__:
+ return module.__dict__[name](**config["kwargs"])
+
+ raise NameError
+
+
+class LightningSystem(pl.LightningModule):
+ _VOX_STEMS = ["speech", "vocals"]
+ _BG_STEMS = ["background", "effects", "mne"]
+
+ def __init__(
+ self, config: Dict, loss_adjustment: float = 1.0, attach_fader: bool = False
+ ) -> None:
+ super().__init__()
+ self.optimizer_config = config["optimizer"]
+ self.model = parse_model_config(config["model"])
+ self.loss = parse_loss_config(config["loss"])
+ self.metrics = nn.ModuleDict(
+ {
+ stem: parse_metric_config(config["metrics"]["dev"])
+ for stem in self.model.stems
+ }
+ )
+
+ self.metrics.disallow_fsdp = True
+
+ self.test_metrics = nn.ModuleDict(
+ {
+ stem: parse_metric_config(config["metrics"]["test"])
+ for stem in self.model.stems
+ }
+ )
+
+ self.test_metrics.disallow_fsdp = True
+
+ self.fs = config["model"]["kwargs"]["fs"]
+
+ self.fader_config = config["inference"]["fader"]
+ if attach_fader:
+ self.fader = parse_fader_config(config["inference"]["fader"])
+ else:
+ self.fader = None
+
+ self.augmentation: Optional[BaseAugmentor]
+ if config.get("augmentation", None) is not None:
+ self.augmentation = StemAugmentor(**config["augmentation"])
+ else:
+ self.augmentation = None
+
+ self.predict_output_path: Optional[str] = None
+ self.loss_adjustment = loss_adjustment
+
+ self.val_prefix = None
+ self.test_prefix = None
+
+ def configure_optimizers(self) -> Any:
+ return parse_optimizer_config(
+ self.optimizer_config, self.trainer.model.parameters()
+ )
+
+ def compute_loss(
+ self, batch: BatchedDataDict, output: OutputType
+ ) -> Dict[str, torch.Tensor]:
+ return {"loss": self.loss(output, batch)}
+
+ def update_metrics(
+ self, batch: BatchedDataDict, output: OutputType, mode: str
+ ) -> None:
+
+ if mode == "test":
+ metrics = self.test_metrics
+ else:
+ metrics = self.metrics
+
+ for stem, metric in metrics.items():
+
+ if stem == "mne:+":
+ stem = "mne"
+
+ # print(f"matching for {stem}")
+ if mode == "train":
+ metric.update(
+ output["audio"][stem], # .cpu(),
+ batch["audio"][stem], # .cpu()
+ )
+ else:
+ if stem not in batch["audio"]:
+ matched = False
+ if stem in self._VOX_STEMS:
+ for bstem in self._VOX_STEMS:
+ if bstem in batch["audio"]:
+ batch["audio"][stem] = batch["audio"][bstem]
+ matched = True
+ break
+ elif stem in self._BG_STEMS:
+ for bstem in self._BG_STEMS:
+ if bstem in batch["audio"]:
+ batch["audio"][stem] = batch["audio"][bstem]
+ matched = True
+ break
+ else:
+ matched = True
+
+ # print(batch["audio"].keys())
+
+ if matched:
+ # print(f"matched {stem}!")
+ if stem == "mne" and "mne" not in output["audio"]:
+ output["audio"]["mne"] = (
+ output["audio"]["music"] + output["audio"]["effects"]
+ )
+
+ metric.update(
+ output["audio"][stem], # .cpu(),
+ batch["audio"][stem], # .cpu(),
+ )
+
+ # print(metric.compute())
+
+ def compute_metrics(self, mode: str = "dev") -> Dict[str, torch.Tensor]:
+
+ if mode == "test":
+ metrics = self.test_metrics
+ else:
+ metrics = self.metrics
+
+ metric_dict = {}
+
+ for stem, metric in metrics.items():
+ md = metric.compute()
+ metric_dict.update({f"{stem}/{k}": v for k, v in md.items()})
+
+ self.log_dict(metric_dict, prog_bar=True, logger=False)
+
+ return metric_dict
+
+ def reset_metrics(self, test_mode: bool = False) -> None:
+
+ if test_mode:
+ metrics = self.test_metrics
+ else:
+ metrics = self.metrics
+
+ for _, metric in metrics.items():
+ metric.reset()
+
+ def forward(self, batch: BatchedDataDict) -> Any:
+ batch, output = self.model(batch)
+
+ return batch, output
+
+ def common_step(self, batch: BatchedDataDict, mode: str) -> Any:
+ batch, output = self.forward(batch)
+ # print(batch)
+ # print(output)
+ loss_dict = self.compute_loss(batch, output)
+
+ with torch.no_grad():
+ self.update_metrics(batch, output, mode=mode)
+
+ if mode == "train":
+ self.log("loss", loss_dict["loss"], prog_bar=True)
+
+ return output, loss_dict
+
+ def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]:
+
+ if self.augmentation is not None:
+ with torch.no_grad():
+ batch = self.augmentation(batch)
+
+ _, loss_dict = self.common_step(batch, mode="train")
+
+ with torch.inference_mode():
+ self.log_dict_with_prefix(
+ loss_dict, "train", batch_size=batch["audio"]["mixture"].shape[0]
+ )
+
+ loss_dict["loss"] *= self.loss_adjustment
+
+ return loss_dict
+
+ def on_train_batch_end(
+ self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int
+ ) -> None:
+
+ metric_dict = self.compute_metrics()
+ self.log_dict_with_prefix(metric_dict, "train")
+ self.reset_metrics()
+
+ def validation_step(
+ self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
+ ) -> Dict[str, Any]:
+
+ with torch.inference_mode():
+ curr_val_prefix = f"val{dataloader_idx}" if dataloader_idx > 0 else "val"
+
+ if curr_val_prefix != self.val_prefix:
+ # print(f"Switching to validation dataloader {dataloader_idx}")
+ if self.val_prefix is not None:
+ self._on_validation_epoch_end()
+ self.val_prefix = curr_val_prefix
+ _, loss_dict = self.common_step(batch, mode="val")
+
+ self.log_dict_with_prefix(
+ loss_dict,
+ self.val_prefix,
+ batch_size=batch["audio"]["mixture"].shape[0],
+ prog_bar=True,
+ add_dataloader_idx=False,
+ )
+
+ return loss_dict
+
+ def on_validation_epoch_end(self) -> None:
+ self._on_validation_epoch_end()
+
+ def _on_validation_epoch_end(self) -> None:
+ metric_dict = self.compute_metrics()
+ self.log_dict_with_prefix(
+ metric_dict, self.val_prefix, prog_bar=True, add_dataloader_idx=False
+ )
+ # self.logger.save()
+ # print(self.val_prefix, "Validation metrics:", metric_dict)
+ self.reset_metrics()
+
+ def old_predtest_step(
+ self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
+ ) -> Tuple[BatchedDataDict, OutputType]:
+
+ audio_batch = batch["audio"]["mixture"]
+ track_batch = batch.get("track", ["" for _ in range(len(audio_batch))])
+
+ output_list_of_dicts = [
+ self.fader(audio[None, ...], lambda a: self.test_forward(a, track))
+ for audio, track in zip(audio_batch, track_batch)
+ ]
+
+ output_dict_of_lists = defaultdict(list)
+
+ for output_dict in output_list_of_dicts:
+ for stem, audio in output_dict.items():
+ output_dict_of_lists[stem].append(audio)
+
+ output = {
+ "audio": {
+ stem: torch.concat(output_list, dim=0)
+ for stem, output_list in output_dict_of_lists.items()
+ }
+ }
+
+ return batch, output
+
+ def predtest_step(
+ self, batch: BatchedDataDict, batch_idx: int = -1, dataloader_idx: int = 0
+ ) -> Tuple[BatchedDataDict, OutputType]:
+
+ if getattr(self.model, "bypass_fader", False):
+ batch, output = self.model(batch)
+ else:
+ audio_batch = batch["audio"]["mixture"]
+ output = self.fader(
+ audio_batch, lambda a: self.test_forward(a, "", batch=batch)
+ )
+
+ return batch, output
+
+ def test_forward(
+ self, audio: torch.Tensor, track: str = "", batch: BatchedDataDict = None
+ ) -> torch.Tensor:
+
+ if self.fader is None:
+ self.attach_fader()
+
+ cond = batch.get("condition", None)
+
+ if cond is not None and cond.shape[0] == 1:
+ cond = cond.repeat(audio.shape[0], 1)
+
+ _, output = self.forward(
+ {
+ "audio": {"mixture": audio},
+ "track": track,
+ "condition": cond,
+ }
+ ) # TODO: support track properly
+
+ return output["audio"]
+
+ def on_test_epoch_start(self) -> None:
+ self.attach_fader(force_reattach=True)
+
+ def test_step(
+ self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
+ ) -> Any:
+ curr_test_prefix = f"test{dataloader_idx}"
+
+ # print(batch["audio"].keys())
+
+ if curr_test_prefix != self.test_prefix:
+ # print(f"Switching to test dataloader {dataloader_idx}")
+ if self.test_prefix is not None:
+ self._on_test_epoch_end()
+ self.test_prefix = curr_test_prefix
+
+ with torch.inference_mode():
+ _, output = self.predtest_step(batch, batch_idx, dataloader_idx)
+ # print(output)
+ self.update_metrics(batch, output, mode="test")
+
+ return output
+
+ def on_test_epoch_end(self) -> None:
+ self._on_test_epoch_end()
+
+ def _on_test_epoch_end(self) -> None:
+ metric_dict = self.compute_metrics(mode="test")
+ self.log_dict_with_prefix(
+ metric_dict, self.test_prefix, prog_bar=True, add_dataloader_idx=False
+ )
+ # self.logger.save()
+ # print(self.test_prefix, "Test metrics:", metric_dict)
+ self.reset_metrics()
+
+ def predict_step(
+ self,
+ batch: BatchedDataDict,
+ batch_idx: int = 0,
+ dataloader_idx: int = 0,
+ include_track_name: Optional[bool] = None,
+ get_no_vox_combinations: bool = True,
+ get_residual: bool = False,
+ treat_batch_as_channels: bool = False,
+ fs: Optional[int] = None,
+ ) -> Any:
+ assert self.predict_output_path is not None
+
+ batch_size = batch["audio"]["mixture"].shape[0]
+
+ if include_track_name is None:
+ include_track_name = batch_size > 1
+
+ with torch.inference_mode():
+ batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
+ print("Pred test finished...")
+ torch.cuda.empty_cache()
+ metric_dict = {}
+
+ if get_residual:
+ mixture = batch["audio"]["mixture"]
+ extracted = sum([output["audio"][stem] for stem in output["audio"]])
+ residual = mixture - extracted
+ print(extracted.shape, mixture.shape, residual.shape)
+
+ output["audio"]["residual"] = residual
+
+ if get_no_vox_combinations:
+ no_vox_stems = [
+ stem for stem in output["audio"] if stem not in self._VOX_STEMS
+ ]
+ no_vox_combinations = chain.from_iterable(
+ combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
+ )
+
+ for combination in no_vox_combinations:
+ combination_ = list(combination)
+ output["audio"]["+".join(combination_)] = sum(
+ [output["audio"][stem] for stem in combination_]
+ )
+
+ if treat_batch_as_channels:
+ for stem in output["audio"]:
+ output["audio"][stem] = output["audio"][stem].reshape(
+ 1, -1, output["audio"][stem].shape[-1]
+ )
+ batch_size = 1
+
+ for b in range(batch_size):
+ print("!!", b)
+ for stem in output["audio"]:
+ print(f"Saving audio for {stem} to {self.predict_output_path}")
+ track_name = batch["track"][b].split("/")[-1]
+
+ if batch.get("audio", {}).get(stem, None) is not None:
+ self.test_metrics[stem].reset()
+ metrics = self.test_metrics[stem](
+ batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
+ )
+ snr = metrics["snr"]
+ sisnr = metrics["sisnr"]
+ sdr = metrics["sdr"]
+ metric_dict[stem] = metrics
+ print(
+ track_name,
+ f"snr={snr:2.2f} dB",
+ f"sisnr={sisnr:2.2f}",
+ f"sdr={sdr:2.2f} dB",
+ )
+ filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
+ else:
+ filename = f"{stem}.wav"
+
+ if include_track_name:
+ output_dir = os.path.join(self.predict_output_path, track_name)
+ else:
+ output_dir = self.predict_output_path
+
+ os.makedirs(output_dir, exist_ok=True)
+
+ if fs is None:
+ fs = self.fs
+
+ ta.save(
+ os.path.join(output_dir, filename),
+ output["audio"][stem][b, ...].cpu(),
+ fs,
+ )
+
+ return metric_dict
+
+ def get_stems(
+ self,
+ batch: BatchedDataDict,
+ batch_idx: int = 0,
+ dataloader_idx: int = 0,
+ include_track_name: Optional[bool] = None,
+ get_no_vox_combinations: bool = True,
+ get_residual: bool = False,
+ treat_batch_as_channels: bool = False,
+ fs: Optional[int] = None,
+ ) -> Any:
+ assert self.predict_output_path is not None
+
+ batch_size = batch["audio"]["mixture"].shape[0]
+
+ if include_track_name is None:
+ include_track_name = batch_size > 1
+
+ with torch.inference_mode():
+ batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
+ torch.cuda.empty_cache()
+ metric_dict = {}
+
+ if get_residual:
+ mixture = batch["audio"]["mixture"]
+ extracted = sum([output["audio"][stem] for stem in output["audio"]])
+ residual = mixture - extracted
+ # print(extracted.shape, mixture.shape, residual.shape)
+
+ output["audio"]["residual"] = residual
+
+ if get_no_vox_combinations:
+ no_vox_stems = [
+ stem for stem in output["audio"] if stem not in self._VOX_STEMS
+ ]
+ no_vox_combinations = chain.from_iterable(
+ combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
+ )
+
+ for combination in no_vox_combinations:
+ combination_ = list(combination)
+ output["audio"]["+".join(combination_)] = sum(
+ [output["audio"][stem] for stem in combination_]
+ )
+
+ if treat_batch_as_channels:
+ for stem in output["audio"]:
+ output["audio"][stem] = output["audio"][stem].reshape(
+ 1, -1, output["audio"][stem].shape[-1]
+ )
+ batch_size = 1
+
+ result = {}
+ for b in range(batch_size):
+ for stem in output["audio"]:
+ track_name = batch["track"][b].split("/")[-1]
+
+ if batch.get("audio", {}).get(stem, None) is not None:
+ self.test_metrics[stem].reset()
+ metrics = self.test_metrics[stem](
+ batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
+ )
+ snr = metrics["snr"]
+ sisnr = metrics["sisnr"]
+ sdr = metrics["sdr"]
+ metric_dict[stem] = metrics
+ print(
+ track_name,
+ f"snr={snr:2.2f} dB",
+ f"sisnr={sisnr:2.2f}",
+ f"sdr={sdr:2.2f} dB",
+ )
+ filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
+ else:
+ filename = f"{stem}.wav"
+
+ if include_track_name:
+ output_dir = os.path.join(self.predict_output_path, track_name)
+ else:
+ output_dir = self.predict_output_path
+
+ os.makedirs(output_dir, exist_ok=True)
+
+ if fs is None:
+ fs = self.fs
+
+ result[stem] = output["audio"][stem][b, ...].cpu().numpy()
+
+ return result
+
+ def load_state_dict(
+ self, state_dict: Mapping[str, Any], strict: bool = False
+ ) -> Any:
+
+ return super().load_state_dict(state_dict, strict=False)
+
+ def set_predict_output_path(self, path: str) -> None:
+ self.predict_output_path = path
+ os.makedirs(self.predict_output_path, exist_ok=True)
+
+ self.attach_fader()
+
+ def attach_fader(self, force_reattach=False) -> None:
+ if self.fader is None or force_reattach:
+ self.fader = parse_fader_config(self.fader_config)
+ self.fader.to(self.device)
+
+ def log_dict_with_prefix(
+ self,
+ dict_: Dict[str, torch.Tensor],
+ prefix: str,
+ batch_size: Optional[int] = None,
+ **kwargs: Any,
+ ) -> None:
+ self.log_dict(
+ {f"{prefix}/{k}": v for k, v in dict_.items()},
+ batch_size=batch_size,
+ logger=True,
+ sync_dist=True,
+ **kwargs,
+ )
diff --git a/mvsepless/models/bandit/core/data/__init__.py b/mvsepless/models/bandit/core/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9d4d672bd3b6ad90a26e19ee6c26e02ee3be84c
--- /dev/null
+++ b/mvsepless/models/bandit/core/data/__init__.py
@@ -0,0 +1,2 @@
+from .dnr.datamodule import DivideAndRemasterDataModule
+from .musdb.datamodule import MUSDB18DataModule
diff --git a/mvsepless/models/bandit/core/data/_types.py b/mvsepless/models/bandit/core/data/_types.py
new file mode 100644
index 0000000000000000000000000000000000000000..65e4607a558e6b6a65ee68de883b69e282f8fcf4
--- /dev/null
+++ b/mvsepless/models/bandit/core/data/_types.py
@@ -0,0 +1,17 @@
+from typing import Dict, Sequence, TypedDict
+
+import torch
+
+AudioDict = Dict[str, torch.Tensor]
+
+DataDict = TypedDict("DataDict", {"audio": AudioDict, "track": str})
+
+BatchedDataDict = TypedDict(
+ "BatchedDataDict", {"audio": AudioDict, "track": Sequence[str]}
+)
+
+
+class DataDictWithLanguage(TypedDict):
+ audio: AudioDict
+ track: str
+ language: str
diff --git a/mvsepless/models/bandit/core/data/augmentation.py b/mvsepless/models/bandit/core/data/augmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc2ec18ef256cf266a81d7efae9a5fb902ae1d5a
--- /dev/null
+++ b/mvsepless/models/bandit/core/data/augmentation.py
@@ -0,0 +1,102 @@
+from abc import ABC
+from typing import Any, Dict, Union
+
+import torch
+import torch_audiomentations as tam
+from torch import nn
+
+from ._types import BatchedDataDict, DataDict
+
+
+class BaseAugmentor(nn.Module, ABC):
+ def forward(
+ self, item: Union[DataDict, BatchedDataDict]
+ ) -> Union[DataDict, BatchedDataDict]:
+ raise NotImplementedError
+
+
+class StemAugmentor(BaseAugmentor):
+ def __init__(
+ self,
+ audiomentations: Dict[str, Dict[str, Any]],
+ fix_clipping: bool = True,
+ scaler_margin: float = 0.5,
+ apply_both_default_and_common: bool = False,
+ ) -> None:
+ super().__init__()
+
+ augmentations = {}
+
+ self.has_default = "[default]" in audiomentations
+ self.has_common = "[common]" in audiomentations
+ self.apply_both_default_and_common = apply_both_default_and_common
+
+ for stem in audiomentations:
+ if audiomentations[stem]["name"] == "Compose":
+ augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
+ [
+ getattr(tam, aug["name"])(**aug["kwargs"])
+ for aug in audiomentations[stem]["kwargs"]["transforms"]
+ ],
+ **audiomentations[stem]["kwargs"]["kwargs"],
+ )
+ else:
+ augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
+ **audiomentations[stem]["kwargs"]
+ )
+
+ self.augmentations = nn.ModuleDict(augmentations)
+ self.fix_clipping = fix_clipping
+ self.scaler_margin = scaler_margin
+
+ def check_and_fix_clipping(
+ self, item: Union[DataDict, BatchedDataDict]
+ ) -> Union[DataDict, BatchedDataDict]:
+ max_abs = []
+
+ for stem in item["audio"]:
+ max_abs.append(item["audio"][stem].abs().max().item())
+
+ if max(max_abs) > 1.0:
+ scaler = 1.0 / (
+ max(max_abs)
+ + torch.rand((1,), device=item["audio"]["mixture"].device)
+ * self.scaler_margin
+ )
+
+ for stem in item["audio"]:
+ item["audio"][stem] *= scaler
+
+ return item
+
+ def forward(
+ self, item: Union[DataDict, BatchedDataDict]
+ ) -> Union[DataDict, BatchedDataDict]:
+
+ for stem in item["audio"]:
+ if stem == "mixture":
+ continue
+
+ if self.has_common:
+ item["audio"][stem] = self.augmentations["[common]"](
+ item["audio"][stem]
+ ).samples
+
+ if stem in self.augmentations:
+ item["audio"][stem] = self.augmentations[stem](
+ item["audio"][stem]
+ ).samples
+ elif self.has_default:
+ if not self.has_common or self.apply_both_default_and_common:
+ item["audio"][stem] = self.augmentations["[default]"](
+ item["audio"][stem]
+ ).samples
+
+ item["audio"]["mixture"] = sum(
+ [item["audio"][stem] for stem in item["audio"] if stem != "mixture"]
+ ) # type: ignore[call-overload, assignment]
+
+ if self.fix_clipping:
+ item = self.check_and_fix_clipping(item)
+
+ return item
diff --git a/mvsepless/models/bandit/core/data/augmented.py b/mvsepless/models/bandit/core/data/augmented.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c0524409bf99b009605989eba5d5f46f0560f2e
--- /dev/null
+++ b/mvsepless/models/bandit/core/data/augmented.py
@@ -0,0 +1,34 @@
+import warnings
+from typing import Dict, Optional, Union
+
+import torch
+from torch import nn
+from torch.utils import data
+
+
+class AugmentedDataset(data.Dataset):
+ def __init__(
+ self,
+ dataset: data.Dataset,
+ augmentation: nn.Module = nn.Identity(),
+ target_length: Optional[int] = None,
+ ) -> None:
+ warnings.warn(
+ "This class is no longer used. Attach augmentation to "
+ "the LightningSystem instead.",
+ DeprecationWarning,
+ )
+
+ self.dataset = dataset
+ self.augmentation = augmentation
+
+ self.ds_length: int = len(dataset) # type: ignore[arg-type]
+ self.length = target_length if target_length is not None else self.ds_length
+
+ def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str, torch.Tensor]]]:
+ item = self.dataset[index % self.ds_length]
+ item = self.augmentation(item)
+ return item
+
+ def __len__(self) -> int:
+ return self.length
diff --git a/mvsepless/models/bandit/core/data/base.py b/mvsepless/models/bandit/core/data/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c85194f613d2043a26406624da5cb83cb04033a
--- /dev/null
+++ b/mvsepless/models/bandit/core/data/base.py
@@ -0,0 +1,60 @@
+import os
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional
+
+import numpy as np
+import pedalboard as pb
+import torch
+import torchaudio as ta
+from torch.utils import data
+
+from ._types import AudioDict, DataDict
+
+
+class BaseSourceSeparationDataset(data.Dataset, ABC):
+ def __init__(
+ self,
+ split: str,
+ stems: List[str],
+ files: List[str],
+ data_path: str,
+ fs: int,
+ npy_memmap: bool,
+ recompute_mixture: bool,
+ ):
+ self.split = split
+ self.stems = stems
+ self.stems_no_mixture = [s for s in stems if s != "mixture"]
+ self.files = files
+ self.data_path = data_path
+ self.fs = fs
+ self.npy_memmap = npy_memmap
+ self.recompute_mixture = recompute_mixture
+
+ @abstractmethod
+ def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
+ raise NotImplementedError
+
+ def _get_audio(self, stems, identifier: Dict[str, Any]):
+ audio = {}
+ for stem in stems:
+ audio[stem] = self.get_stem(stem=stem, identifier=identifier)
+
+ return audio
+
+ def get_audio(self, identifier: Dict[str, Any]) -> AudioDict:
+
+ if self.recompute_mixture:
+ audio = self._get_audio(self.stems_no_mixture, identifier=identifier)
+ audio["mixture"] = self.compute_mixture(audio)
+ return audio
+ else:
+ return self._get_audio(self.stems, identifier=identifier)
+
+ @abstractmethod
+ def get_identifier(self, index: int) -> Dict[str, Any]:
+ pass
+
+ def compute_mixture(self, audio: AudioDict) -> torch.Tensor:
+
+ return sum(audio[stem] for stem in audio if stem != "mixture")
diff --git a/mvsepless/models/bandit/core/data/dnr/__init__.py b/mvsepless/models/bandit/core/data/dnr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mvsepless/models/bandit/core/data/dnr/datamodule.py b/mvsepless/models/bandit/core/data/dnr/datamodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..2971d419d433e335668f9e52cf54afda55a48f88
--- /dev/null
+++ b/mvsepless/models/bandit/core/data/dnr/datamodule.py
@@ -0,0 +1,68 @@
+import os
+from typing import Mapping, Optional
+
+import pytorch_lightning as pl
+
+from .dataset import (
+ DivideAndRemasterDataset,
+ DivideAndRemasterDeterministicChunkDataset,
+ DivideAndRemasterRandomChunkDataset,
+ DivideAndRemasterRandomChunkDatasetWithSpeechReverb,
+)
+
+
+def DivideAndRemasterDataModule(
+ data_root: str = "$DATA_ROOT/DnR/v2",
+ batch_size: int = 2,
+ num_workers: int = 8,
+ train_kwargs: Optional[Mapping] = None,
+ val_kwargs: Optional[Mapping] = None,
+ test_kwargs: Optional[Mapping] = None,
+ datamodule_kwargs: Optional[Mapping] = None,
+ use_speech_reverb: bool = False,
+ # augmentor=None
+) -> pl.LightningDataModule:
+ if train_kwargs is None:
+ train_kwargs = {}
+
+ if val_kwargs is None:
+ val_kwargs = {}
+
+ if test_kwargs is None:
+ test_kwargs = {}
+
+ if datamodule_kwargs is None:
+ datamodule_kwargs = {}
+
+ if num_workers is None:
+ num_workers = os.cpu_count()
+
+ if num_workers is None:
+ num_workers = 32
+
+ num_workers = min(num_workers, 64)
+
+ if use_speech_reverb:
+ train_cls = DivideAndRemasterRandomChunkDatasetWithSpeechReverb
+ else:
+ train_cls = DivideAndRemasterRandomChunkDataset
+
+ train_dataset = train_cls(data_root, "train", **train_kwargs)
+
+ # if augmentor is not None:
+ # train_dataset = AugmentedDataset(train_dataset, augmentor)
+
+ datamodule = pl.LightningDataModule.from_datasets(
+ train_dataset=train_dataset,
+ val_dataset=DivideAndRemasterDeterministicChunkDataset(
+ data_root, "val", **val_kwargs
+ ),
+ test_dataset=DivideAndRemasterDataset(data_root, "test", **test_kwargs),
+ batch_size=batch_size,
+ num_workers=num_workers,
+ **datamodule_kwargs
+ )
+
+ datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign]
+
+ return datamodule
diff --git a/mvsepless/models/bandit/core/data/dnr/dataset.py b/mvsepless/models/bandit/core/data/dnr/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ae6dba590312c83b61addec26284c8a9979c1a0
--- /dev/null
+++ b/mvsepless/models/bandit/core/data/dnr/dataset.py
@@ -0,0 +1,366 @@
+import os
+from abc import ABC
+from typing import Any, Dict, List, Optional
+
+import numpy as np
+import pedalboard as pb
+import torch
+import torchaudio as ta
+from torch.utils import data
+
+from .._types import AudioDict, DataDict
+from ..base import BaseSourceSeparationDataset
+
+
+class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC):
+ ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"]
+ STEM_NAME_MAP = {
+ "mixture": "mix",
+ "speech": "speech",
+ "music": "music",
+ "effects": "sfx",
+ }
+ SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"}
+
+ FULL_TRACK_LENGTH_SECOND = 60
+ FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100
+
+ def __init__(
+ self,
+ split: str,
+ stems: List[str],
+ files: List[str],
+ data_path: str,
+ fs: int = 44100,
+ npy_memmap: bool = True,
+ recompute_mixture: bool = False,
+ ) -> None:
+ super().__init__(
+ split=split,
+ stems=stems,
+ files=files,
+ data_path=data_path,
+ fs=fs,
+ npy_memmap=npy_memmap,
+ recompute_mixture=recompute_mixture,
+ )
+
+ def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
+
+ if stem == "mne":
+ return self.get_stem(stem="music", identifier=identifier) + self.get_stem(
+ stem="effects", identifier=identifier
+ )
+
+ track = identifier["track"]
+ path = os.path.join(self.data_path, track)
+
+ if self.npy_memmap:
+ audio = np.load(
+ os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"), mmap_mode="r"
+ )
+ else:
+ # noinspection PyUnresolvedReferences
+ audio, _ = ta.load(os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav"))
+
+ return audio
+
+ def get_identifier(self, index):
+ return dict(track=self.files[index])
+
+ def __getitem__(self, index: int) -> DataDict:
+ identifier = self.get_identifier(index)
+ audio = self.get_audio(identifier)
+
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
+
+
+class DivideAndRemasterDataset(DivideAndRemasterBaseDataset):
+ def __init__(
+ self,
+ data_root: str,
+ split: str,
+ stems: Optional[List[str]] = None,
+ fs: int = 44100,
+ npy_memmap: bool = True,
+ ) -> None:
+
+ if stems is None:
+ stems = self.ALLOWED_STEMS
+ self.stems = stems
+
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
+
+ files = sorted(os.listdir(data_path))
+ files = [
+ f
+ for f in files
+ if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
+ ]
+ # pprint(list(enumerate(files)))
+ if split == "train":
+ assert len(files) == 3406, len(files)
+ elif split == "val":
+ assert len(files) == 487, len(files)
+ elif split == "test":
+ assert len(files) == 973, len(files)
+
+ self.n_tracks = len(files)
+
+ super().__init__(
+ data_path=data_path,
+ split=split,
+ stems=stems,
+ files=files,
+ fs=fs,
+ npy_memmap=npy_memmap,
+ )
+
+ def __len__(self) -> int:
+ return self.n_tracks
+
+
+class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset):
+ def __init__(
+ self,
+ data_root: str,
+ split: str,
+ target_length: int,
+ chunk_size_second: float,
+ stems: Optional[List[str]] = None,
+ fs: int = 44100,
+ npy_memmap: bool = True,
+ ) -> None:
+
+ if stems is None:
+ stems = self.ALLOWED_STEMS
+ self.stems = stems
+
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
+
+ files = sorted(os.listdir(data_path))
+ files = [
+ f
+ for f in files
+ if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
+ ]
+
+ if split == "train":
+ assert len(files) == 3406, len(files)
+ elif split == "val":
+ assert len(files) == 487, len(files)
+ elif split == "test":
+ assert len(files) == 973, len(files)
+
+ self.n_tracks = len(files)
+
+ self.target_length = target_length
+ self.chunk_size = int(chunk_size_second * fs)
+
+ super().__init__(
+ data_path=data_path,
+ split=split,
+ stems=stems,
+ files=files,
+ fs=fs,
+ npy_memmap=npy_memmap,
+ )
+
+ def __len__(self) -> int:
+ return self.target_length
+
+ def get_identifier(self, index):
+ return super().get_identifier(index % self.n_tracks)
+
+ def get_stem(
+ self,
+ *,
+ stem: str,
+ identifier: Dict[str, Any],
+ chunk_here: bool = False,
+ ) -> torch.Tensor:
+
+ stem = super().get_stem(stem=stem, identifier=identifier)
+
+ if chunk_here:
+ start = np.random.randint(
+ 0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
+ )
+ end = start + self.chunk_size
+
+ stem = stem[:, start:end]
+
+ return stem
+
+ def __getitem__(self, index: int) -> DataDict:
+ identifier = self.get_identifier(index)
+ # self.index_lock = index
+ audio = self.get_audio(identifier)
+ # self.index_lock = None
+
+ start = np.random.randint(0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size)
+ end = start + self.chunk_size
+
+ audio = {k: v[:, start:end] for k, v in audio.items()}
+
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
+
+
+class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset):
+ def __init__(
+ self,
+ data_root: str,
+ split: str,
+ chunk_size_second: float,
+ hop_size_second: float,
+ stems: Optional[List[str]] = None,
+ fs: int = 44100,
+ npy_memmap: bool = True,
+ ) -> None:
+
+ if stems is None:
+ stems = self.ALLOWED_STEMS
+ self.stems = stems
+
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
+
+ files = sorted(os.listdir(data_path))
+ files = [
+ f
+ for f in files
+ if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
+ ]
+ # pprint(list(enumerate(files)))
+ if split == "train":
+ assert len(files) == 3406, len(files)
+ elif split == "val":
+ assert len(files) == 487, len(files)
+ elif split == "test":
+ assert len(files) == 973, len(files)
+
+ self.n_tracks = len(files)
+
+ self.chunk_size = int(chunk_size_second * fs)
+ self.hop_size = int(hop_size_second * fs)
+ self.n_chunks_per_track = int(
+ (self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second
+ )
+
+ self.length = self.n_tracks * self.n_chunks_per_track
+
+ super().__init__(
+ data_path=data_path,
+ split=split,
+ stems=stems,
+ files=files,
+ fs=fs,
+ npy_memmap=npy_memmap,
+ )
+
+ def get_identifier(self, index):
+ return super().get_identifier(index % self.n_tracks)
+
+ def __len__(self) -> int:
+ return self.length
+
+ def __getitem__(self, item: int) -> DataDict:
+
+ index = item % self.n_tracks
+ chunk = item // self.n_tracks
+
+ data_ = super().__getitem__(index)
+
+ audio = data_["audio"]
+
+ start = chunk * self.hop_size
+ end = start + self.chunk_size
+
+ for stem in self.stems:
+ data_["audio"][stem] = audio[stem][:, start:end]
+
+ return data_
+
+
+class DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
+ DivideAndRemasterRandomChunkDataset
+):
+ def __init__(
+ self,
+ data_root: str,
+ split: str,
+ target_length: int,
+ chunk_size_second: float,
+ stems: Optional[List[str]] = None,
+ fs: int = 44100,
+ npy_memmap: bool = True,
+ ) -> None:
+
+ if stems is None:
+ stems = self.ALLOWED_STEMS
+
+ stems_no_mixture = [s for s in stems if s != "mixture"]
+
+ super().__init__(
+ data_root=data_root,
+ split=split,
+ target_length=target_length,
+ chunk_size_second=chunk_size_second,
+ stems=stems_no_mixture,
+ fs=fs,
+ npy_memmap=npy_memmap,
+ )
+
+ self.stems = stems
+ self.stems_no_mixture = stems_no_mixture
+
+ def __getitem__(self, index: int) -> DataDict:
+
+ data_ = super().__getitem__(index)
+
+ dry = data_["audio"]["speech"][:]
+ n_samples = dry.shape[-1]
+
+ wet_level = np.random.rand()
+
+ speech = pb.Reverb(
+ room_size=np.random.rand(),
+ damping=np.random.rand(),
+ wet_level=wet_level,
+ dry_level=(1 - wet_level),
+ width=np.random.rand(),
+ ).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples]
+
+ data_["audio"]["speech"] = speech
+
+ data_["audio"]["mixture"] = sum(
+ [data_["audio"][s] for s in self.stems_no_mixture]
+ )
+
+ return data_
+
+ def __len__(self) -> int:
+ return super().__len__()
+
+
+if __name__ == "__main__":
+
+ from pprint import pprint
+ from tqdm.auto import tqdm
+
+ for split_ in ["train", "val", "test"]:
+ ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
+ data_root="$DATA_ROOT/DnR/v2np",
+ split=split_,
+ target_length=100,
+ chunk_size_second=6.0,
+ )
+
+ print(split_, len(ds))
+
+ for track_ in tqdm(ds): # type: ignore
+ pprint(track_)
+ track_["audio"] = {k: v.shape for k, v in track_["audio"].items()}
+ pprint(track_)
+ # break
+
+ break
diff --git a/mvsepless/models/bandit/core/data/dnr/preprocess.py b/mvsepless/models/bandit/core/data/dnr/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..18d68b18fbe963647df1253190625ea639035572
--- /dev/null
+++ b/mvsepless/models/bandit/core/data/dnr/preprocess.py
@@ -0,0 +1,51 @@
+import glob
+import os
+from typing import Tuple
+
+import numpy as np
+import torchaudio as ta
+from tqdm.contrib.concurrent import process_map
+
+
+def process_one(inputs: Tuple[str, str, int]) -> None:
+ infile, outfile, target_fs = inputs
+
+ dir = os.path.dirname(outfile)
+ os.makedirs(dir, exist_ok=True)
+
+ data, fs = ta.load(infile)
+
+ if fs != target_fs:
+ data = ta.functional.resample(
+ data, fs, target_fs, resampling_method="sinc_interp_kaiser"
+ )
+ fs = target_fs
+
+ data = data.numpy()
+ data = data.astype(np.float32)
+
+ if os.path.exists(outfile):
+ data_ = np.load(outfile)
+ if np.allclose(data, data_):
+ return
+
+ np.save(outfile, data)
+
+
+def preprocess(data_path: str, output_path: str, fs: int) -> None:
+ files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
+ print(files)
+ outfiles = [
+ f.replace(data_path, output_path).replace(".wav", ".npy") for f in files
+ ]
+
+ os.makedirs(output_path, exist_ok=True)
+ inputs = list(zip(files, outfiles, [fs] * len(files)))
+
+ process_map(process_one, inputs, chunksize=32)
+
+
+if __name__ == "__main__":
+ import fire
+
+ fire.Fire()
diff --git a/mvsepless/models/bandit/core/data/musdb/__init__.py b/mvsepless/models/bandit/core/data/musdb/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mvsepless/models/bandit/core/data/musdb/datamodule.py b/mvsepless/models/bandit/core/data/musdb/datamodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ba2cb25e3fd7bbaf120b0d1cde90adc26d9d1df
--- /dev/null
+++ b/mvsepless/models/bandit/core/data/musdb/datamodule.py
@@ -0,0 +1,75 @@
+import os.path
+from typing import Mapping, Optional
+
+import pytorch_lightning as pl
+
+from .dataset import (
+ MUSDB18BaseDataset,
+ MUSDB18FullTrackDataset,
+ MUSDB18SadDataset,
+ MUSDB18SadOnTheFlyAugmentedDataset,
+)
+
+
+def MUSDB18DataModule(
+ data_root: str = "$DATA_ROOT/MUSDB18/HQ",
+ target_stem: str = "vocals",
+ batch_size: int = 2,
+ num_workers: int = 8,
+ train_kwargs: Optional[Mapping] = None,
+ val_kwargs: Optional[Mapping] = None,
+ test_kwargs: Optional[Mapping] = None,
+ datamodule_kwargs: Optional[Mapping] = None,
+ use_on_the_fly: bool = True,
+ npy_memmap: bool = True,
+) -> pl.LightningDataModule:
+ if train_kwargs is None:
+ train_kwargs = {}
+
+ if val_kwargs is None:
+ val_kwargs = {}
+
+ if test_kwargs is None:
+ test_kwargs = {}
+
+ if datamodule_kwargs is None:
+ datamodule_kwargs = {}
+
+ train_dataset: MUSDB18BaseDataset
+
+ if use_on_the_fly:
+ train_dataset = MUSDB18SadOnTheFlyAugmentedDataset(
+ data_root=os.path.join(data_root, "saded-np"),
+ split="train",
+ target_stem=target_stem,
+ **train_kwargs
+ )
+ else:
+ train_dataset = MUSDB18SadDataset(
+ data_root=os.path.join(data_root, "saded-np"),
+ split="train",
+ target_stem=target_stem,
+ **train_kwargs
+ )
+
+ datamodule = pl.LightningDataModule.from_datasets(
+ train_dataset=train_dataset,
+ val_dataset=MUSDB18SadDataset(
+ data_root=os.path.join(data_root, "saded-np"),
+ split="val",
+ target_stem=target_stem,
+ **val_kwargs
+ ),
+ test_dataset=MUSDB18FullTrackDataset(
+ data_root=os.path.join(data_root, "canonical"), split="test", **test_kwargs
+ ),
+ batch_size=batch_size,
+ num_workers=num_workers,
+ **datamodule_kwargs
+ )
+
+ datamodule.predict_dataloader = ( # type: ignore[method-assign]
+ datamodule.test_dataloader
+ )
+
+ return datamodule
diff --git a/mvsepless/models/bandit/core/data/musdb/dataset.py b/mvsepless/models/bandit/core/data/musdb/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..82eb33bb32a6e2172ace1bb8ae1a9649041af5bf
--- /dev/null
+++ b/mvsepless/models/bandit/core/data/musdb/dataset.py
@@ -0,0 +1,273 @@
+import os
+from abc import ABC
+from typing import List, Optional, Tuple
+
+import numpy as np
+import torch
+import torchaudio as ta
+from torch.utils import data
+
+from .._types import AudioDict, DataDict
+from ..base import BaseSourceSeparationDataset
+
+
+class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC):
+
+ ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"]
+
+ def __init__(
+ self,
+ split: str,
+ stems: List[str],
+ files: List[str],
+ data_path: str,
+ fs: int = 44100,
+ npy_memmap=False,
+ ) -> None:
+ super().__init__(
+ split=split,
+ stems=stems,
+ files=files,
+ data_path=data_path,
+ fs=fs,
+ npy_memmap=npy_memmap,
+ recompute_mixture=False,
+ )
+
+ def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
+ track = identifier["track"]
+ path = os.path.join(self.data_path, track)
+ # noinspection PyUnresolvedReferences
+
+ if self.npy_memmap:
+ audio = np.load(os.path.join(path, f"{stem}.wav.npy"), mmap_mode="r")
+ else:
+ audio, _ = ta.load(os.path.join(path, f"{stem}.wav"))
+
+ return audio
+
+ def get_identifier(self, index):
+ return dict(track=self.files[index])
+
+ def __getitem__(self, index: int) -> DataDict:
+ identifier = self.get_identifier(index)
+ audio = self.get_audio(identifier)
+
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
+
+
+class MUSDB18FullTrackDataset(MUSDB18BaseDataset):
+
+ N_TRAIN_TRACKS = 100
+ N_TEST_TRACKS = 50
+ VALIDATION_FILES = [
+ "Actions - One Minute Smile",
+ "Clara Berry And Wooldog - Waltz For My Victims",
+ "Johnny Lokke - Promises & Lies",
+ "Patrick Talbot - A Reason To Leave",
+ "Triviul - Angelsaint",
+ "Alexander Ross - Goodbye Bolero",
+ "Fergessen - Nos Palpitants",
+ "Leaf - Summerghost",
+ "Skelpolu - Human Mistakes",
+ "Young Griffo - Pennies",
+ "ANiMAL - Rockshow",
+ "James May - On The Line",
+ "Meaxic - Take A Step",
+ "Traffic Experiment - Sirens",
+ ]
+
+ def __init__(
+ self, data_root: str, split: str, stems: Optional[List[str]] = None
+ ) -> None:
+
+ if stems is None:
+ stems = self.ALLOWED_STEMS
+ self.stems = stems
+
+ if split == "test":
+ subset = "test"
+ elif split in ["train", "val"]:
+ subset = "train"
+ else:
+ raise NameError
+
+ data_path = os.path.join(data_root, subset)
+
+ files = sorted(os.listdir(data_path))
+ files = [f for f in files if not f.startswith(".")]
+ # pprint(list(enumerate(files)))
+ if subset == "train":
+ assert len(files) == 100, len(files)
+ if split == "train":
+ files = [f for f in files if f not in self.VALIDATION_FILES]
+ assert len(files) == 100 - len(self.VALIDATION_FILES)
+ else:
+ files = [f for f in files if f in self.VALIDATION_FILES]
+ assert len(files) == len(self.VALIDATION_FILES)
+ else:
+ split = "test"
+ assert len(files) == 50
+
+ self.n_tracks = len(files)
+
+ super().__init__(data_path=data_path, split=split, stems=stems, files=files)
+
+ def __len__(self) -> int:
+ return self.n_tracks
+
+
+class MUSDB18SadDataset(MUSDB18BaseDataset):
+ def __init__(
+ self,
+ data_root: str,
+ split: str,
+ target_stem: str,
+ stems: Optional[List[str]] = None,
+ target_length: Optional[int] = None,
+ npy_memmap=False,
+ ) -> None:
+
+ if stems is None:
+ stems = self.ALLOWED_STEMS
+
+ data_path = os.path.join(data_root, target_stem, split)
+
+ files = sorted(os.listdir(data_path))
+ files = [f for f in files if not f.startswith(".")]
+
+ super().__init__(
+ data_path=data_path,
+ split=split,
+ stems=stems,
+ files=files,
+ npy_memmap=npy_memmap,
+ )
+ self.n_segments = len(files)
+ self.target_stem = target_stem
+ self.target_length = (
+ target_length if target_length is not None else self.n_segments
+ )
+
+ def __len__(self) -> int:
+ return self.target_length
+
+ def __getitem__(self, index: int) -> DataDict:
+
+ index = index % self.n_segments
+
+ return super().__getitem__(index)
+
+ def get_identifier(self, index):
+ return super().get_identifier(index % self.n_segments)
+
+
+class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset):
+ def __init__(
+ self,
+ data_root: str,
+ split: str,
+ target_stem: str,
+ stems: Optional[List[str]] = None,
+ target_length: int = 20000,
+ apply_probability: Optional[float] = None,
+ chunk_size_second: float = 3.0,
+ random_scale_range_db: Tuple[float, float] = (-10, 10),
+ drop_probability: float = 0.1,
+ rescale: bool = True,
+ ) -> None:
+ super().__init__(data_root, split, target_stem, stems)
+
+ if apply_probability is None:
+ apply_probability = (target_length - self.n_segments) / target_length
+
+ self.apply_probability = apply_probability
+ self.drop_probability = drop_probability
+ self.chunk_size_second = chunk_size_second
+ self.random_scale_range_db = random_scale_range_db
+ self.rescale = rescale
+
+ self.chunk_size_sample = int(self.chunk_size_second * self.fs)
+ self.target_length = target_length
+
+ def __len__(self) -> int:
+ return self.target_length
+
+ def __getitem__(self, index: int) -> DataDict:
+
+ index = index % self.n_segments
+
+ # if np.random.rand() > self.apply_probability:
+ # return super().__getitem__(index)
+
+ audio = {}
+ identifier = self.get_identifier(index)
+
+ # assert self.target_stem in self.stems_no_mixture
+ for stem in self.stems_no_mixture:
+ if stem == self.target_stem:
+ identifier_ = identifier
+ else:
+ if np.random.rand() < self.apply_probability:
+ index_ = np.random.randint(self.n_segments)
+ identifier_ = self.get_identifier(index_)
+ else:
+ identifier_ = identifier
+
+ audio[stem] = self.get_stem(stem=stem, identifier=identifier_)
+
+ # if stem == self.target_stem:
+
+ if self.chunk_size_sample < audio[stem].shape[-1]:
+ chunk_start = np.random.randint(
+ audio[stem].shape[-1] - self.chunk_size_sample
+ )
+ else:
+ chunk_start = 0
+
+ if np.random.rand() < self.drop_probability:
+ # db_scale = "-inf"
+ linear_scale = 0.0
+ else:
+ db_scale = np.random.uniform(*self.random_scale_range_db)
+ linear_scale = np.power(10, db_scale / 20)
+ # db_scale = f"{db_scale:+2.1f}"
+ # print(linear_scale)
+ audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample] = (
+ linear_scale
+ * audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample]
+ )
+
+ audio["mixture"] = self.compute_mixture(audio)
+
+ if self.rescale:
+ max_abs_val = max(
+ [torch.max(torch.abs(audio[stem])) for stem in self.stems]
+ ) # type: ignore[type-var]
+ if max_abs_val > 1:
+ audio = {k: v / max_abs_val for k, v in audio.items()}
+
+ track = identifier["track"]
+
+ return {"audio": audio, "track": f"{self.split}/{track}"}
+
+
+# if __name__ == "__main__":
+#
+# from pprint import pprint
+# from tqdm.auto import tqdm
+#
+# for split_ in ["train", "val", "test"]:
+# ds = MUSDB18SadOnTheFlyAugmentedDataset(
+# data_root="$DATA_ROOT/MUSDB18/HQ/saded",
+# split=split_,
+# target_stem="vocals"
+# )
+#
+# print(split_, len(ds))
+#
+# for track_ in tqdm(ds):
+# track_["audio"] = {
+# k: v.shape for k, v in track_["audio"].items()
+# }
+# pprint(track_)
diff --git a/mvsepless/models/bandit/core/data/musdb/preprocess.py b/mvsepless/models/bandit/core/data/musdb/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad95caeb1ebcbac261ca10cd333896f54bad927b
--- /dev/null
+++ b/mvsepless/models/bandit/core/data/musdb/preprocess.py
@@ -0,0 +1,226 @@
+import glob
+import os
+
+import numpy as np
+import torch
+import torchaudio as ta
+from torch import nn
+from torch.nn import functional as F
+from tqdm.contrib.concurrent import process_map
+
+from .._types import DataDict
+from .dataset import MUSDB18FullTrackDataset
+import pyloudnorm as pyln
+
+
+class SourceActivityDetector(nn.Module):
+ def __init__(
+ self,
+ analysis_stem: str,
+ output_path: str,
+ fs: int = 44100,
+ segment_length_second: float = 6.0,
+ hop_length_second: float = 3.0,
+ n_chunks: int = 10,
+ chunk_epsilon: float = 1e-5,
+ energy_threshold_quantile: float = 0.15,
+ segment_epsilon: float = 1e-3,
+ salient_proportion_threshold: float = 0.5,
+ target_lufs: float = -24,
+ ) -> None:
+ super().__init__()
+
+ self.fs = fs
+ self.segment_length = int(segment_length_second * self.fs)
+ self.hop_length = int(hop_length_second * self.fs)
+ self.n_chunks = n_chunks
+ assert self.segment_length % self.n_chunks == 0
+ self.chunk_size = self.segment_length // self.n_chunks
+ self.chunk_epsilon = chunk_epsilon
+ self.energy_threshold_quantile = energy_threshold_quantile
+ self.segment_epsilon = segment_epsilon
+ self.salient_proportion_threshold = salient_proportion_threshold
+ self.analysis_stem = analysis_stem
+
+ self.meter = pyln.Meter(self.fs)
+ self.target_lufs = target_lufs
+
+ self.output_path = output_path
+
+ def forward(self, data: DataDict) -> None:
+
+ stem_ = self.analysis_stem if (self.analysis_stem != "none") else "mixture"
+
+ x = data["audio"][stem_]
+
+ xnp = x.numpy()
+ loudness = self.meter.integrated_loudness(xnp.T)
+
+ for stem in data["audio"]:
+ s = data["audio"][stem]
+ s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T
+ s = torch.as_tensor(s)
+ data["audio"][stem] = s
+
+ if x.ndim == 3:
+ assert x.shape[0] == 1
+ x = x[0]
+
+ n_chan, n_samples = x.shape
+
+ n_segments = (
+ int(np.ceil((n_samples - self.segment_length) / self.hop_length)) + 1
+ )
+
+ segments = torch.zeros((n_segments, n_chan, self.segment_length))
+ for i in range(n_segments):
+ start = i * self.hop_length
+ end = start + self.segment_length
+ end = min(end, n_samples)
+
+ xseg = x[:, start:end]
+
+ if end - start < self.segment_length:
+ xseg = F.pad(
+ xseg, pad=(0, self.segment_length - (end - start)), value=torch.nan
+ )
+
+ segments[i, :, :] = xseg
+
+ chunks = segments.reshape((n_segments, n_chan, self.n_chunks, self.chunk_size))
+
+ if self.analysis_stem != "none":
+ chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3))
+ chunk_energies = torch.nan_to_num(chunk_energies, nan=0)
+ chunk_energies[chunk_energies == 0] = self.chunk_epsilon
+
+ energy_threshold = torch.nanquantile(
+ chunk_energies, q=self.energy_threshold_quantile
+ )
+
+ if energy_threshold < self.segment_epsilon:
+ energy_threshold = self.segment_epsilon # type: ignore[assignment]
+
+ chunks_above_threshold = chunk_energies > energy_threshold
+ n_chunks_above_threshold = torch.mean(
+ chunks_above_threshold.to(torch.float), dim=-1
+ )
+
+ segment_above_threshold = (
+ n_chunks_above_threshold > self.salient_proportion_threshold
+ )
+
+ if torch.sum(segment_above_threshold) == 0:
+ return
+
+ else:
+ segment_above_threshold = torch.ones((n_segments,))
+
+ for i in range(n_segments):
+ if not segment_above_threshold[i]:
+ continue
+
+ outpath = os.path.join(
+ self.output_path,
+ self.analysis_stem,
+ f"{data['track']} - {self.analysis_stem}{i:03d}",
+ )
+ os.makedirs(outpath, exist_ok=True)
+
+ for stem in data["audio"]:
+ if stem == self.analysis_stem:
+ segment = torch.nan_to_num(segments[i, :, :], nan=0)
+ else:
+ start = i * self.hop_length
+ end = start + self.segment_length
+ end = min(n_samples, end)
+
+ segment = data["audio"][stem][:, start:end]
+
+ if end - start < self.segment_length:
+ segment = F.pad(
+ segment, (0, self.segment_length - (end - start))
+ )
+
+ assert segment.shape[-1] == self.segment_length, segment.shape
+
+ # ta.save(os.path.join(outpath, f"{stem}.wav"), segment, self.fs)
+
+ np.save(os.path.join(outpath, f"{stem}.wav"), segment)
+
+
+def preprocess(
+ analysis_stem: str,
+ output_path: str = "/data/MUSDB18/HQ/saded-np",
+ fs: int = 44100,
+ segment_length_second: float = 6.0,
+ hop_length_second: float = 3.0,
+ n_chunks: int = 10,
+ chunk_epsilon: float = 1e-5,
+ energy_threshold_quantile: float = 0.15,
+ segment_epsilon: float = 1e-3,
+ salient_proportion_threshold: float = 0.5,
+) -> None:
+
+ sad = SourceActivityDetector(
+ analysis_stem=analysis_stem,
+ output_path=output_path,
+ fs=fs,
+ segment_length_second=segment_length_second,
+ hop_length_second=hop_length_second,
+ n_chunks=n_chunks,
+ chunk_epsilon=chunk_epsilon,
+ energy_threshold_quantile=energy_threshold_quantile,
+ segment_epsilon=segment_epsilon,
+ salient_proportion_threshold=salient_proportion_threshold,
+ )
+
+ for split in ["train", "val", "test"]:
+ ds = MUSDB18FullTrackDataset(
+ data_root="/data/MUSDB18/HQ/canonical",
+ split=split,
+ )
+
+ tracks = []
+ for i, track in enumerate(tqdm(ds, total=len(ds))):
+ if i % 32 == 0 and tracks:
+ process_map(sad, tracks, max_workers=8)
+ tracks = []
+ tracks.append(track)
+ process_map(sad, tracks, max_workers=8)
+
+
+def loudness_norm_one(inputs):
+ infile, outfile, target_lufs = inputs
+
+ audio, fs = ta.load(infile)
+ audio = audio.mean(dim=0, keepdim=True).numpy().T
+
+ meter = pyln.Meter(fs)
+ loudness = meter.integrated_loudness(audio)
+ audio = pyln.normalize.loudness(audio, loudness, target_lufs)
+
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
+ np.save(outfile, audio.T)
+
+
+def loudness_norm(
+ data_path: str,
+ # output_path: str,
+ target_lufs=-17.0,
+):
+ files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
+
+ outfiles = [f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files]
+
+ files = [(f, o, target_lufs) for f, o in zip(files, outfiles)]
+
+ process_map(loudness_norm_one, files, chunksize=2)
+
+
+if __name__ == "__main__":
+
+ from tqdm.auto import tqdm
+ import fire
+
+ fire.Fire()
diff --git a/mvsepless/models/bandit/core/data/musdb/validation.yaml b/mvsepless/models/bandit/core/data/musdb/validation.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2f8752478d285d1d13d5e842225af1de95cae57a
--- /dev/null
+++ b/mvsepless/models/bandit/core/data/musdb/validation.yaml
@@ -0,0 +1,15 @@
+validation:
+ - 'Actions - One Minute Smile'
+ - 'Clara Berry And Wooldog - Waltz For My Victims'
+ - 'Johnny Lokke - Promises & Lies'
+ - 'Patrick Talbot - A Reason To Leave'
+ - 'Triviul - Angelsaint'
+ - 'Alexander Ross - Goodbye Bolero'
+ - 'Fergessen - Nos Palpitants'
+ - 'Leaf - Summerghost'
+ - 'Skelpolu - Human Mistakes'
+ - 'Young Griffo - Pennies'
+ - 'ANiMAL - Rockshow'
+ - 'James May - On The Line'
+ - 'Meaxic - Take A Step'
+ - 'Traffic Experiment - Sirens'
\ No newline at end of file
diff --git a/mvsepless/models/bandit/core/loss/__init__.py b/mvsepless/models/bandit/core/loss/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..993be521fa7ab8f06a2a012beabdb9fdd6cd0a80
--- /dev/null
+++ b/mvsepless/models/bandit/core/loss/__init__.py
@@ -0,0 +1,8 @@
+from ._multistem import MultiStemWrapperFromConfig
+from ._timefreq import (
+ ReImL1Loss,
+ ReImL2Loss,
+ TimeFreqL1Loss,
+ TimeFreqL2Loss,
+ TimeFreqSignalNoisePNormRatioLoss,
+)
diff --git a/mvsepless/models/bandit/core/loss/_complex.py b/mvsepless/models/bandit/core/loss/_complex.py
new file mode 100644
index 0000000000000000000000000000000000000000..68c82f204709d07cba013f1582ca985bcf66dde6
--- /dev/null
+++ b/mvsepless/models/bandit/core/loss/_complex.py
@@ -0,0 +1,27 @@
+from typing import Any
+
+import torch
+from torch import nn
+from torch.nn.modules import loss as _loss
+from torch.nn.modules.loss import _Loss
+
+
+class ReImLossWrapper(_Loss):
+ def __init__(self, module: _Loss) -> None:
+ super().__init__()
+ self.module = module
+
+ def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ return self.module(torch.view_as_real(preds), torch.view_as_real(target))
+
+
+class ReImL1Loss(ReImLossWrapper):
+ def __init__(self, **kwargs: Any) -> None:
+ l1_loss = _loss.L1Loss(**kwargs)
+ super().__init__(module=(l1_loss))
+
+
+class ReImL2Loss(ReImLossWrapper):
+ def __init__(self, **kwargs: Any) -> None:
+ l2_loss = _loss.MSELoss(**kwargs)
+ super().__init__(module=(l2_loss))
diff --git a/mvsepless/models/bandit/core/loss/_multistem.py b/mvsepless/models/bandit/core/loss/_multistem.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9c4a4f776a318b4450cc98c820a057983e2f9a3
--- /dev/null
+++ b/mvsepless/models/bandit/core/loss/_multistem.py
@@ -0,0 +1,43 @@
+from typing import Any, Dict
+
+import torch
+from asteroid import losses as asteroid_losses
+from torch import nn
+from torch.nn.modules.loss import _Loss
+
+from . import snr
+
+
+def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss:
+
+ for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]:
+ if name in module.__dict__:
+ return module.__dict__[name](**kwargs)
+
+ raise NameError
+
+
+class MultiStemWrapper(_Loss):
+ def __init__(self, module: _Loss, modality: str = "audio") -> None:
+ super().__init__()
+ self.loss = module
+ self.modality = modality
+
+ def forward(
+ self,
+ preds: Dict[str, Dict[str, torch.Tensor]],
+ target: Dict[str, Dict[str, torch.Tensor]],
+ ) -> torch.Tensor:
+ loss = {
+ stem: self.loss(preds[self.modality][stem], target[self.modality][stem])
+ for stem in preds[self.modality]
+ if stem in target[self.modality]
+ }
+
+ return sum(list(loss.values()))
+
+
+class MultiStemWrapperFromConfig(MultiStemWrapper):
+ def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None:
+ loss = parse_loss(name, kwargs)
+ super().__init__(module=loss, modality=modality)
diff --git a/mvsepless/models/bandit/core/loss/_timefreq.py b/mvsepless/models/bandit/core/loss/_timefreq.py
new file mode 100644
index 0000000000000000000000000000000000000000..edf848e05650b7ce833f6d642a8a53896cd2d8fd
--- /dev/null
+++ b/mvsepless/models/bandit/core/loss/_timefreq.py
@@ -0,0 +1,95 @@
+from typing import Any, Dict, Optional
+
+import torch
+from torch import nn
+from torch.nn.modules.loss import _Loss
+
+from ._multistem import MultiStemWrapper
+from ._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper
+from .snr import SignalNoisePNormRatio
+
+
+class TimeFreqWrapper(_Loss):
+ def __init__(
+ self,
+ time_module: _Loss,
+ freq_module: Optional[_Loss] = None,
+ time_weight: float = 1.0,
+ freq_weight: float = 1.0,
+ multistem: bool = True,
+ ) -> None:
+ super().__init__()
+
+ if freq_module is None:
+ freq_module = time_module
+
+ if multistem:
+ time_module = MultiStemWrapper(time_module, modality="audio")
+ freq_module = MultiStemWrapper(freq_module, modality="spectrogram")
+
+ self.time_module = time_module
+ self.freq_module = freq_module
+
+ self.time_weight = time_weight
+ self.freq_weight = freq_weight
+
+ # TODO: add better type hints
+ def forward(self, preds: Any, target: Any) -> torch.Tensor:
+
+ return self.time_weight * self.time_module(
+ preds, target
+ ) + self.freq_weight * self.freq_module(preds, target)
+
+
+class TimeFreqL1Loss(TimeFreqWrapper):
+ def __init__(
+ self,
+ time_weight: float = 1.0,
+ freq_weight: float = 1.0,
+ tkwargs: Optional[Dict[str, Any]] = None,
+ fkwargs: Optional[Dict[str, Any]] = None,
+ multistem: bool = True,
+ ) -> None:
+ if tkwargs is None:
+ tkwargs = {}
+ if fkwargs is None:
+ fkwargs = {}
+ time_module = nn.L1Loss(**tkwargs)
+ freq_module = ReImL1Loss(**fkwargs)
+ super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
+
+
+class TimeFreqL2Loss(TimeFreqWrapper):
+ def __init__(
+ self,
+ time_weight: float = 1.0,
+ freq_weight: float = 1.0,
+ tkwargs: Optional[Dict[str, Any]] = None,
+ fkwargs: Optional[Dict[str, Any]] = None,
+ multistem: bool = True,
+ ) -> None:
+ if tkwargs is None:
+ tkwargs = {}
+ if fkwargs is None:
+ fkwargs = {}
+ time_module = nn.MSELoss(**tkwargs)
+ freq_module = ReImL2Loss(**fkwargs)
+ super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
+
+
+class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper):
+ def __init__(
+ self,
+ time_weight: float = 1.0,
+ freq_weight: float = 1.0,
+ tkwargs: Optional[Dict[str, Any]] = None,
+ fkwargs: Optional[Dict[str, Any]] = None,
+ multistem: bool = True,
+ ) -> None:
+ if tkwargs is None:
+ tkwargs = {}
+ if fkwargs is None:
+ fkwargs = {}
+ time_module = SignalNoisePNormRatio(**tkwargs)
+ freq_module = SignalNoisePNormRatio(**fkwargs)
+ super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
diff --git a/mvsepless/models/bandit/core/loss/snr.py b/mvsepless/models/bandit/core/loss/snr.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d712a525027417198c7072ef571166ef0c02afa
--- /dev/null
+++ b/mvsepless/models/bandit/core/loss/snr.py
@@ -0,0 +1,139 @@
+import torch
+from torch.nn.modules.loss import _Loss
+from torch.nn import functional as F
+
+
+class SignalNoisePNormRatio(_Loss):
+ def __init__(
+ self,
+ p: float = 1.0,
+ scale_invariant: bool = False,
+ zero_mean: bool = False,
+ take_log: bool = True,
+ reduction: str = "mean",
+ EPS: float = 1e-3,
+ ) -> None:
+ assert reduction != "sum", NotImplementedError
+ super().__init__(reduction=reduction)
+ assert not zero_mean
+
+ self.p = p
+
+ self.EPS = EPS
+ self.take_log = take_log
+
+ self.scale_invariant = scale_invariant
+
+ def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+
+ target_ = target
+ if self.scale_invariant:
+ ndim = target.ndim
+ dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True)
+ s_target_energy = torch.sum(
+ target * torch.conj(target), dim=-1, keepdim=True
+ )
+
+ if ndim > 2:
+ dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True)
+ s_target_energy = torch.sum(
+ s_target_energy, dim=list(range(1, ndim)), keepdim=True
+ )
+
+ target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8)
+ target = target_ * target_scaler
+
+ if torch.is_complex(est_target):
+ est_target = torch.view_as_real(est_target)
+ target = torch.view_as_real(target)
+
+ batch_size = est_target.shape[0]
+ est_target = est_target.reshape(batch_size, -1)
+ target = target.reshape(batch_size, -1)
+ # target_ = target_.reshape(batch_size, -1)
+
+ if self.p == 1:
+ e_error = torch.abs(est_target - target).mean(dim=-1)
+ e_target = torch.abs(target).mean(dim=-1)
+ elif self.p == 2:
+ e_error = torch.square(est_target - target).mean(dim=-1)
+ e_target = torch.square(target).mean(dim=-1)
+ else:
+ raise NotImplementedError
+
+ if self.take_log:
+ loss = 10 * (
+ torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS)
+ )
+ else:
+ loss = (e_error + self.EPS) / (e_target + self.EPS)
+
+ if self.reduction == "mean":
+ loss = loss.mean()
+ elif self.reduction == "sum":
+ loss = loss.sum()
+
+ return loss
+
+
+class MultichannelSingleSrcNegSDR(_Loss):
+ def __init__(
+ self,
+ sdr_type: str,
+ p: float = 2.0,
+ zero_mean: bool = True,
+ take_log: bool = True,
+ reduction: str = "mean",
+ EPS: float = 1e-8,
+ ) -> None:
+ assert reduction != "sum", NotImplementedError
+ super().__init__(reduction=reduction)
+
+ assert sdr_type in ["snr", "sisdr", "sdsdr"]
+ self.sdr_type = sdr_type
+ self.zero_mean = zero_mean
+ self.take_log = take_log
+ self.EPS = 1e-8
+
+ self.p = p
+
+ def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ if target.size() != est_target.size() or target.ndim != 3:
+ raise TypeError(
+ f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead"
+ )
+ # Step 1. Zero-mean norm
+ if self.zero_mean:
+ mean_source = torch.mean(target, dim=[1, 2], keepdim=True)
+ mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True)
+ target = target - mean_source
+ est_target = est_target - mean_estimate
+ # Step 2. Pair-wise SI-SDR.
+ if self.sdr_type in ["sisdr", "sdsdr"]:
+ # [batch, 1]
+ dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True)
+ # [batch, 1]
+ s_target_energy = torch.sum(target**2, dim=[1, 2], keepdim=True) + self.EPS
+ # [batch, time]
+ scaled_target = dot * target / s_target_energy
+ else:
+ # [batch, time]
+ scaled_target = target
+ if self.sdr_type in ["sdsdr", "snr"]:
+ e_noise = est_target - target
+ else:
+ e_noise = est_target - scaled_target
+ # [batch]
+
+ if self.p == 2.0:
+ losses = torch.sum(scaled_target**2, dim=[1, 2]) / (
+ torch.sum(e_noise**2, dim=[1, 2]) + self.EPS
+ )
+ else:
+ losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / (
+ torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
+ )
+ if self.take_log:
+ losses = 10 * torch.log10(losses + self.EPS)
+ losses = losses.mean() if self.reduction == "mean" else losses
+ return -losses
diff --git a/mvsepless/models/bandit/core/metrics/__init__.py b/mvsepless/models/bandit/core/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c638b4df585ad6c3c6490d9e67b7fc197f0d06f4
--- /dev/null
+++ b/mvsepless/models/bandit/core/metrics/__init__.py
@@ -0,0 +1,9 @@
+from .snr import (
+ ChunkMedianScaleInvariantSignalDistortionRatio,
+ ChunkMedianScaleInvariantSignalNoiseRatio,
+ ChunkMedianSignalDistortionRatio,
+ ChunkMedianSignalNoiseRatio,
+ SafeSignalDistortionRatio,
+)
+
+# from .mushra import EstimatedMushraScore
diff --git a/mvsepless/models/bandit/core/metrics/_squim.py b/mvsepless/models/bandit/core/metrics/_squim.py
new file mode 100644
index 0000000000000000000000000000000000000000..71c993a2b6cb3da36849c2f87ef7bb7443a9095c
--- /dev/null
+++ b/mvsepless/models/bandit/core/metrics/_squim.py
@@ -0,0 +1,443 @@
+from dataclasses import dataclass
+
+from torchaudio._internal import load_state_dict_from_url
+
+import math
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def transform_wb_pesq_range(x: float) -> float:
+ """The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined
+ for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric
+ defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score".
+
+ Args:
+ x (float): Narrow-band PESQ score.
+
+ Returns:
+ (float): Wide-band PESQ score.
+ """
+ return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
+
+
+PESQRange: Tuple[float, float] = (
+ 1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of
+ # the raw score is not -0.5 anymore. It's hard to figure out the true lower bound.
+ # We are using 1.0 as a reasonable approximation.
+ transform_wb_pesq_range(4.5),
+)
+
+
+class RangeSigmoid(nn.Module):
+ def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
+ super(RangeSigmoid, self).__init__()
+ assert isinstance(val_range, tuple) and len(val_range) == 2
+ self.val_range: Tuple[float, float] = val_range
+ self.sigmoid: nn.modules.Module = nn.Sigmoid()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ out = (
+ self.sigmoid(x) * (self.val_range[1] - self.val_range[0])
+ + self.val_range[0]
+ )
+ return out
+
+
+class Encoder(nn.Module):
+ """Encoder module that transform 1D waveform to 2D representations.
+
+ Args:
+ feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512)
+ win_len (int, optional): kernel size in the Conv1D layer. (Default: 32)
+ """
+
+ def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
+ super(Encoder, self).__init__()
+
+ self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Apply waveforms to convolutional layer and ReLU layer.
+
+ Args:
+ x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
+
+ Returns:
+ (torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`.
+ """
+ out = x.unsqueeze(dim=1)
+ out = F.relu(self.conv1d(out))
+ return out
+
+
+class SingleRNN(nn.Module):
+ def __init__(
+ self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0
+ ) -> None:
+ super(SingleRNN, self).__init__()
+
+ self.rnn_type = rnn_type
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+
+ self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
+ input_size,
+ hidden_size,
+ 1,
+ dropout=dropout,
+ batch_first=True,
+ bidirectional=True,
+ )
+
+ self.proj = nn.Linear(hidden_size * 2, input_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # input shape: batch, seq, dim
+ out, _ = self.rnn(x)
+ out = self.proj(out)
+ return out
+
+
+class DPRNN(nn.Module):
+ """*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`.
+
+ Args:
+ feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64)
+ hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128)
+ num_blocks (int, optional): Number of DPRNN layers. (Default: 6)
+ rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM")
+ d_model (int, optional): The number of expected features in the input. (Default: 256)
+ chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100)
+ chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50)
+ """
+
+ def __init__(
+ self,
+ feat_dim: int = 64,
+ hidden_dim: int = 128,
+ num_blocks: int = 6,
+ rnn_type: str = "LSTM",
+ d_model: int = 256,
+ chunk_size: int = 100,
+ chunk_stride: int = 50,
+ ) -> None:
+ super(DPRNN, self).__init__()
+
+ self.num_blocks = num_blocks
+
+ self.row_rnn = nn.ModuleList([])
+ self.col_rnn = nn.ModuleList([])
+ self.row_norm = nn.ModuleList([])
+ self.col_norm = nn.ModuleList([])
+ for _ in range(num_blocks):
+ self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
+ self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
+ self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
+ self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
+ self.conv = nn.Sequential(
+ nn.Conv2d(feat_dim, d_model, 1),
+ nn.PReLU(),
+ )
+ self.chunk_size = chunk_size
+ self.chunk_stride = chunk_stride
+
+ def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
+ # input shape: (B, N, T)
+ seq_len = x.shape[-1]
+
+ rest = (
+ self.chunk_size
+ - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
+ )
+ out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
+
+ return out, rest
+
+ def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
+ out, rest = self.pad_chunk(x)
+ batch_size, feat_dim, seq_len = out.shape
+
+ segments1 = (
+ out[:, :, : -self.chunk_stride]
+ .contiguous()
+ .view(batch_size, feat_dim, -1, self.chunk_size)
+ )
+ segments2 = (
+ out[:, :, self.chunk_stride :]
+ .contiguous()
+ .view(batch_size, feat_dim, -1, self.chunk_size)
+ )
+ out = torch.cat([segments1, segments2], dim=3)
+ out = (
+ out.view(batch_size, feat_dim, -1, self.chunk_size)
+ .transpose(2, 3)
+ .contiguous()
+ )
+
+ return out, rest
+
+ def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
+ batch_size, dim, _, _ = x.shape
+ out = (
+ x.transpose(2, 3)
+ .contiguous()
+ .view(batch_size, dim, -1, self.chunk_size * 2)
+ )
+ out1 = (
+ out[:, :, :, : self.chunk_size]
+ .contiguous()
+ .view(batch_size, dim, -1)[:, :, self.chunk_stride :]
+ )
+ out2 = (
+ out[:, :, :, self.chunk_size :]
+ .contiguous()
+ .view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
+ )
+ out = out1 + out2
+ if rest > 0:
+ out = out[:, :, :-rest]
+ out = out.contiguous()
+ return out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x, rest = self.chunking(x)
+ batch_size, _, dim1, dim2 = x.shape
+ out = x
+ for row_rnn, row_norm, col_rnn, col_norm in zip(
+ self.row_rnn, self.row_norm, self.col_rnn, self.col_norm
+ ):
+ row_in = (
+ out.permute(0, 3, 2, 1)
+ .contiguous()
+ .view(batch_size * dim2, dim1, -1)
+ .contiguous()
+ )
+ row_out = row_rnn(row_in)
+ row_out = (
+ row_out.view(batch_size, dim2, dim1, -1)
+ .permute(0, 3, 2, 1)
+ .contiguous()
+ )
+ row_out = row_norm(row_out)
+ out = out + row_out
+
+ col_in = (
+ out.permute(0, 2, 3, 1)
+ .contiguous()
+ .view(batch_size * dim1, dim2, -1)
+ .contiguous()
+ )
+ col_out = col_rnn(col_in)
+ col_out = (
+ col_out.view(batch_size, dim1, dim2, -1)
+ .permute(0, 3, 1, 2)
+ .contiguous()
+ )
+ col_out = col_norm(col_out)
+ out = out + col_out
+ out = self.conv(out)
+ out = self.merging(out, rest)
+ out = out.transpose(1, 2).contiguous()
+ return out
+
+
+class AutoPool(nn.Module):
+ def __init__(self, pool_dim: int = 1) -> None:
+ super(AutoPool, self).__init__()
+ self.pool_dim: int = pool_dim
+ self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
+ self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ weight = self.softmax(torch.mul(x, self.alpha))
+ out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
+ return out
+
+
+class SquimObjective(nn.Module):
+ """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores
+ for speech enhancement (e.g., STOI, PESQ, and SI-SDR).
+
+ Args:
+ encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation.
+ dprnn (torch.nn.Module): DPRNN module to model sequential feature.
+ branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score.
+ """
+
+ def __init__(
+ self,
+ encoder: nn.Module,
+ dprnn: nn.Module,
+ branches: nn.ModuleList,
+ ):
+ super(SquimObjective, self).__init__()
+ self.encoder = encoder
+ self.dprnn = dprnn
+ self.branches = branches
+
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ """
+ Args:
+ x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
+
+ Returns:
+ List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`.
+ """
+ if x.ndim != 2:
+ raise ValueError(
+ f"The input must be a 2D Tensor. Found dimension {x.ndim}."
+ )
+ x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
+ out = self.encoder(x)
+ out = self.dprnn(out)
+ scores = []
+ for branch in self.branches:
+ scores.append(branch(out).squeeze(dim=1))
+ return scores
+
+
+def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
+ """Create branch module after DPRNN model for predicting metric score.
+
+ Args:
+ d_model (int): The number of expected features in the input.
+ nhead (int): Number of heads in the multi-head attention model.
+ metric (str): The metric name to predict.
+
+ Returns:
+ (nn.Module): Returned module to predict corresponding metric score.
+ """
+ layer1 = nn.TransformerEncoderLayer(
+ d_model, nhead, d_model * 4, dropout=0.0, batch_first=True
+ )
+ layer2 = AutoPool()
+ if metric == "stoi":
+ layer3 = nn.Sequential(
+ nn.Linear(d_model, d_model),
+ nn.PReLU(),
+ nn.Linear(d_model, 1),
+ RangeSigmoid(),
+ )
+ elif metric == "pesq":
+ layer3 = nn.Sequential(
+ nn.Linear(d_model, d_model),
+ nn.PReLU(),
+ nn.Linear(d_model, 1),
+ RangeSigmoid(val_range=PESQRange),
+ )
+ else:
+ layer3: nn.modules.Module = nn.Sequential(
+ nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1)
+ )
+ return nn.Sequential(layer1, layer2, layer3)
+
+
+def squim_objective_model(
+ feat_dim: int,
+ win_len: int,
+ d_model: int,
+ nhead: int,
+ hidden_dim: int,
+ num_blocks: int,
+ rnn_type: str,
+ chunk_size: int,
+ chunk_stride: Optional[int] = None,
+) -> SquimObjective:
+ """Build a custome :class:`torchaudio.prototype.models.SquimObjective` model.
+
+ Args:
+ feat_dim (int, optional): The feature dimension after Encoder module.
+ win_len (int): Kernel size in the Encoder module.
+ d_model (int): The number of expected features in the input.
+ nhead (int): Number of heads in the multi-head attention model.
+ hidden_dim (int): Hidden dimension in the RNN layer of DPRNN.
+ num_blocks (int): Number of DPRNN layers.
+ rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"].
+ chunk_size (int): Chunk size of input for DPRNN.
+ chunk_stride (int or None, optional): Stride of chunk input for DPRNN.
+ """
+ if chunk_stride is None:
+ chunk_stride = chunk_size // 2
+ encoder = Encoder(feat_dim, win_len)
+ dprnn = DPRNN(
+ feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride
+ )
+ branches = nn.ModuleList(
+ [
+ _create_branch(d_model, nhead, "stoi"),
+ _create_branch(d_model, nhead, "pesq"),
+ _create_branch(d_model, nhead, "sisdr"),
+ ]
+ )
+ return SquimObjective(encoder, dprnn, branches)
+
+
+def squim_objective_base() -> SquimObjective:
+ """Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments."""
+ return squim_objective_model(
+ feat_dim=256,
+ win_len=64,
+ d_model=256,
+ nhead=4,
+ hidden_dim=256,
+ num_blocks=2,
+ rnn_type="LSTM",
+ chunk_size=71,
+ )
+
+
+@dataclass
+class SquimObjectiveBundle:
+
+ _path: str
+ _sample_rate: float
+
+ def _get_state_dict(self, dl_kwargs):
+ url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
+ state_dict = load_state_dict_from_url(url, **dl_kwargs)
+ return state_dict
+
+ def get_model(self, *, dl_kwargs=None) -> SquimObjective:
+ """Construct the SquimObjective model, and load the pretrained weight.
+
+ The weight file is downloaded from the internet and cached with
+ :func:`torch.hub.load_state_dict_from_url`
+
+ Args:
+ dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
+
+ Returns:
+ Variation of :py:class:`~torchaudio.models.SquimObjective`.
+ """
+ model = squim_objective_base()
+ model.load_state_dict(self._get_state_dict(dl_kwargs))
+ model.eval()
+ return model
+
+ @property
+ def sample_rate(self):
+ """Sample rate of the audio that the model is trained on.
+
+ :type: float
+ """
+ return self._sample_rate
+
+
+SQUIM_OBJECTIVE = SquimObjectiveBundle(
+ "squim_objective_dns2020.pth",
+ _sample_rate=16000,
+)
+SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
+ :cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
+
+ The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
+ The weights are under `Creative Commons Attribution 4.0 International License
+ `__.
+
+ Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
+ """
diff --git a/mvsepless/models/bandit/core/metrics/snr.py b/mvsepless/models/bandit/core/metrics/snr.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b7a168756204ddd8044c7c74e31c15dc1643aa1
--- /dev/null
+++ b/mvsepless/models/bandit/core/metrics/snr.py
@@ -0,0 +1,127 @@
+from typing import Any, Callable
+
+import numpy as np
+import torch
+import torchmetrics as tm
+from torch._C import _LinAlgError
+from torchmetrics import functional as tmF
+
+
+class SafeSignalDistortionRatio(tm.SignalDistortionRatio):
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+
+ def update(self, *args, **kwargs) -> Any:
+ try:
+ super().update(*args, **kwargs)
+ except:
+ pass
+
+ def compute(self) -> Any:
+ if self.total == 0:
+ return torch.tensor(torch.nan)
+ return super().compute()
+
+
+class BaseChunkMedianSignalRatio(tm.Metric):
+ def __init__(
+ self,
+ func: Callable,
+ window_size: int,
+ hop_size: int = None,
+ zero_mean: bool = False,
+ ) -> None:
+ super().__init__()
+
+ # self.zero_mean = zero_mean
+ self.func = func
+ self.window_size = window_size
+ if hop_size is None:
+ hop_size = window_size
+ self.hop_size = hop_size
+
+ self.add_state("sum_snr", default=torch.tensor(0.0), dist_reduce_fx="sum")
+ self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
+
+ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
+
+ n_samples = target.shape[-1]
+
+ n_chunks = int(np.ceil((n_samples - self.window_size) / self.hop_size) + 1)
+
+ snr_chunk = []
+
+ for i in range(n_chunks):
+ start = i * self.hop_size
+
+ if n_samples - start < self.window_size:
+ continue
+
+ end = start + self.window_size
+
+ try:
+ chunk_snr = self.func(preds[..., start:end], target[..., start:end])
+
+ # print(preds.shape, chunk_snr.shape)
+
+ if torch.all(torch.isfinite(chunk_snr)):
+ snr_chunk.append(chunk_snr)
+ except _LinAlgError:
+ pass
+
+ snr_chunk = torch.stack(snr_chunk, dim=-1)
+ snr_batch, _ = torch.nanmedian(snr_chunk, dim=-1)
+
+ self.sum_snr += snr_batch.sum()
+ self.total += snr_batch.numel()
+
+ def compute(self) -> Any:
+ return self.sum_snr / self.total
+
+
+class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio):
+ def __init__(
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
+ ) -> None:
+ super().__init__(
+ func=tmF.signal_noise_ratio,
+ window_size=window_size,
+ hop_size=hop_size,
+ zero_mean=zero_mean,
+ )
+
+
+class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio):
+ def __init__(
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
+ ) -> None:
+ super().__init__(
+ func=tmF.scale_invariant_signal_noise_ratio,
+ window_size=window_size,
+ hop_size=hop_size,
+ zero_mean=zero_mean,
+ )
+
+
+class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio):
+ def __init__(
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
+ ) -> None:
+ super().__init__(
+ func=tmF.signal_distortion_ratio,
+ window_size=window_size,
+ hop_size=hop_size,
+ zero_mean=zero_mean,
+ )
+
+
+class ChunkMedianScaleInvariantSignalDistortionRatio(BaseChunkMedianSignalRatio):
+ def __init__(
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
+ ) -> None:
+ super().__init__(
+ func=tmF.scale_invariant_signal_distortion_ratio,
+ window_size=window_size,
+ hop_size=hop_size,
+ zero_mean=zero_mean,
+ )
diff --git a/mvsepless/models/bandit/core/model/__init__.py b/mvsepless/models/bandit/core/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..54ac48eb69d6f844ba5b73b213eae4cfab157cac
--- /dev/null
+++ b/mvsepless/models/bandit/core/model/__init__.py
@@ -0,0 +1,3 @@
+from .bsrnn.wrapper import (
+ MultiMaskMultiSourceBandSplitRNNSimple,
+)
diff --git a/mvsepless/models/bandit/core/model/_spectral.py b/mvsepless/models/bandit/core/model/_spectral.py
new file mode 100644
index 0000000000000000000000000000000000000000..6af5cbd0dcb6ed0a4babd6b8554184d91c406655
--- /dev/null
+++ b/mvsepless/models/bandit/core/model/_spectral.py
@@ -0,0 +1,54 @@
+from typing import Dict, Optional
+
+import torch
+import torchaudio as ta
+from torch import nn
+
+
+class _SpectralComponent(nn.Module):
+ def __init__(
+ self,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+
+ assert power is None
+
+ window_fn = torch.__dict__[window_fn]
+
+ self.stft = ta.transforms.Spectrogram(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ pad_mode=pad_mode,
+ pad=0,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ normalized=normalized,
+ center=center,
+ onesided=onesided,
+ )
+
+ self.istft = ta.transforms.InverseSpectrogram(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ pad_mode=pad_mode,
+ pad=0,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ normalized=normalized,
+ center=center,
+ onesided=onesided,
+ )
diff --git a/mvsepless/models/bandit/core/model/bsrnn/__init__.py b/mvsepless/models/bandit/core/model/bsrnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4736e9dec4f2e731ebf34bd51bf4ce1af728b269
--- /dev/null
+++ b/mvsepless/models/bandit/core/model/bsrnn/__init__.py
@@ -0,0 +1,23 @@
+from abc import ABC
+from typing import Iterable, Mapping, Union
+
+from torch import nn
+
+from .bandsplit import BandSplitModule
+from .tfmodel import (
+ SeqBandModellingModule,
+ TransformerTimeFreqModule,
+)
+
+
+class BandsplitCoreBase(nn.Module, ABC):
+ band_split: nn.Module
+ tf_model: nn.Module
+ mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]]
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ @staticmethod
+ def mask(x, m):
+ return x * m
diff --git a/mvsepless/models/bandit/core/model/bsrnn/bandsplit.py b/mvsepless/models/bandit/core/model/bsrnn/bandsplit.py
new file mode 100644
index 0000000000000000000000000000000000000000..171c8800e5b45fa6bb8dff800d6ba61be8328375
--- /dev/null
+++ b/mvsepless/models/bandit/core/model/bsrnn/bandsplit.py
@@ -0,0 +1,135 @@
+from typing import List, Tuple
+
+import torch
+from torch import nn
+
+from .utils import (
+ band_widths_from_specs,
+ check_no_gap,
+ check_no_overlap,
+ check_nonzero_bandwidth,
+)
+
+
+class NormFC(nn.Module):
+ def __init__(
+ self,
+ emb_dim: int,
+ bandwidth: int,
+ in_channel: int,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.treat_channel_as_feature = treat_channel_as_feature
+
+ if normalize_channel_independently:
+ raise NotImplementedError
+
+ reim = 2
+
+ self.norm = nn.LayerNorm(in_channel * bandwidth * reim)
+
+ fc_in = bandwidth * reim
+
+ if treat_channel_as_feature:
+ fc_in *= in_channel
+ else:
+ assert emb_dim % in_channel == 0
+ emb_dim = emb_dim // in_channel
+
+ self.fc = nn.Linear(fc_in, emb_dim)
+
+ def forward(self, xb):
+ # xb = (batch, n_time, in_chan, reim * band_width)
+
+ batch, n_time, in_chan, ribw = xb.shape
+ xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw))
+ # (batch, n_time, in_chan * reim * band_width)
+
+ if not self.treat_channel_as_feature:
+ xb = xb.reshape(batch, n_time, in_chan, ribw)
+ # (batch, n_time, in_chan, reim * band_width)
+
+ zb = self.fc(xb)
+ # (batch, n_time, emb_dim)
+ # OR
+ # (batch, n_time, in_chan, emb_dim_per_chan)
+
+ if not self.treat_channel_as_feature:
+ batch, n_time, in_chan, emb_dim_per_chan = zb.shape
+ # (batch, n_time, in_chan, emb_dim_per_chan)
+ zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan))
+
+ return zb # (batch, n_time, emb_dim)
+
+
+class BandSplitModule(nn.Module):
+ def __init__(
+ self,
+ band_specs: List[Tuple[float, float]],
+ emb_dim: int,
+ in_channel: int,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ ) -> None:
+ super().__init__()
+
+ check_nonzero_bandwidth(band_specs)
+
+ if require_no_gap:
+ check_no_gap(band_specs)
+
+ if require_no_overlap:
+ check_no_overlap(band_specs)
+
+ self.band_specs = band_specs
+ # list of [fstart, fend) in index.
+ # Note that fend is exclusive.
+ self.band_widths = band_widths_from_specs(band_specs)
+ self.n_bands = len(band_specs)
+ self.emb_dim = emb_dim
+
+ self.norm_fc_modules = nn.ModuleList(
+ [ # type: ignore
+ (
+ NormFC(
+ emb_dim=emb_dim,
+ bandwidth=bw,
+ in_channel=in_channel,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ )
+ )
+ for bw in self.band_widths
+ ]
+ )
+
+ def forward(self, x: torch.Tensor):
+ # x = complex spectrogram (batch, in_chan, n_freq, n_time)
+
+ batch, in_chan, _, n_time = x.shape
+
+ z = torch.zeros(
+ size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
+ )
+
+ xr = torch.view_as_real(x) # batch, in_chan, n_freq, n_time, 2
+ xr = torch.permute(xr, (0, 3, 1, 4, 2)) # batch, n_time, in_chan, 2, n_freq
+ batch, n_time, in_chan, reim, band_width = xr.shape
+ for i, nfm in enumerate(self.norm_fc_modules):
+ # print(f"bandsplit/band{i:02d}")
+ fstart, fend = self.band_specs[i]
+ xb = xr[..., fstart:fend]
+ # (batch, n_time, in_chan, reim, band_width)
+ xb = torch.reshape(xb, (batch, n_time, in_chan, -1))
+ # (batch, n_time, in_chan, reim * band_width)
+ # z.append(nfm(xb)) # (batch, n_time, emb_dim)
+ z[:, i, :, :] = nfm(xb.contiguous())
+
+ # z = torch.stack(z, dim=1)
+
+ return z
diff --git a/mvsepless/models/bandit/core/model/bsrnn/core.py b/mvsepless/models/bandit/core/model/bsrnn/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..412b37dc6994ce18f227021c5458c12e06e600a5
--- /dev/null
+++ b/mvsepless/models/bandit/core/model/bsrnn/core.py
@@ -0,0 +1,651 @@
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from . import BandsplitCoreBase
+from .bandsplit import BandSplitModule
+from .maskestim import (
+ MaskEstimationModule,
+ OverlappingMaskEstimationModule,
+)
+from .tfmodel import (
+ ConvolutionalTimeFreqModule,
+ SeqBandModellingModule,
+ TransformerTimeFreqModule,
+)
+
+
+class MultiMaskBandSplitCoreBase(BandsplitCoreBase):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def forward(self, x, cond=None, compute_residual: bool = True):
+ # x = complex spectrogram (batch, in_chan, n_freq, n_time)
+ # print(x.shape)
+ batch, in_chan, n_freq, n_time = x.shape
+ x = torch.reshape(x, (-1, 1, n_freq, n_time))
+
+ z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
+
+ # if torch.any(torch.isnan(z)):
+ # raise ValueError("z nan")
+
+ # print(z)
+ q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
+ # print(q)
+
+ # if torch.any(torch.isnan(q)):
+ # raise ValueError("q nan")
+
+ out = {}
+
+ for stem, mem in self.mask_estim.items():
+ m = mem(q, cond=cond)
+
+ # if torch.any(torch.isnan(m)):
+ # raise ValueError("m nan", stem)
+
+ s = self.mask(x, m)
+ s = torch.reshape(s, (batch, in_chan, n_freq, n_time))
+ out[stem] = s
+
+ return {"spectrogram": out}
+
+ def instantiate_mask_estim(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: List[Tuple[float, float]],
+ emb_dim: int,
+ mlp_dim: int,
+ cond_dim: int,
+ hidden_activation: str,
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ overlapping_band: bool = False,
+ freq_weights: Optional[List[torch.Tensor]] = None,
+ n_freq: Optional[int] = None,
+ use_freq_weights: bool = True,
+ mult_add_mask: bool = False,
+ ):
+ if hidden_activation_kwargs is None:
+ hidden_activation_kwargs = {}
+
+ if "mne:+" in stems:
+ stems = [s for s in stems if s != "mne:+"]
+
+ if overlapping_band:
+ assert freq_weights is not None
+ assert n_freq is not None
+
+ if mult_add_mask:
+
+ self.mask_estim = nn.ModuleDict(
+ {
+ stem: MultAddMaskEstimationModule(
+ band_specs=band_specs,
+ freq_weights=freq_weights,
+ n_freq=n_freq,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ use_freq_weights=use_freq_weights,
+ )
+ for stem in stems
+ }
+ )
+ else:
+ self.mask_estim = nn.ModuleDict(
+ {
+ stem: OverlappingMaskEstimationModule(
+ band_specs=band_specs,
+ freq_weights=freq_weights,
+ n_freq=n_freq,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ use_freq_weights=use_freq_weights,
+ )
+ for stem in stems
+ }
+ )
+ else:
+ self.mask_estim = nn.ModuleDict(
+ {
+ stem: MaskEstimationModule(
+ band_specs=band_specs,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+ for stem in stems
+ }
+ )
+
+ def instantiate_bandsplit(
+ self,
+ in_channel: int,
+ band_specs: List[Tuple[float, float]],
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ emb_dim: int = 128,
+ ):
+ self.band_split = BandSplitModule(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
+ )
+
+
+class SingleMaskBandsplitCoreBase(BandsplitCoreBase):
+ def __init__(self, **kwargs) -> None:
+ super().__init__()
+
+ def forward(self, x):
+ # x = complex spectrogram (batch, in_chan, n_freq, n_time)
+ z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
+ q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
+ m = self.mask_estim(q) # (batch, in_chan, n_freq, n_time)
+
+ s = self.mask(x, m)
+
+ return s
+
+
+class SingleMaskBandsplitCoreRNN(
+ SingleMaskBandsplitCoreBase,
+):
+ def __init__(
+ self,
+ in_channel: int,
+ band_specs: List[Tuple[float, float]],
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ ) -> None:
+ super().__init__()
+ self.band_split = BandSplitModule(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
+ )
+ self.tf_model = SeqBandModellingModule(
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ )
+ self.mask_estim = MaskEstimationModule(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+
+
+class SingleMaskBandsplitCoreTransformer(
+ SingleMaskBandsplitCoreBase,
+):
+ def __init__(
+ self,
+ in_channel: int,
+ band_specs: List[Tuple[float, float]],
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ tf_dropout: float = 0.0,
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ ) -> None:
+ super().__init__()
+ self.band_split = BandSplitModule(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
+ )
+ self.tf_model = TransformerTimeFreqModule(
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ dropout=tf_dropout,
+ )
+ self.mask_estim = MaskEstimationModule(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+
+
+class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase):
+ def __init__(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: List[Tuple[float, float]],
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ cond_dim: int = 0,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ overlapping_band: bool = False,
+ freq_weights: Optional[List[torch.Tensor]] = None,
+ n_freq: Optional[int] = None,
+ use_freq_weights: bool = True,
+ mult_add_mask: bool = False,
+ ) -> None:
+
+ super().__init__()
+ self.instantiate_bandsplit(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
+ )
+
+ self.tf_model = SeqBandModellingModule(
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ )
+
+ self.mult_add_mask = mult_add_mask
+
+ self.instantiate_mask_estim(
+ in_channel=in_channel,
+ stems=stems,
+ band_specs=band_specs,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=overlapping_band,
+ freq_weights=freq_weights,
+ n_freq=n_freq,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask,
+ )
+
+ @staticmethod
+ def _mult_add_mask(x, m):
+
+ assert m.ndim == 5
+
+ mm = m[..., 0]
+ am = m[..., 1]
+
+ # print(mm.shape, am.shape, x.shape, m.shape)
+
+ return x * mm + am
+
+ def mask(self, x, m):
+ if self.mult_add_mask:
+
+ return self._mult_add_mask(x, m)
+ else:
+ return super().mask(x, m)
+
+
+class MultiSourceMultiMaskBandSplitCoreTransformer(
+ MultiMaskBandSplitCoreBase,
+):
+ def __init__(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: List[Tuple[float, float]],
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ tf_dropout: float = 0.0,
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ overlapping_band: bool = False,
+ freq_weights: Optional[List[torch.Tensor]] = None,
+ n_freq: Optional[int] = None,
+ use_freq_weights: bool = True,
+ rnn_type: str = "LSTM",
+ cond_dim: int = 0,
+ mult_add_mask: bool = False,
+ ) -> None:
+ super().__init__()
+ self.instantiate_bandsplit(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
+ )
+ self.tf_model = TransformerTimeFreqModule(
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ dropout=tf_dropout,
+ )
+
+ self.instantiate_mask_estim(
+ in_channel=in_channel,
+ stems=stems,
+ band_specs=band_specs,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=overlapping_band,
+ freq_weights=freq_weights,
+ n_freq=n_freq,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask,
+ )
+
+
+class MultiSourceMultiMaskBandSplitCoreConv(
+ MultiMaskBandSplitCoreBase,
+):
+ def __init__(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: List[Tuple[float, float]],
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ tf_dropout: float = 0.0,
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ overlapping_band: bool = False,
+ freq_weights: Optional[List[torch.Tensor]] = None,
+ n_freq: Optional[int] = None,
+ use_freq_weights: bool = True,
+ rnn_type: str = "LSTM",
+ cond_dim: int = 0,
+ mult_add_mask: bool = False,
+ ) -> None:
+ super().__init__()
+ self.instantiate_bandsplit(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
+ )
+ self.tf_model = ConvolutionalTimeFreqModule(
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ dropout=tf_dropout,
+ )
+
+ self.instantiate_mask_estim(
+ in_channel=in_channel,
+ stems=stems,
+ band_specs=band_specs,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=overlapping_band,
+ freq_weights=freq_weights,
+ n_freq=n_freq,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask,
+ )
+
+
+class PatchingMaskBandsplitCoreBase(MultiMaskBandSplitCoreBase):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def mask(self, x, m):
+ # x.shape = (batch, n_channel, n_freq, n_time)
+ # m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time)
+
+ _, n_channel, kernel_freq, kernel_time, n_freq, n_time = m.shape
+ padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2)
+
+ xf = F.unfold(
+ x,
+ kernel_size=(kernel_freq, kernel_time),
+ padding=padding,
+ stride=(1, 1),
+ )
+
+ xf = xf.view(
+ -1,
+ n_channel,
+ kernel_freq,
+ kernel_time,
+ n_freq,
+ n_time,
+ )
+
+ sf = xf * m
+
+ sf = sf.view(
+ -1,
+ n_channel * kernel_freq * kernel_time,
+ n_freq * n_time,
+ )
+
+ s = F.fold(
+ sf,
+ output_size=(n_freq, n_time),
+ kernel_size=(kernel_freq, kernel_time),
+ padding=padding,
+ stride=(1, 1),
+ ).view(
+ -1,
+ n_channel,
+ n_freq,
+ n_time,
+ )
+
+ return s
+
+ def old_mask(self, x, m):
+ # x.shape = (batch, n_channel, n_freq, n_time)
+ # m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time)
+
+ s = torch.zeros_like(x)
+
+ _, n_channel, n_freq, n_time = x.shape
+ kernel_freq, kernel_time, _, _, _, _ = m.shape
+
+ # print(x.shape, m.shape)
+
+ kernel_freq_half = (kernel_freq - 1) // 2
+ kernel_time_half = (kernel_time - 1) // 2
+
+ for ifreq in range(kernel_freq):
+ for itime in range(kernel_time):
+ df, dt = kernel_freq_half - ifreq, kernel_time_half - itime
+ x = x.roll(shifts=(df, dt), dims=(2, 3))
+
+ # if `df` > 0:
+ # x[:, :, :df, :] = 0
+ # elif `df` < 0:
+ # x[:, :, df:, :] = 0
+
+ # if `dt` > 0:
+ # x[:, :, :, :dt] = 0
+ # elif `dt` < 0:
+ # x[:, :, :, dt:] = 0
+
+ fslice = slice(max(0, df), min(n_freq, n_freq + df))
+ tslice = slice(max(0, dt), min(n_time, n_time + dt))
+
+ s[:, :, fslice, tslice] += (
+ x[:, :, fslice, tslice] * m[ifreq, itime, :, :, fslice, tslice]
+ )
+
+ return s
+
+
+class MultiSourceMultiPatchingMaskBandSplitCoreRNN(PatchingMaskBandsplitCoreBase):
+ def __init__(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: List[Tuple[float, float]],
+ mask_kernel_freq: int,
+ mask_kernel_time: int,
+ conv_kernel_freq: int,
+ conv_kernel_time: int,
+ kernel_norm_mlp_version: int,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ overlapping_band: bool = False,
+ freq_weights: Optional[List[torch.Tensor]] = None,
+ n_freq: Optional[int] = None,
+ ) -> None:
+
+ super().__init__()
+ self.band_split = BandSplitModule(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
+ )
+
+ self.tf_model = SeqBandModellingModule(
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ )
+
+ if hidden_activation_kwargs is None:
+ hidden_activation_kwargs = {}
+
+ if overlapping_band:
+ assert freq_weights is not None
+ assert n_freq is not None
+ self.mask_estim = nn.ModuleDict(
+ {
+ stem: PatchingMaskEstimationModule(
+ band_specs=band_specs,
+ freq_weights=freq_weights,
+ n_freq=n_freq,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ mask_kernel_freq=mask_kernel_freq,
+ mask_kernel_time=mask_kernel_time,
+ conv_kernel_freq=conv_kernel_freq,
+ conv_kernel_time=conv_kernel_time,
+ kernel_norm_mlp_version=kernel_norm_mlp_version,
+ )
+ for stem in stems
+ }
+ )
+ else:
+ raise NotImplementedError
diff --git a/mvsepless/models/bandit/core/model/bsrnn/maskestim.py b/mvsepless/models/bandit/core/model/bsrnn/maskestim.py
new file mode 100644
index 0000000000000000000000000000000000000000..067f79ce28b5e39de1259710b3444e837c4eb960
--- /dev/null
+++ b/mvsepless/models/bandit/core/model/bsrnn/maskestim.py
@@ -0,0 +1,351 @@
+import warnings
+from typing import Dict, List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch.nn.modules import activation
+
+from .utils import (
+ band_widths_from_specs,
+ check_no_gap,
+ check_no_overlap,
+ check_nonzero_bandwidth,
+)
+
+
+class BaseNormMLP(nn.Module):
+ def __init__(
+ self,
+ emb_dim: int,
+ mlp_dim: int,
+ bandwidth: int,
+ in_channel: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs=None,
+ complex_mask: bool = True,
+ ):
+
+ super().__init__()
+ if hidden_activation_kwargs is None:
+ hidden_activation_kwargs = {}
+ self.hidden_activation_kwargs = hidden_activation_kwargs
+ self.norm = nn.LayerNorm(emb_dim)
+ self.hidden = torch.jit.script(
+ nn.Sequential(
+ nn.Linear(in_features=emb_dim, out_features=mlp_dim),
+ activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
+ )
+ )
+
+ self.bandwidth = bandwidth
+ self.in_channel = in_channel
+
+ self.complex_mask = complex_mask
+ self.reim = 2 if complex_mask else 1
+ self.glu_mult = 2
+
+
+class NormMLP(BaseNormMLP):
+ def __init__(
+ self,
+ emb_dim: int,
+ mlp_dim: int,
+ bandwidth: int,
+ in_channel: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs=None,
+ complex_mask: bool = True,
+ ) -> None:
+ super().__init__(
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ bandwidth=bandwidth,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+
+ self.output = torch.jit.script(
+ nn.Sequential(
+ nn.Linear(
+ in_features=mlp_dim,
+ out_features=bandwidth * in_channel * self.reim * 2,
+ ),
+ nn.GLU(dim=-1),
+ )
+ )
+
+ def reshape_output(self, mb):
+ # print(mb.shape)
+ batch, n_time, _ = mb.shape
+ if self.complex_mask:
+ mb = mb.reshape(
+ batch, n_time, self.in_channel, self.bandwidth, self.reim
+ ).contiguous()
+ # print(mb.shape)
+ mb = torch.view_as_complex(mb) # (batch, n_time, in_channel, bandwidth)
+ else:
+ mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth)
+
+ mb = torch.permute(mb, (0, 2, 3, 1)) # (batch, in_channel, bandwidth, n_time)
+
+ return mb
+
+ def forward(self, qb):
+ # qb = (batch, n_time, emb_dim)
+
+ # if torch.any(torch.isnan(qb)):
+ # raise ValueError("qb0")
+
+ qb = self.norm(qb) # (batch, n_time, emb_dim)
+
+ # if torch.any(torch.isnan(qb)):
+ # raise ValueError("qb1")
+
+ qb = self.hidden(qb) # (batch, n_time, mlp_dim)
+ # if torch.any(torch.isnan(qb)):
+ # raise ValueError("qb2")
+ mb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim)
+ # if torch.any(torch.isnan(qb)):
+ # raise ValueError("mb")
+ mb = self.reshape_output(mb) # (batch, in_channel, bandwidth, n_time)
+
+ return mb
+
+
+class MultAddNormMLP(NormMLP):
+ def __init__(
+ self,
+ emb_dim: int,
+ mlp_dim: int,
+ bandwidth: int,
+ in_channel: "int | None",
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs=None,
+ complex_mask: bool = True,
+ ) -> None:
+ super().__init__(
+ emb_dim,
+ mlp_dim,
+ bandwidth,
+ in_channel,
+ hidden_activation,
+ hidden_activation_kwargs,
+ complex_mask,
+ )
+
+ self.output2 = torch.jit.script(
+ nn.Sequential(
+ nn.Linear(
+ in_features=mlp_dim,
+ out_features=bandwidth * in_channel * self.reim * 2,
+ ),
+ nn.GLU(dim=-1),
+ )
+ )
+
+ def forward(self, qb):
+
+ qb = self.norm(qb) # (batch, n_time, emb_dim)
+ qb = self.hidden(qb) # (batch, n_time, mlp_dim)
+ mmb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim)
+ mmb = self.reshape_output(mmb) # (batch, in_channel, bandwidth, n_time)
+ amb = self.output2(qb) # (batch, n_time, bandwidth * in_channel * reim)
+ amb = self.reshape_output(amb) # (batch, in_channel, bandwidth, n_time)
+
+ return mmb, amb
+
+
+class MaskEstimationModuleSuperBase(nn.Module):
+ pass
+
+
+class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
+ def __init__(
+ self,
+ band_specs: List[Tuple[float, float]],
+ emb_dim: int,
+ mlp_dim: int,
+ in_channel: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Dict = None,
+ complex_mask: bool = True,
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
+ norm_mlp_kwargs: Dict = None,
+ ) -> None:
+ super().__init__()
+
+ self.band_widths = band_widths_from_specs(band_specs)
+ self.n_bands = len(band_specs)
+
+ if hidden_activation_kwargs is None:
+ hidden_activation_kwargs = {}
+
+ if norm_mlp_kwargs is None:
+ norm_mlp_kwargs = {}
+
+ self.norm_mlp = nn.ModuleList(
+ [
+ (
+ norm_mlp_cls(
+ bandwidth=self.band_widths[b],
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ **norm_mlp_kwargs,
+ )
+ )
+ for b in range(self.n_bands)
+ ]
+ )
+
+ def compute_masks(self, q):
+ batch, n_bands, n_time, emb_dim = q.shape
+
+ masks = []
+
+ for b, nmlp in enumerate(self.norm_mlp):
+ # print(f"maskestim/{b:02d}")
+ qb = q[:, b, :, :]
+ mb = nmlp(qb)
+ masks.append(mb)
+
+ return masks
+
+
+class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
+ def __init__(
+ self,
+ in_channel: int,
+ band_specs: List[Tuple[float, float]],
+ freq_weights: List[torch.Tensor],
+ n_freq: int,
+ emb_dim: int,
+ mlp_dim: int,
+ cond_dim: int = 0,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Dict = None,
+ complex_mask: bool = True,
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
+ norm_mlp_kwargs: Dict = None,
+ use_freq_weights: bool = True,
+ ) -> None:
+ check_nonzero_bandwidth(band_specs)
+ check_no_gap(band_specs)
+
+ # if cond_dim > 0:
+ # raise NotImplementedError
+
+ super().__init__(
+ band_specs=band_specs,
+ emb_dim=emb_dim + cond_dim,
+ mlp_dim=mlp_dim,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ norm_mlp_cls=norm_mlp_cls,
+ norm_mlp_kwargs=norm_mlp_kwargs,
+ )
+
+ self.n_freq = n_freq
+ self.band_specs = band_specs
+ self.in_channel = in_channel
+
+ if freq_weights is not None:
+ for i, fw in enumerate(freq_weights):
+ self.register_buffer(f"freq_weights/{i}", fw)
+
+ self.use_freq_weights = use_freq_weights
+ else:
+ self.use_freq_weights = False
+
+ self.cond_dim = cond_dim
+
+ def forward(self, q, cond=None):
+ # q = (batch, n_bands, n_time, emb_dim)
+
+ batch, n_bands, n_time, emb_dim = q.shape
+
+ if cond is not None:
+ print(cond)
+ if cond.ndim == 2:
+ cond = cond[:, None, None, :].expand(-1, n_bands, n_time, -1)
+ elif cond.ndim == 3:
+ assert cond.shape[1] == n_time
+ else:
+ raise ValueError(f"Invalid cond shape: {cond.shape}")
+
+ q = torch.cat([q, cond], dim=-1)
+ elif self.cond_dim > 0:
+ cond = torch.ones(
+ (batch, n_bands, n_time, self.cond_dim),
+ device=q.device,
+ dtype=q.dtype,
+ )
+ q = torch.cat([q, cond], dim=-1)
+ else:
+ pass
+
+ mask_list = self.compute_masks(
+ q
+ ) # [n_bands * (batch, in_channel, bandwidth, n_time)]
+
+ masks = torch.zeros(
+ (batch, self.in_channel, self.n_freq, n_time),
+ device=q.device,
+ dtype=mask_list[0].dtype,
+ )
+
+ for im, mask in enumerate(mask_list):
+ fstart, fend = self.band_specs[im]
+ if self.use_freq_weights:
+ fw = self.get_buffer(f"freq_weights/{im}")[:, None]
+ mask = mask * fw
+ masks[:, :, fstart:fend, :] += mask
+
+ return masks
+
+
+class MaskEstimationModule(OverlappingMaskEstimationModule):
+ def __init__(
+ self,
+ band_specs: List[Tuple[float, float]],
+ emb_dim: int,
+ mlp_dim: int,
+ in_channel: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Dict = None,
+ complex_mask: bool = True,
+ **kwargs,
+ ) -> None:
+ check_nonzero_bandwidth(band_specs)
+ check_no_gap(band_specs)
+ check_no_overlap(band_specs)
+ super().__init__(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ freq_weights=None,
+ n_freq=None,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+
+ def forward(self, q, cond=None):
+ # q = (batch, n_bands, n_time, emb_dim)
+
+ masks = self.compute_masks(
+ q
+ ) # [n_bands * (batch, in_channel, bandwidth, n_time)]
+
+ # TODO: currently this requires band specs to have no gap and no overlap
+ masks = torch.concat(masks, dim=2) # (batch, in_channel, n_freq, n_time)
+
+ return masks
diff --git a/mvsepless/models/bandit/core/model/bsrnn/tfmodel.py b/mvsepless/models/bandit/core/model/bsrnn/tfmodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..f482a118f5a9ac7f9f2d5bc36725155b3d8049db
--- /dev/null
+++ b/mvsepless/models/bandit/core/model/bsrnn/tfmodel.py
@@ -0,0 +1,320 @@
+import warnings
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.modules import rnn
+
+import torch.backends.cuda
+
+
+class TimeFrequencyModellingModule(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+
+class ResidualRNN(nn.Module):
+ def __init__(
+ self,
+ emb_dim: int,
+ rnn_dim: int,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ use_batch_trick: bool = True,
+ use_layer_norm: bool = True,
+ ) -> None:
+ # n_group is the size of the 2nd dim
+ super().__init__()
+
+ self.use_layer_norm = use_layer_norm
+ if use_layer_norm:
+ self.norm = nn.LayerNorm(emb_dim)
+ else:
+ self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim)
+
+ self.rnn = rnn.__dict__[rnn_type](
+ input_size=emb_dim,
+ hidden_size=rnn_dim,
+ num_layers=1,
+ batch_first=True,
+ bidirectional=bidirectional,
+ )
+
+ self.fc = nn.Linear(
+ in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
+ )
+
+ self.use_batch_trick = use_batch_trick
+ if not self.use_batch_trick:
+ warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
+
+ def forward(self, z):
+ # z = (batch, n_uncrossed, n_across, emb_dim)
+
+ z0 = torch.clone(z)
+
+ # print(z.device)
+
+ if self.use_layer_norm:
+ z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim)
+ else:
+ z = torch.permute(
+ z, (0, 3, 1, 2)
+ ) # (batch, emb_dim, n_uncrossed, n_across)
+
+ z = self.norm(z) # (batch, emb_dim, n_uncrossed, n_across)
+
+ z = torch.permute(
+ z, (0, 2, 3, 1)
+ ) # (batch, n_uncrossed, n_across, emb_dim)
+
+ batch, n_uncrossed, n_across, emb_dim = z.shape
+
+ if self.use_batch_trick:
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
+
+ z = self.rnn(z.contiguous())[
+ 0
+ ] # (batch * n_uncrossed, n_across, dir_rnn_dim)
+
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
+ # (batch, n_uncrossed, n_across, dir_rnn_dim)
+ else:
+ # Note: this is EXTREMELY SLOW
+ zlist = []
+ for i in range(n_uncrossed):
+ zi = self.rnn(z[:, i, :, :])[0] # (batch, n_across, emb_dim)
+ zlist.append(zi)
+
+ z = torch.stack(zlist, dim=1) # (batch, n_uncrossed, n_across, dir_rnn_dim)
+
+ z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
+
+ z = z + z0
+
+ return z
+
+
+class SeqBandModellingModule(TimeFrequencyModellingModule):
+ def __init__(
+ self,
+ n_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ parallel_mode=False,
+ ) -> None:
+ super().__init__()
+ self.seqband = nn.ModuleList([])
+
+ if parallel_mode:
+ for _ in range(n_modules):
+ self.seqband.append(
+ nn.ModuleList(
+ [
+ ResidualRNN(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ ),
+ ResidualRNN(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ ),
+ ]
+ )
+ )
+ else:
+
+ for _ in range(2 * n_modules):
+ self.seqband.append(
+ ResidualRNN(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ )
+ )
+
+ self.parallel_mode = parallel_mode
+
+ def forward(self, z):
+ # z = (batch, n_bands, n_time, emb_dim)
+
+ if self.parallel_mode:
+ for sbm_pair in self.seqband:
+ # z: (batch, n_bands, n_time, emb_dim)
+ sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
+ zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
+ zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
+ z = zt + zf.transpose(1, 2)
+ else:
+ for sbm in self.seqband:
+ z = sbm(z)
+ z = z.transpose(1, 2)
+
+ # (batch, n_bands, n_time, emb_dim)
+ # --> (batch, n_time, n_bands, emb_dim)
+ # OR
+ # (batch, n_time, n_bands, emb_dim)
+ # --> (batch, n_bands, n_time, emb_dim)
+
+ q = z
+ return q # (batch, n_bands, n_time, emb_dim)
+
+
+class ResidualTransformer(nn.Module):
+ def __init__(
+ self,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ dropout: float = 0.0,
+ ) -> None:
+ # n_group is the size of the 2nd dim
+ super().__init__()
+
+ self.tf = nn.TransformerEncoderLayer(
+ d_model=emb_dim, nhead=4, dim_feedforward=rnn_dim, batch_first=True
+ )
+
+ self.is_causal = not bidirectional
+ self.dropout = dropout
+
+ def forward(self, z):
+ batch, n_uncrossed, n_across, emb_dim = z.shape
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
+ z = self.tf(
+ z, is_causal=self.is_causal
+ ) # (batch, n_uncrossed, n_across, emb_dim)
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim))
+
+ return z
+
+
+class TransformerTimeFreqModule(TimeFrequencyModellingModule):
+ def __init__(
+ self,
+ n_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ dropout: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.norm = nn.LayerNorm(emb_dim)
+ self.seqband = nn.ModuleList([])
+
+ for _ in range(2 * n_modules):
+ self.seqband.append(
+ ResidualTransformer(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ dropout=dropout,
+ )
+ )
+
+ def forward(self, z):
+ # z = (batch, n_bands, n_time, emb_dim)
+ z = self.norm(z) # (batch, n_bands, n_time, emb_dim)
+
+ for sbm in self.seqband:
+ z = sbm(z)
+ z = z.transpose(1, 2)
+
+ # (batch, n_bands, n_time, emb_dim)
+ # --> (batch, n_time, n_bands, emb_dim)
+ # OR
+ # (batch, n_time, n_bands, emb_dim)
+ # --> (batch, n_bands, n_time, emb_dim)
+
+ q = z
+ return q # (batch, n_bands, n_time, emb_dim)
+
+
+class ResidualConvolution(nn.Module):
+ def __init__(
+ self,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ dropout: float = 0.0,
+ ) -> None:
+ # n_group is the size of the 2nd dim
+ super().__init__()
+ self.norm = nn.InstanceNorm2d(emb_dim, affine=True)
+
+ self.conv = nn.Sequential(
+ nn.Conv2d(
+ in_channels=emb_dim,
+ out_channels=rnn_dim,
+ kernel_size=(3, 3),
+ padding="same",
+ stride=(1, 1),
+ ),
+ nn.Tanhshrink(),
+ )
+
+ self.is_causal = not bidirectional
+ self.dropout = dropout
+
+ self.fc = nn.Conv2d(
+ in_channels=rnn_dim,
+ out_channels=emb_dim,
+ kernel_size=(1, 1),
+ padding="same",
+ stride=(1, 1),
+ )
+
+ def forward(self, z):
+ # z = (batch, n_uncrossed, n_across, emb_dim)
+
+ z0 = torch.clone(z)
+
+ z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim)
+ z = self.conv(z) # (batch, n_uncrossed, n_across, emb_dim)
+ z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
+ z = z + z0
+
+ return z
+
+
+class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule):
+ def __init__(
+ self,
+ n_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ dropout: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.seqband = torch.jit.script(
+ nn.Sequential(
+ *[
+ ResidualConvolution(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ dropout=dropout,
+ )
+ for _ in range(2 * n_modules)
+ ]
+ )
+ )
+
+ def forward(self, z):
+ # z = (batch, n_bands, n_time, emb_dim)
+
+ z = torch.permute(z, (0, 3, 1, 2)) # (batch, emb_dim, n_bands, n_time)
+
+ z = self.seqband(z) # (batch, emb_dim, n_bands, n_time)
+
+ z = torch.permute(z, (0, 2, 3, 1)) # (batch, n_bands, n_time, emb_dim)
+
+ return z
diff --git a/mvsepless/models/bandit/core/model/bsrnn/utils.py b/mvsepless/models/bandit/core/model/bsrnn/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f470b180303c48e2e5a02b1319bff743c0c83db
--- /dev/null
+++ b/mvsepless/models/bandit/core/model/bsrnn/utils.py
@@ -0,0 +1,525 @@
+import os
+from abc import abstractmethod
+from typing import Any, Callable
+
+import numpy as np
+import torch
+from librosa import hz_to_midi, midi_to_hz
+from torch import Tensor
+from torchaudio import functional as taF
+from spafe.fbanks import bark_fbanks
+from spafe.utils.converters import erb2hz, hz2bark, hz2erb
+from torchaudio.functional.functional import _create_triangular_filterbank
+
+
+def band_widths_from_specs(band_specs):
+ return [e - i for i, e in band_specs]
+
+
+def check_nonzero_bandwidth(band_specs):
+ # pprint(band_specs)
+ for fstart, fend in band_specs:
+ if fend - fstart <= 0:
+ raise ValueError("Bands cannot be zero-width")
+
+
+def check_no_overlap(band_specs):
+ fend_prev = -1
+ for fstart_curr, fend_curr in band_specs:
+ if fstart_curr <= fend_prev:
+ raise ValueError("Bands cannot overlap")
+
+
+def check_no_gap(band_specs):
+ fstart, _ = band_specs[0]
+ assert fstart == 0
+
+ fend_prev = -1
+ for fstart_curr, fend_curr in band_specs:
+ if fstart_curr - fend_prev > 1:
+ raise ValueError("Bands cannot leave gap")
+ fend_prev = fend_curr
+
+
+class BandsplitSpecification:
+ def __init__(self, nfft: int, fs: int) -> None:
+ self.fs = fs
+ self.nfft = nfft
+ self.nyquist = fs / 2
+ self.max_index = nfft // 2 + 1
+
+ self.split500 = self.hertz_to_index(500)
+ self.split1k = self.hertz_to_index(1000)
+ self.split2k = self.hertz_to_index(2000)
+ self.split4k = self.hertz_to_index(4000)
+ self.split8k = self.hertz_to_index(8000)
+ self.split16k = self.hertz_to_index(16000)
+ self.split20k = self.hertz_to_index(20000)
+
+ self.above20k = [(self.split20k, self.max_index)]
+ self.above16k = [(self.split16k, self.split20k)] + self.above20k
+
+ def index_to_hertz(self, index: int):
+ return index * self.fs / self.nfft
+
+ def hertz_to_index(self, hz: float, round: bool = True):
+ index = hz * self.nfft / self.fs
+
+ if round:
+ index = int(np.round(index))
+
+ return index
+
+ def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
+ band_specs = []
+ lower = start_index
+
+ while lower < end_index:
+ upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
+ upper = min(upper, end_index)
+
+ band_specs.append((lower, upper))
+ lower = upper
+
+ return band_specs
+
+ @abstractmethod
+ def get_band_specs(self):
+ raise NotImplementedError
+
+
+class VocalBandsplitSpecification(BandsplitSpecification):
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
+ super().__init__(nfft=nfft, fs=fs)
+
+ self.version = version
+
+ def get_band_specs(self):
+ return getattr(self, f"version{self.version}")()
+
+ @property
+ def version1(self):
+ return self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.max_index, bandwidth_hz=1000
+ )
+
+ def version2(self):
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split16k, bandwidth_hz=1000
+ )
+ below20k = self.get_band_specs_with_bandwidth(
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
+ )
+
+ return below16k + below20k + self.above20k
+
+ def version3(self):
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split8k, bandwidth_hz=1000
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
+ )
+
+ return below8k + below16k + self.above16k
+
+ def version4(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
+ )
+
+ return below1k + below8k + below16k + self.above16k
+
+ def version5(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
+ )
+ below20k = self.get_band_specs_with_bandwidth(
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
+ )
+ return below1k + below16k + below20k + self.above20k
+
+ def version6(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
+ )
+ below4k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
+ )
+ return below1k + below4k + below8k + below16k + self.above16k
+
+ def version7(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
+ )
+ below4k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
+ )
+ below20k = self.get_band_specs_with_bandwidth(
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
+ )
+ return below1k + below4k + below8k + below16k + below20k + self.above20k
+
+
+class OtherBandsplitSpecification(VocalBandsplitSpecification):
+ def __init__(self, nfft: int, fs: int) -> None:
+ super().__init__(nfft=nfft, fs=fs, version="7")
+
+
+class BassBandsplitSpecification(BandsplitSpecification):
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
+ super().__init__(nfft=nfft, fs=fs)
+
+ def get_band_specs(self):
+ below500 = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split500, bandwidth_hz=50
+ )
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
+ )
+ below4k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
+ )
+ above16k = [(self.split16k, self.max_index)]
+
+ return below500 + below1k + below4k + below8k + below16k + above16k
+
+
+class DrumBandsplitSpecification(BandsplitSpecification):
+ def __init__(self, nfft: int, fs: int) -> None:
+ super().__init__(nfft=nfft, fs=fs)
+
+ def get_band_specs(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=50
+ )
+ below2k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
+ )
+ below4k = self.get_band_specs_with_bandwidth(
+ start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
+ )
+ above16k = [(self.split16k, self.max_index)]
+
+ return below1k + below2k + below4k + below8k + below16k + above16k
+
+
+class PerceptualBandsplitSpecification(BandsplitSpecification):
+ def __init__(
+ self,
+ nfft: int,
+ fs: int,
+ fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
+ n_bands: int,
+ f_min: float = 0.0,
+ f_max: float = None,
+ ) -> None:
+ super().__init__(nfft=nfft, fs=fs)
+ self.n_bands = n_bands
+ if f_max is None:
+ f_max = fs / 2
+
+ self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
+
+ weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True) # (1, n_freqs)
+ normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs)
+
+ freq_weights = []
+ band_specs = []
+ for i in range(self.n_bands):
+ active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
+ if isinstance(active_bins, int):
+ active_bins = (active_bins, active_bins)
+ if len(active_bins) == 0:
+ continue
+ start_index = active_bins[0]
+ end_index = active_bins[-1] + 1
+ band_specs.append((start_index, end_index))
+ freq_weights.append(normalized_mel_fb[i, start_index:end_index])
+
+ self.freq_weights = freq_weights
+ self.band_specs = band_specs
+
+ def get_band_specs(self):
+ return self.band_specs
+
+ def get_freq_weights(self):
+ return self.freq_weights
+
+ def save_to_file(self, dir_path: str) -> None:
+
+ os.makedirs(dir_path, exist_ok=True)
+
+ import pickle
+
+ with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
+ pickle.dump(
+ {
+ "band_specs": self.band_specs,
+ "freq_weights": self.freq_weights,
+ "filterbank": self.filterbank,
+ },
+ f,
+ )
+
+
+def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
+ fb = taF.melscale_fbanks(
+ n_mels=n_bands,
+ sample_rate=fs,
+ f_min=f_min,
+ f_max=f_max,
+ n_freqs=n_freqs,
+ ).T
+
+ fb[0, 0] = 1.0
+
+ return fb
+
+
+class MelBandsplitSpecification(PerceptualBandsplitSpecification):
+ def __init__(
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
+ ) -> None:
+ super().__init__(
+ fbank_fn=mel_filterbank,
+ nfft=nfft,
+ fs=fs,
+ n_bands=n_bands,
+ f_min=f_min,
+ f_max=f_max,
+ )
+
+
+def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
+
+ nfft = 2 * (n_freqs - 1)
+ df = fs / nfft
+ # init freqs
+ f_max = f_max or fs / 2
+ f_min = f_min or 0
+ f_min = fs / nfft
+
+ n_octaves = np.log2(f_max / f_min)
+ n_octaves_per_band = n_octaves / n_bands
+ bandwidth_mult = np.power(2.0, n_octaves_per_band)
+
+ low_midi = max(0, hz_to_midi(f_min))
+ high_midi = hz_to_midi(f_max)
+ midi_points = np.linspace(low_midi, high_midi, n_bands)
+ hz_pts = midi_to_hz(midi_points)
+
+ low_pts = hz_pts / bandwidth_mult
+ high_pts = hz_pts * bandwidth_mult
+
+ low_bins = np.floor(low_pts / df).astype(int)
+ high_bins = np.ceil(high_pts / df).astype(int)
+
+ fb = np.zeros((n_bands, n_freqs))
+
+ for i in range(n_bands):
+ fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
+
+ fb[0, : low_bins[0]] = 1.0
+ fb[-1, high_bins[-1] + 1 :] = 1.0
+
+ return torch.as_tensor(fb)
+
+
+class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
+ def __init__(
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
+ ) -> None:
+ super().__init__(
+ fbank_fn=musical_filterbank,
+ nfft=nfft,
+ fs=fs,
+ n_bands=n_bands,
+ f_min=f_min,
+ f_max=f_max,
+ )
+
+
+def bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
+ nfft = 2 * (n_freqs - 1)
+ fb, _ = bark_fbanks.bark_filter_banks(
+ nfilts=n_bands,
+ nfft=nfft,
+ fs=fs,
+ low_freq=f_min,
+ high_freq=f_max,
+ scale="constant",
+ )
+
+ return torch.as_tensor(fb)
+
+
+class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
+ def __init__(
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
+ ) -> None:
+ super().__init__(
+ fbank_fn=bark_filterbank,
+ nfft=nfft,
+ fs=fs,
+ n_bands=n_bands,
+ f_min=f_min,
+ f_max=f_max,
+ )
+
+
+def triangular_bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
+
+ all_freqs = torch.linspace(0, fs // 2, n_freqs)
+
+ # calculate mel freq bins
+ m_min = hz2bark(f_min)
+ m_max = hz2bark(f_max)
+
+ m_pts = torch.linspace(m_min, m_max, n_bands + 2)
+ f_pts = 600 * torch.sinh(m_pts / 6)
+
+ # create filterbank
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
+
+ fb = fb.T
+
+ first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
+ first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
+
+ fb[first_active_band, :first_active_bin] = 1.0
+
+ return fb
+
+
+class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
+ def __init__(
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
+ ) -> None:
+ super().__init__(
+ fbank_fn=triangular_bark_filterbank,
+ nfft=nfft,
+ fs=fs,
+ n_bands=n_bands,
+ f_min=f_min,
+ f_max=f_max,
+ )
+
+
+def minibark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
+ fb = bark_filterbank(n_bands, fs, f_min, f_max, n_freqs)
+
+ fb[fb < np.sqrt(0.5)] = 0.0
+
+ return fb
+
+
+class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
+ def __init__(
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
+ ) -> None:
+ super().__init__(
+ fbank_fn=minibark_filterbank,
+ nfft=nfft,
+ fs=fs,
+ n_bands=n_bands,
+ f_min=f_min,
+ f_max=f_max,
+ )
+
+
+def erb_filterbank(
+ n_bands: int,
+ fs: int,
+ f_min: float,
+ f_max: float,
+ n_freqs: int,
+) -> Tensor:
+ # freq bins
+ A = (1000 * np.log(10)) / (24.7 * 4.37)
+ all_freqs = torch.linspace(0, fs // 2, n_freqs)
+
+ # calculate mel freq bins
+ m_min = hz2erb(f_min)
+ m_max = hz2erb(f_max)
+
+ m_pts = torch.linspace(m_min, m_max, n_bands + 2)
+ f_pts = (torch.pow(10, (m_pts / A)) - 1) / 0.00437
+
+ # create filterbank
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
+
+ fb = fb.T
+
+ first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
+ first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
+
+ fb[first_active_band, :first_active_bin] = 1.0
+
+ return fb
+
+
+class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
+ def __init__(
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
+ ) -> None:
+ super().__init__(
+ fbank_fn=erb_filterbank,
+ nfft=nfft,
+ fs=fs,
+ n_bands=n_bands,
+ f_min=f_min,
+ f_max=f_max,
+ )
+
+
+if __name__ == "__main__":
+ import pandas as pd
+
+ band_defs = []
+
+ for bands in [VocalBandsplitSpecification]:
+ band_name = bands.__name__.replace("BandsplitSpecification", "")
+
+ mbs = bands(nfft=2048, fs=44100).get_band_specs()
+
+ for i, (f_min, f_max) in enumerate(mbs):
+ band_defs.append(
+ {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
+ )
+
+ df = pd.DataFrame(band_defs)
+ df.to_csv("vox7bands.csv", index=False)
diff --git a/mvsepless/models/bandit/core/model/bsrnn/wrapper.py b/mvsepless/models/bandit/core/model/bsrnn/wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..b077b03b1712f186bdb51764537a13a3731c0aa1
--- /dev/null
+++ b/mvsepless/models/bandit/core/model/bsrnn/wrapper.py
@@ -0,0 +1,829 @@
+from pprint import pprint
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from .._spectral import _SpectralComponent
+from .utils import (
+ BarkBandsplitSpecification,
+ BassBandsplitSpecification,
+ DrumBandsplitSpecification,
+ EquivalentRectangularBandsplitSpecification,
+ MelBandsplitSpecification,
+ MusicalBandsplitSpecification,
+ OtherBandsplitSpecification,
+ TriangularBarkBandsplitSpecification,
+ VocalBandsplitSpecification,
+)
+from .core import (
+ MultiSourceMultiMaskBandSplitCoreConv,
+ MultiSourceMultiMaskBandSplitCoreRNN,
+ MultiSourceMultiMaskBandSplitCoreTransformer,
+ MultiSourceMultiPatchingMaskBandSplitCoreRNN,
+ SingleMaskBandsplitCoreRNN,
+ SingleMaskBandsplitCoreTransformer,
+)
+
+import pytorch_lightning as pl
+
+
+def get_band_specs(band_specs, n_fft, fs, n_bands=None):
+ if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]:
+ bsm = VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs()
+ freq_weights = None
+ overlapping_band = False
+ elif "tribark" in band_specs:
+ assert n_bands is not None
+ specs = TriangularBarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
+ bsm = specs.get_band_specs()
+ freq_weights = specs.get_freq_weights()
+ overlapping_band = True
+ elif "bark" in band_specs:
+ assert n_bands is not None
+ specs = BarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
+ bsm = specs.get_band_specs()
+ freq_weights = specs.get_freq_weights()
+ overlapping_band = True
+ elif "erb" in band_specs:
+ assert n_bands is not None
+ specs = EquivalentRectangularBandsplitSpecification(
+ nfft=n_fft, fs=fs, n_bands=n_bands
+ )
+ bsm = specs.get_band_specs()
+ freq_weights = specs.get_freq_weights()
+ overlapping_band = True
+ elif "musical" in band_specs:
+ assert n_bands is not None
+ specs = MusicalBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
+ bsm = specs.get_band_specs()
+ freq_weights = specs.get_freq_weights()
+ overlapping_band = True
+ elif band_specs == "dnr:mel" or "mel" in band_specs:
+ assert n_bands is not None
+ specs = MelBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
+ bsm = specs.get_band_specs()
+ freq_weights = specs.get_freq_weights()
+ overlapping_band = True
+ else:
+ raise NameError
+
+ return bsm, freq_weights, overlapping_band
+
+
+def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None):
+ if band_specs_map == "musdb:all":
+ bsm = {
+ "vocals": VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
+ "drums": DrumBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
+ "bass": BassBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
+ "other": OtherBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
+ }
+ freq_weights = None
+ overlapping_band = False
+ elif band_specs_map == "dnr:vox7":
+ bsm_, freq_weights, overlapping_band = get_band_specs(
+ "dnr:speech", n_fft, fs, n_bands
+ )
+ bsm = {"speech": bsm_, "music": bsm_, "effects": bsm_}
+ elif "dnr:vox7:" in band_specs_map:
+ stem = band_specs_map.split(":")[-1]
+ bsm_, freq_weights, overlapping_band = get_band_specs(
+ "dnr:speech", n_fft, fs, n_bands
+ )
+ bsm = {stem: bsm_}
+ else:
+ raise NameError
+
+ return bsm, freq_weights, overlapping_band
+
+
+class BandSplitWrapperBase(pl.LightningModule):
+ bsrnn: nn.Module
+
+ def __init__(self, **kwargs):
+ super().__init__()
+
+
+class SingleMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
+ def __init__(
+ self,
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
+ fs: int = 44100,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ ) -> None:
+ super().__init__(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ )
+
+ if isinstance(band_specs_map, str):
+ self.band_specs_map, self.freq_weights, self.overlapping_band = (
+ get_band_specs_map(band_specs_map, n_fft, fs, n_bands=n_bands)
+ )
+
+ self.stems = list(self.band_specs_map.keys())
+
+ def forward(self, batch):
+ audio = batch["audio"]
+
+ with torch.no_grad():
+ batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
+
+ X = batch["spectrogram"]["mixture"]
+ length = batch["audio"]["mixture"].shape[-1]
+
+ output = {"spectrogram": {}, "audio": {}}
+
+ for stem, bsrnn in self.bsrnn.items():
+ S = bsrnn(X)
+ s = self.istft(S, length)
+ output["spectrogram"][stem] = S
+ output["audio"][stem] = s
+
+ return batch, output
+
+
+class MultiMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
+ def __init__(
+ self,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ fs: int = 44100,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ ) -> None:
+ super().__init__(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ )
+
+ if isinstance(band_specs, str):
+ self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
+ band_specs, n_fft, fs, n_bands
+ )
+
+ self.stems = stems
+
+ def forward(self, batch):
+ # with torch.no_grad():
+ audio = batch["audio"]
+ cond = batch.get("condition", None)
+ with torch.no_grad():
+ batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
+
+ X = batch["spectrogram"]["mixture"]
+ length = batch["audio"]["mixture"].shape[-1]
+
+ output = self.bsrnn(X, cond=cond)
+ output["audio"] = {}
+
+ for stem, S in output["spectrogram"].items():
+ s = self.istft(S, length)
+ output["audio"][stem] = s
+
+ return batch, output
+
+
+class MultiMaskMultiSourceBandSplitBaseSimple(BandSplitWrapperBase, _SpectralComponent):
+ def __init__(
+ self,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ fs: int = 44100,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ ) -> None:
+ super().__init__(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ )
+
+ if isinstance(band_specs, str):
+ self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
+ band_specs, n_fft, fs, n_bands
+ )
+
+ self.stems = stems
+
+ def forward(self, batch):
+ with torch.no_grad():
+ X = self.stft(batch)
+ length = batch.shape[-1]
+ output = self.bsrnn(X, cond=None)
+ res = []
+ for stem, S in output["spectrogram"].items():
+ s = self.istft(S, length)
+ res.append(s)
+ res = torch.stack(res, dim=1)
+ return res
+
+
+class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase):
+ def __init__(
+ self,
+ in_channel: int,
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ ) -> None:
+ super().__init__(
+ band_specs_map=band_specs_map,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ )
+
+ self.bsrnn = nn.ModuleDict(
+ {
+ src: SingleMaskBandsplitCoreRNN(
+ band_specs=specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+ for src, specs in self.band_specs_map.items()
+ }
+ )
+
+
+class SingleMaskMultiSourceBandSplitTransformer(SingleMaskMultiSourceBandSplitBase):
+ def __init__(
+ self,
+ in_channel: int,
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ tf_dropout: float = 0.0,
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ ) -> None:
+ super().__init__(
+ band_specs_map=band_specs_map,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ )
+
+ self.bsrnn = nn.ModuleDict(
+ {
+ src: SingleMaskBandsplitCoreTransformer(
+ band_specs=specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ tf_dropout=tf_dropout,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+ for src, specs in self.band_specs_map.items()
+ }
+ )
+
+
+class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
+ def __init__(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ cond_dim: int = 0,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ use_freq_weights: bool = True,
+ normalize_input: bool = False,
+ mult_add_mask: bool = False,
+ freeze_encoder: bool = False,
+ ) -> None:
+ super().__init__(
+ stems=stems,
+ band_specs=band_specs,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ n_bands=n_bands,
+ )
+
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
+ stems=stems,
+ band_specs=self.band_specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=self.overlapping_band,
+ freq_weights=self.freq_weights,
+ n_freq=n_fft // 2 + 1,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask,
+ )
+
+ self.normalize_input = normalize_input
+ self.cond_dim = cond_dim
+
+ if freeze_encoder:
+ for param in self.bsrnn.band_split.parameters():
+ param.requires_grad = False
+
+ for param in self.bsrnn.tf_model.parameters():
+ param.requires_grad = False
+
+
+class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple):
+ def __init__(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ cond_dim: int = 0,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ use_freq_weights: bool = True,
+ normalize_input: bool = False,
+ mult_add_mask: bool = False,
+ freeze_encoder: bool = False,
+ ) -> None:
+ super().__init__(
+ stems=stems,
+ band_specs=band_specs,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ n_bands=n_bands,
+ )
+
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
+ stems=stems,
+ band_specs=self.band_specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=self.overlapping_band,
+ freq_weights=self.freq_weights,
+ n_freq=n_fft // 2 + 1,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask,
+ )
+
+ self.normalize_input = normalize_input
+ self.cond_dim = cond_dim
+
+ if freeze_encoder:
+ for param in self.bsrnn.band_split.parameters():
+ param.requires_grad = False
+
+ for param in self.bsrnn.tf_model.parameters():
+ param.requires_grad = False
+
+
+class MultiMaskMultiSourceBandSplitTransformer(MultiMaskMultiSourceBandSplitBase):
+ def __init__(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ cond_dim: int = 0,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ use_freq_weights: bool = True,
+ normalize_input: bool = False,
+ mult_add_mask: bool = False,
+ ) -> None:
+ super().__init__(
+ stems=stems,
+ band_specs=band_specs,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ n_bands=n_bands,
+ )
+
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer(
+ stems=stems,
+ band_specs=self.band_specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=self.overlapping_band,
+ freq_weights=self.freq_weights,
+ n_freq=n_fft // 2 + 1,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask,
+ )
+
+
+class MultiMaskMultiSourceBandSplitConv(MultiMaskMultiSourceBandSplitBase):
+ def __init__(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ cond_dim: int = 0,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ use_freq_weights: bool = True,
+ normalize_input: bool = False,
+ mult_add_mask: bool = False,
+ ) -> None:
+ super().__init__(
+ stems=stems,
+ band_specs=band_specs,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ n_bands=n_bands,
+ )
+
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv(
+ stems=stems,
+ band_specs=self.band_specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=self.overlapping_band,
+ freq_weights=self.freq_weights,
+ n_freq=n_fft // 2 + 1,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask,
+ )
+
+
+class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
+ def __init__(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ kernel_norm_mlp_version: int = 1,
+ mask_kernel_freq: int = 3,
+ mask_kernel_time: int = 3,
+ conv_kernel_freq: int = 1,
+ conv_kernel_time: int = 1,
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ ) -> None:
+ super().__init__(
+ stems=stems,
+ band_specs=band_specs,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ n_bands=n_bands,
+ )
+
+ self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN(
+ stems=stems,
+ band_specs=self.band_specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=self.overlapping_band,
+ freq_weights=self.freq_weights,
+ n_freq=n_fft // 2 + 1,
+ mask_kernel_freq=mask_kernel_freq,
+ mask_kernel_time=mask_kernel_time,
+ conv_kernel_freq=conv_kernel_freq,
+ conv_kernel_time=conv_kernel_time,
+ kernel_norm_mlp_version=kernel_norm_mlp_version,
+ )
diff --git a/mvsepless/models/bandit/core/utils/__init__.py b/mvsepless/models/bandit/core/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mvsepless/models/bandit/core/utils/audio.py b/mvsepless/models/bandit/core/utils/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..20fbfa6d1f20e763f380e5cc3c341384b4e3af16
--- /dev/null
+++ b/mvsepless/models/bandit/core/utils/audio.py
@@ -0,0 +1,412 @@
+from collections import defaultdict
+
+from tqdm.auto import tqdm
+from typing import Callable, Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+@torch.jit.script
+def merge(
+ combined: torch.Tensor,
+ original_batch_size: int,
+ n_channel: int,
+ n_chunks: int,
+ chunk_size: int,
+):
+ combined = torch.reshape(
+ combined, (original_batch_size, n_chunks, n_channel, chunk_size)
+ )
+ combined = torch.permute(combined, (0, 2, 3, 1)).reshape(
+ original_batch_size * n_channel, chunk_size, n_chunks
+ )
+
+ return combined
+
+
+@torch.jit.script
+def unfold(
+ padded_audio: torch.Tensor,
+ original_batch_size: int,
+ n_channel: int,
+ chunk_size: int,
+ hop_size: int,
+) -> torch.Tensor:
+
+ unfolded_input = F.unfold(
+ padded_audio[:, :, None, :], kernel_size=(1, chunk_size), stride=(1, hop_size)
+ )
+
+ _, _, n_chunks = unfolded_input.shape
+ unfolded_input = unfolded_input.view(
+ original_batch_size, n_channel, chunk_size, n_chunks
+ )
+ unfolded_input = torch.permute(unfolded_input, (0, 3, 1, 2)).reshape(
+ original_batch_size * n_chunks, n_channel, chunk_size
+ )
+
+ return unfolded_input
+
+
+@torch.jit.script
+# @torch.compile
+def merge_chunks_all(
+ combined: torch.Tensor,
+ original_batch_size: int,
+ n_channel: int,
+ n_samples: int,
+ n_padded_samples: int,
+ n_chunks: int,
+ chunk_size: int,
+ hop_size: int,
+ edge_frame_pad_sizes: Tuple[int, int],
+ standard_window: torch.Tensor,
+ first_window: torch.Tensor,
+ last_window: torch.Tensor,
+):
+ combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
+
+ combined = combined * standard_window[:, None].to(combined.device)
+
+ combined = F.fold(
+ combined.to(torch.float32),
+ output_size=(1, n_padded_samples),
+ kernel_size=(1, chunk_size),
+ stride=(1, hop_size),
+ )
+
+ combined = combined.view(original_batch_size, n_channel, n_padded_samples)
+
+ pad_front, pad_back = edge_frame_pad_sizes
+ combined = combined[..., pad_front:-pad_back]
+
+ combined = combined[..., :n_samples]
+
+ return combined
+
+ # @torch.jit.script
+
+
+def merge_chunks_edge(
+ combined: torch.Tensor,
+ original_batch_size: int,
+ n_channel: int,
+ n_samples: int,
+ n_padded_samples: int,
+ n_chunks: int,
+ chunk_size: int,
+ hop_size: int,
+ edge_frame_pad_sizes: Tuple[int, int],
+ standard_window: torch.Tensor,
+ first_window: torch.Tensor,
+ last_window: torch.Tensor,
+):
+ combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
+
+ combined[..., 0] = combined[..., 0] * first_window
+ combined[..., -1] = combined[..., -1] * last_window
+ combined[..., 1:-1] = combined[..., 1:-1] * standard_window[:, None]
+
+ combined = F.fold(
+ combined,
+ output_size=(1, n_padded_samples),
+ kernel_size=(1, chunk_size),
+ stride=(1, hop_size),
+ )
+
+ combined = combined.view(original_batch_size, n_channel, n_padded_samples)
+
+ combined = combined[..., :n_samples]
+
+ return combined
+
+
+class BaseFader(nn.Module):
+ def __init__(
+ self,
+ chunk_size_second: float,
+ hop_size_second: float,
+ fs: int,
+ fade_edge_frames: bool,
+ batch_size: int,
+ ) -> None:
+ super().__init__()
+
+ self.chunk_size = int(chunk_size_second * fs)
+ self.hop_size = int(hop_size_second * fs)
+ self.overlap_size = self.chunk_size - self.hop_size
+ self.fade_edge_frames = fade_edge_frames
+ self.batch_size = batch_size
+
+ # @torch.jit.script
+ def prepare(self, audio):
+
+ if self.fade_edge_frames:
+ audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect")
+
+ n_samples = audio.shape[-1]
+ n_chunks = int(np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1)
+
+ padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size
+ pad_size = padded_size - n_samples
+
+ padded_audio = F.pad(audio, (0, pad_size))
+
+ return padded_audio, n_chunks
+
+ def forward(
+ self,
+ audio: torch.Tensor,
+ model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
+ ):
+
+ original_dtype = audio.dtype
+ original_device = audio.device
+
+ audio = audio.to("cpu")
+
+ original_batch_size, n_channel, n_samples = audio.shape
+ padded_audio, n_chunks = self.prepare(audio)
+ del audio
+ n_padded_samples = padded_audio.shape[-1]
+
+ if n_channel > 1:
+ padded_audio = padded_audio.view(
+ original_batch_size * n_channel, 1, n_padded_samples
+ )
+
+ unfolded_input = unfold(
+ padded_audio, original_batch_size, n_channel, self.chunk_size, self.hop_size
+ )
+
+ n_total_chunks, n_channel, chunk_size = unfolded_input.shape
+
+ n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int)
+
+ chunks_in = [
+ unfolded_input[b * self.batch_size : (b + 1) * self.batch_size, ...].clone()
+ for b in range(n_batch)
+ ]
+
+ all_chunks_out = defaultdict(
+ lambda: torch.zeros_like(unfolded_input, device="cpu")
+ )
+
+ # for b, cin in enumerate(tqdm(chunks_in)):
+ for b, cin in enumerate(chunks_in):
+ if torch.allclose(cin, torch.tensor(0.0)):
+ del cin
+ continue
+
+ chunks_out = model_fn(cin.to(original_device))
+ del cin
+ for s, c in chunks_out.items():
+ all_chunks_out[s][
+ b * self.batch_size : (b + 1) * self.batch_size, ...
+ ] = c.cpu()
+ del chunks_out
+
+ del unfolded_input
+ del padded_audio
+
+ if self.fade_edge_frames:
+ fn = merge_chunks_all
+ else:
+ fn = merge_chunks_edge
+ outputs = {}
+
+ torch.cuda.empty_cache()
+
+ for s, c in all_chunks_out.items():
+ combined: torch.Tensor = fn(
+ c,
+ original_batch_size,
+ n_channel,
+ n_samples,
+ n_padded_samples,
+ n_chunks,
+ self.chunk_size,
+ self.hop_size,
+ self.edge_frame_pad_sizes,
+ self.standard_window,
+ self.__dict__.get("first_window", self.standard_window),
+ self.__dict__.get("last_window", self.standard_window),
+ )
+
+ outputs[s] = combined.to(dtype=original_dtype, device=original_device)
+
+ return {"audio": outputs}
+
+ #
+ # def old_forward(
+ # self,
+ # audio: torch.Tensor,
+ # model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
+ # ):
+ #
+ # n_samples = audio.shape[-1]
+ # original_batch_size = audio.shape[0]
+ #
+ # padded_audio, n_chunks = self.prepare(audio)
+ #
+ # ndim = padded_audio.ndim
+ # broadcaster = [1 for _ in range(ndim - 1)] + [self.chunk_size]
+ #
+ # outputs = defaultdict(
+ # lambda: torch.zeros_like(
+ # padded_audio, device=audio.device, dtype=torch.float64
+ # )
+ # )
+ #
+ # all_chunks_out = []
+ # len_chunks_in = []
+ #
+ # batch_size_ = int(self.batch_size // original_batch_size)
+ # for b in range(int(np.ceil(n_chunks / batch_size_))):
+ # chunks_in = []
+ # for j in range(batch_size_):
+ # i = b * batch_size_ + j
+ # if i == n_chunks:
+ # break
+ #
+ # start = i * hop_size
+ # end = start + self.chunk_size
+ # chunk_in = padded_audio[..., start:end]
+ # chunks_in.append(chunk_in)
+ #
+ # chunks_in = torch.concat(chunks_in, dim=0)
+ # chunks_out = model_fn(chunks_in)
+ # all_chunks_out.append(chunks_out)
+ # len_chunks_in.append(len(chunks_in))
+ #
+ # for b, (chunks_out, lci) in enumerate(
+ # zip(all_chunks_out, len_chunks_in)
+ # ):
+ # for stem in chunks_out:
+ # for j in range(lci // original_batch_size):
+ # i = b * batch_size_ + j
+ #
+ # if self.fade_edge_frames:
+ # window = self.standard_window
+ # else:
+ # if i == 0:
+ # window = self.first_window
+ # elif i == n_chunks - 1:
+ # window = self.last_window
+ # else:
+ # window = self.standard_window
+ #
+ # start = i * hop_size
+ # end = start + self.chunk_size
+ #
+ # chunk_out = chunks_out[stem][j * original_batch_size: (j + 1) * original_batch_size,
+ # ...]
+ # contrib = window.view(*broadcaster) * chunk_out
+ # outputs[stem][..., start:end] = (
+ # outputs[stem][..., start:end] + contrib
+ # )
+ #
+ # if self.fade_edge_frames:
+ # pad_front, pad_back = self.edge_frame_pad_sizes
+ # outputs = {k: v[..., pad_front:-pad_back] for k, v in
+ # outputs.items()}
+ #
+ # outputs = {k: v[..., :n_samples].to(audio.dtype) for k, v in
+ # outputs.items()}
+ #
+ # return {
+ # "audio": outputs
+ # }
+
+
+class LinearFader(BaseFader):
+ def __init__(
+ self,
+ chunk_size_second: float,
+ hop_size_second: float,
+ fs: int,
+ fade_edge_frames: bool = False,
+ batch_size: int = 1,
+ ) -> None:
+
+ assert hop_size_second >= chunk_size_second / 2
+
+ super().__init__(
+ chunk_size_second=chunk_size_second,
+ hop_size_second=hop_size_second,
+ fs=fs,
+ fade_edge_frames=fade_edge_frames,
+ batch_size=batch_size,
+ )
+
+ in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1]
+ out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:]
+ center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size)
+ inout_ones = torch.ones(self.overlap_size)
+
+ # using nn.Parameters allows lightning to take care of devices for us
+ self.register_buffer(
+ "standard_window", torch.concat([in_fade, center_ones, out_fade])
+ )
+
+ self.fade_edge_frames = fade_edge_frames
+ self.edge_frame_pad_size = (self.overlap_size, self.overlap_size)
+
+ if not self.fade_edge_frames:
+ self.first_window = nn.Parameter(
+ torch.concat([inout_ones, center_ones, out_fade]), requires_grad=False
+ )
+ self.last_window = nn.Parameter(
+ torch.concat([in_fade, center_ones, inout_ones]), requires_grad=False
+ )
+
+
+class OverlapAddFader(BaseFader):
+ def __init__(
+ self,
+ window_type: str,
+ chunk_size_second: float,
+ hop_size_second: float,
+ fs: int,
+ batch_size: int = 1,
+ ) -> None:
+ assert (chunk_size_second / hop_size_second) % 2 == 0
+ assert int(chunk_size_second * fs) % 2 == 0
+
+ super().__init__(
+ chunk_size_second=chunk_size_second,
+ hop_size_second=hop_size_second,
+ fs=fs,
+ fade_edge_frames=True,
+ batch_size=batch_size,
+ )
+
+ self.hop_multiplier = self.chunk_size / (2 * self.hop_size)
+ # print(f"hop multiplier: {self.hop_multiplier}")
+
+ self.edge_frame_pad_sizes = (2 * self.overlap_size, 2 * self.overlap_size)
+
+ self.register_buffer(
+ "standard_window",
+ torch.windows.__dict__[window_type](
+ self.chunk_size,
+ sym=False, # dtype=torch.float64
+ )
+ / self.hop_multiplier,
+ )
+
+
+if __name__ == "__main__":
+ import torchaudio as ta
+
+ fs = 44100
+ ola = OverlapAddFader("hann", 6.0, 1.0, fs, batch_size=16)
+ audio_, _ = ta.load(
+ "$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too " "Much/vocals.wav"
+ )
+ audio_ = audio_[None, ...]
+ out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"]
+ print(torch.allclose(out, audio_))
diff --git a/mvsepless/models/bandit/model_from_config.py b/mvsepless/models/bandit/model_from_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..3007853faf7de657be8fceb18d6b5e720ee4a60f
--- /dev/null
+++ b/mvsepless/models/bandit/model_from_config.py
@@ -0,0 +1,26 @@
+import sys
+import os.path
+import torch
+
+import yaml
+from ml_collections import ConfigDict
+
+torch.set_float32_matmul_precision("medium")
+
+
+def get_model(
+ config_path,
+ weights_path,
+ device,
+):
+ from .core.model import MultiMaskMultiSourceBandSplitRNNSimple
+
+ f = open(config_path)
+ config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
+ f.close()
+
+ model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model)
+ d = torch.load(code_path + "model_bandit_plus_dnr_sdr_11.47.chpt")
+ model.load_state_dict(d)
+ model.to(device)
+ return model, config
diff --git a/mvsepless/models/bandit_v2/bandit.py b/mvsepless/models/bandit_v2/bandit.py
new file mode 100644
index 0000000000000000000000000000000000000000..fba32962f242ce44f64ab7d8ca433290cbe6d7e2
--- /dev/null
+++ b/mvsepless/models/bandit_v2/bandit.py
@@ -0,0 +1,363 @@
+from typing import Dict, List, Optional
+
+import torch
+import torchaudio as ta
+from torch import nn
+import pytorch_lightning as pl
+
+from .bandsplit import BandSplitModule
+from .maskestim import OverlappingMaskEstimationModule
+from .tfmodel import SeqBandModellingModule
+from .utils import MusicalBandsplitSpecification
+
+
+class BaseEndToEndModule(pl.LightningModule):
+ def __init__(
+ self,
+ ) -> None:
+ super().__init__()
+
+
+class BaseBandit(BaseEndToEndModule):
+ def __init__(
+ self,
+ in_channels: int,
+ fs: int,
+ band_type: str = "musical",
+ n_bands: int = 64,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+
+ self.instantitate_spectral(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ normalized=normalized,
+ center=center,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ )
+
+ self.instantiate_bandsplit(
+ in_channels=in_channels,
+ band_type=band_type,
+ n_bands=n_bands,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
+ n_fft=n_fft,
+ fs=fs,
+ )
+
+ self.instantiate_tf_modelling(
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ )
+
+ def instantitate_spectral(
+ self,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ normalized: bool = True,
+ center: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ ):
+ assert power is None
+
+ window_fn = torch.__dict__[window_fn]
+
+ self.stft = ta.transforms.Spectrogram(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ pad_mode=pad_mode,
+ pad=0,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ normalized=normalized,
+ center=center,
+ onesided=onesided,
+ )
+
+ self.istft = ta.transforms.InverseSpectrogram(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ pad_mode=pad_mode,
+ pad=0,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ normalized=normalized,
+ center=center,
+ onesided=onesided,
+ )
+
+ def instantiate_bandsplit(
+ self,
+ in_channels: int,
+ band_type: str = "musical",
+ n_bands: int = 64,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ emb_dim: int = 128,
+ n_fft: int = 2048,
+ fs: int = 44100,
+ ):
+ assert band_type == "musical"
+
+ self.band_specs = MusicalBandsplitSpecification(
+ nfft=n_fft, fs=fs, n_bands=n_bands
+ )
+
+ self.band_split = BandSplitModule(
+ in_channels=in_channels,
+ band_specs=self.band_specs.get_band_specs(),
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
+ )
+
+ def instantiate_tf_modelling(
+ self,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ ):
+ try:
+ self.tf_model = torch.compile(
+ SeqBandModellingModule(
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ ),
+ disable=True,
+ )
+ except Exception as e:
+ self.tf_model = SeqBandModellingModule(
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ )
+
+ def mask(self, x, m):
+ return x * m
+
+ def forward(self, batch, mode="train"):
+ # Model takes mono as input we give stereo, so we do process of each channel independently
+ init_shape = batch.shape
+ if not isinstance(batch, dict):
+ mono = batch.view(-1, 1, batch.shape[-1])
+ batch = {"mixture": {"audio": mono}}
+
+ with torch.no_grad():
+ mixture = batch["mixture"]["audio"]
+
+ x = self.stft(mixture)
+ batch["mixture"]["spectrogram"] = x
+
+ if "sources" in batch.keys():
+ for stem in batch["sources"].keys():
+ s = batch["sources"][stem]["audio"]
+ s = self.stft(s)
+ batch["sources"][stem]["spectrogram"] = s
+
+ batch = self.separate(batch)
+
+ if 1:
+ b = []
+ for s in self.stems:
+ # We need to obtain stereo again
+ r = batch["estimates"][s]["audio"].view(
+ -1, init_shape[1], init_shape[2]
+ )
+ b.append(r)
+ # And we need to return back tensor and not independent stems
+ batch = torch.stack(b, dim=1)
+ return batch
+
+ def encode(self, batch):
+ x = batch["mixture"]["spectrogram"]
+ length = batch["mixture"]["audio"].shape[-1]
+
+ z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
+ q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
+
+ return x, q, length
+
+ def separate(self, batch):
+ raise NotImplementedError
+
+
+class Bandit(BaseBandit):
+ def __init__(
+ self,
+ in_channels: int,
+ stems: List[str],
+ band_type: str = "musical",
+ n_bands: int = 64,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Dict | None = None,
+ complex_mask: bool = True,
+ use_freq_weights: bool = True,
+ n_fft: int = 2048,
+ win_length: int | None = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Dict | None = None,
+ power: int | None = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ fs: int = 44100,
+ stft_precisions="32",
+ bandsplit_precisions="bf16",
+ tf_model_precisions="bf16",
+ mask_estim_precisions="bf16",
+ ):
+ super().__init__(
+ in_channels=in_channels,
+ band_type=band_type,
+ n_bands=n_bands,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ fs=fs,
+ )
+
+ self.stems = stems
+
+ self.instantiate_mask_estim(
+ in_channels=in_channels,
+ stems=stems,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ n_freq=n_fft // 2 + 1,
+ use_freq_weights=use_freq_weights,
+ )
+
+ def instantiate_mask_estim(
+ self,
+ in_channels: int,
+ stems: List[str],
+ emb_dim: int,
+ mlp_dim: int,
+ hidden_activation: str,
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_freq: Optional[int] = None,
+ use_freq_weights: bool = False,
+ ):
+ if hidden_activation_kwargs is None:
+ hidden_activation_kwargs = {}
+
+ assert n_freq is not None
+
+ self.mask_estim = nn.ModuleDict(
+ {
+ stem: OverlappingMaskEstimationModule(
+ band_specs=self.band_specs.get_band_specs(),
+ freq_weights=self.band_specs.get_freq_weights(),
+ n_freq=n_freq,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ in_channels=in_channels,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ use_freq_weights=use_freq_weights,
+ )
+ for stem in stems
+ }
+ )
+
+ def separate(self, batch):
+ batch["estimates"] = {}
+
+ x, q, length = self.encode(batch)
+
+ for stem, mem in self.mask_estim.items():
+ m = mem(q)
+
+ s = self.mask(x, m.to(x.dtype))
+ s = torch.reshape(s, x.shape)
+ batch["estimates"][stem] = {
+ "audio": self.istft(s, length),
+ "spectrogram": s,
+ }
+
+ return batch
diff --git a/mvsepless/models/bandit_v2/bandsplit.py b/mvsepless/models/bandit_v2/bandsplit.py
new file mode 100644
index 0000000000000000000000000000000000000000..a14ea52bfa318264d536c9f934d0e28db63e15dc
--- /dev/null
+++ b/mvsepless/models/bandit_v2/bandsplit.py
@@ -0,0 +1,130 @@
+from typing import List, Tuple
+
+import torch
+from torch import nn
+from torch.utils.checkpoint import checkpoint_sequential
+
+from .utils import (
+ band_widths_from_specs,
+ check_no_gap,
+ check_no_overlap,
+ check_nonzero_bandwidth,
+)
+
+
+class NormFC(nn.Module):
+ def __init__(
+ self,
+ emb_dim: int,
+ bandwidth: int,
+ in_channels: int,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ ) -> None:
+ super().__init__()
+
+ if not treat_channel_as_feature:
+ raise NotImplementedError
+
+ self.treat_channel_as_feature = treat_channel_as_feature
+
+ if normalize_channel_independently:
+ raise NotImplementedError
+
+ reim = 2
+
+ norm = nn.LayerNorm(in_channels * bandwidth * reim)
+
+ fc_in = bandwidth * reim
+
+ if treat_channel_as_feature:
+ fc_in *= in_channels
+ else:
+ assert emb_dim % in_channels == 0
+ emb_dim = emb_dim // in_channels
+
+ fc = nn.Linear(fc_in, emb_dim)
+
+ self.combined = nn.Sequential(norm, fc)
+
+ def forward(self, xb):
+ return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False)
+
+
+class BandSplitModule(nn.Module):
+ def __init__(
+ self,
+ band_specs: List[Tuple[float, float]],
+ emb_dim: int,
+ in_channels: int,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ ) -> None:
+ super().__init__()
+
+ check_nonzero_bandwidth(band_specs)
+
+ if require_no_gap:
+ check_no_gap(band_specs)
+
+ if require_no_overlap:
+ check_no_overlap(band_specs)
+
+ self.band_specs = band_specs
+ # list of [fstart, fend) in index.
+ # Note that fend is exclusive.
+ self.band_widths = band_widths_from_specs(band_specs)
+ self.n_bands = len(band_specs)
+ self.emb_dim = emb_dim
+
+ try:
+ self.norm_fc_modules = nn.ModuleList(
+ [ # type: ignore
+ torch.compile(
+ NormFC(
+ emb_dim=emb_dim,
+ bandwidth=bw,
+ in_channels=in_channels,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ ),
+ disable=True,
+ )
+ for bw in self.band_widths
+ ]
+ )
+ except Exception as e:
+ self.norm_fc_modules = nn.ModuleList(
+ [ # type: ignore
+ NormFC(
+ emb_dim=emb_dim,
+ bandwidth=bw,
+ in_channels=in_channels,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ )
+ for bw in self.band_widths
+ ]
+ )
+
+ def forward(self, x: torch.Tensor):
+ # x = complex spectrogram (batch, in_chan, n_freq, n_time)
+
+ batch, in_chan, band_width, n_time = x.shape
+
+ z = torch.zeros(
+ size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
+ )
+
+ x = torch.permute(x, (0, 3, 1, 2)).contiguous()
+
+ for i, nfm in enumerate(self.norm_fc_modules):
+ fstart, fend = self.band_specs[i]
+ xb = x[:, :, :, fstart:fend]
+ xb = torch.view_as_real(xb)
+ xb = torch.reshape(xb, (batch, n_time, -1))
+ z[:, i, :, :] = nfm(xb)
+
+ return z
diff --git a/mvsepless/models/bandit_v2/film.py b/mvsepless/models/bandit_v2/film.py
new file mode 100644
index 0000000000000000000000000000000000000000..253594ad0154cee4ef7467036ed71f4d5f836db8
--- /dev/null
+++ b/mvsepless/models/bandit_v2/film.py
@@ -0,0 +1,23 @@
+from torch import nn
+import torch
+
+
+class FiLM(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, gamma, beta):
+ return gamma * x + beta
+
+
+class BTFBroadcastedFiLM(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.film = FiLM()
+
+ def forward(self, x, gamma, beta):
+
+ gamma = gamma[None, None, None, :]
+ beta = beta[None, None, None, :]
+
+ return self.film(x, gamma, beta)
diff --git a/mvsepless/models/bandit_v2/maskestim.py b/mvsepless/models/bandit_v2/maskestim.py
new file mode 100644
index 0000000000000000000000000000000000000000..65215d86a5e94dafdb71744aafadf7aaab93330d
--- /dev/null
+++ b/mvsepless/models/bandit_v2/maskestim.py
@@ -0,0 +1,281 @@
+from typing import Dict, List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch.nn.modules import activation
+from torch.utils.checkpoint import checkpoint_sequential
+
+from .utils import (
+ band_widths_from_specs,
+ check_no_gap,
+ check_no_overlap,
+ check_nonzero_bandwidth,
+)
+
+
+class BaseNormMLP(nn.Module):
+ def __init__(
+ self,
+ emb_dim: int,
+ mlp_dim: int,
+ bandwidth: int,
+ in_channels: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs=None,
+ complex_mask: bool = True,
+ ):
+ super().__init__()
+ if hidden_activation_kwargs is None:
+ hidden_activation_kwargs = {}
+ self.hidden_activation_kwargs = hidden_activation_kwargs
+ self.norm = nn.LayerNorm(emb_dim)
+ self.hidden = nn.Sequential(
+ nn.Linear(in_features=emb_dim, out_features=mlp_dim),
+ activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
+ )
+
+ self.bandwidth = bandwidth
+ self.in_channels = in_channels
+
+ self.complex_mask = complex_mask
+ self.reim = 2 if complex_mask else 1
+ self.glu_mult = 2
+
+
+class NormMLP(BaseNormMLP):
+ def __init__(
+ self,
+ emb_dim: int,
+ mlp_dim: int,
+ bandwidth: int,
+ in_channels: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs=None,
+ complex_mask: bool = True,
+ ) -> None:
+ super().__init__(
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ bandwidth=bandwidth,
+ in_channels=in_channels,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+
+ self.output = nn.Sequential(
+ nn.Linear(
+ in_features=mlp_dim,
+ out_features=bandwidth * in_channels * self.reim * 2,
+ ),
+ nn.GLU(dim=-1),
+ )
+
+ try:
+ self.combined = torch.compile(
+ nn.Sequential(self.norm, self.hidden, self.output), disable=True
+ )
+ except Exception as e:
+ self.combined = nn.Sequential(self.norm, self.hidden, self.output)
+
+ def reshape_output(self, mb):
+ # print(mb.shape)
+ batch, n_time, _ = mb.shape
+ if self.complex_mask:
+ mb = mb.reshape(
+ batch, n_time, self.in_channels, self.bandwidth, self.reim
+ ).contiguous()
+ # print(mb.shape)
+ mb = torch.view_as_complex(mb) # (batch, n_time, in_channels, bandwidth)
+ else:
+ mb = mb.reshape(batch, n_time, self.in_channels, self.bandwidth)
+
+ mb = torch.permute(mb, (0, 2, 3, 1)) # (batch, in_channels, bandwidth, n_time)
+
+ return mb
+
+ def forward(self, qb):
+ # qb = (batch, n_time, emb_dim)
+ # qb = self.norm(qb) # (batch, n_time, emb_dim)
+ # qb = self.hidden(qb) # (batch, n_time, mlp_dim)
+ # mb = self.output(qb) # (batch, n_time, bandwidth * in_channels * reim)
+
+ mb = checkpoint_sequential(self.combined, 2, qb, use_reentrant=False)
+ mb = self.reshape_output(mb) # (batch, in_channels, bandwidth, n_time)
+
+ return mb
+
+
+class MaskEstimationModuleSuperBase(nn.Module):
+ pass
+
+
+class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
+ def __init__(
+ self,
+ band_specs: List[Tuple[float, float]],
+ emb_dim: int,
+ mlp_dim: int,
+ in_channels: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Dict = None,
+ complex_mask: bool = True,
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
+ norm_mlp_kwargs: Dict = None,
+ ) -> None:
+ super().__init__()
+
+ self.band_widths = band_widths_from_specs(band_specs)
+ self.n_bands = len(band_specs)
+
+ if hidden_activation_kwargs is None:
+ hidden_activation_kwargs = {}
+
+ if norm_mlp_kwargs is None:
+ norm_mlp_kwargs = {}
+
+ self.norm_mlp = nn.ModuleList(
+ [
+ norm_mlp_cls(
+ bandwidth=self.band_widths[b],
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ in_channels=in_channels,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ **norm_mlp_kwargs,
+ )
+ for b in range(self.n_bands)
+ ]
+ )
+
+ def compute_masks(self, q):
+ batch, n_bands, n_time, emb_dim = q.shape
+
+ masks = []
+
+ for b, nmlp in enumerate(self.norm_mlp):
+ # print(f"maskestim/{b:02d}")
+ qb = q[:, b, :, :]
+ mb = nmlp(qb)
+ masks.append(mb)
+
+ return masks
+
+ def compute_mask(self, q, b):
+ batch, n_bands, n_time, emb_dim = q.shape
+ qb = q[:, b, :, :]
+ mb = self.norm_mlp[b](qb)
+ return mb
+
+
+class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
+ def __init__(
+ self,
+ in_channels: int,
+ band_specs: List[Tuple[float, float]],
+ freq_weights: List[torch.Tensor],
+ n_freq: int,
+ emb_dim: int,
+ mlp_dim: int,
+ cond_dim: int = 0,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Dict = None,
+ complex_mask: bool = True,
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
+ norm_mlp_kwargs: Dict = None,
+ use_freq_weights: bool = False,
+ ) -> None:
+ check_nonzero_bandwidth(band_specs)
+ check_no_gap(band_specs)
+
+ if cond_dim > 0:
+ raise NotImplementedError
+
+ super().__init__(
+ band_specs=band_specs,
+ emb_dim=emb_dim + cond_dim,
+ mlp_dim=mlp_dim,
+ in_channels=in_channels,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ norm_mlp_cls=norm_mlp_cls,
+ norm_mlp_kwargs=norm_mlp_kwargs,
+ )
+
+ self.n_freq = n_freq
+ self.band_specs = band_specs
+ self.in_channels = in_channels
+
+ if freq_weights is not None and use_freq_weights:
+ for i, fw in enumerate(freq_weights):
+ self.register_buffer(f"freq_weights/{i}", fw)
+
+ self.use_freq_weights = use_freq_weights
+ else:
+ self.use_freq_weights = False
+
+ def forward(self, q):
+ # q = (batch, n_bands, n_time, emb_dim)
+
+ batch, n_bands, n_time, emb_dim = q.shape
+
+ masks = torch.zeros(
+ (batch, self.in_channels, self.n_freq, n_time),
+ device=q.device,
+ dtype=torch.complex64,
+ )
+
+ for im in range(n_bands):
+ fstart, fend = self.band_specs[im]
+
+ mask = self.compute_mask(q, im)
+
+ if self.use_freq_weights:
+ fw = self.get_buffer(f"freq_weights/{im}")[:, None]
+ mask = mask * fw
+ masks[:, :, fstart:fend, :] += mask
+
+ return masks
+
+
+class MaskEstimationModule(OverlappingMaskEstimationModule):
+ def __init__(
+ self,
+ band_specs: List[Tuple[float, float]],
+ emb_dim: int,
+ mlp_dim: int,
+ in_channels: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Dict = None,
+ complex_mask: bool = True,
+ **kwargs,
+ ) -> None:
+ check_nonzero_bandwidth(band_specs)
+ check_no_gap(band_specs)
+ check_no_overlap(band_specs)
+ super().__init__(
+ in_channels=in_channels,
+ band_specs=band_specs,
+ freq_weights=None,
+ n_freq=None,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+
+ def forward(self, q, cond=None):
+ # q = (batch, n_bands, n_time, emb_dim)
+
+ masks = self.compute_masks(
+ q
+ ) # [n_bands * (batch, in_channels, bandwidth, n_time)]
+
+ # TODO: currently this requires band specs to have no gap and no overlap
+ masks = torch.concat(masks, dim=2) # (batch, in_channels, n_freq, n_time)
+
+ return masks
diff --git a/mvsepless/models/bandit_v2/tfmodel.py b/mvsepless/models/bandit_v2/tfmodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..21aef03d1f0e814c20db05fe7d14f8019f07713b
--- /dev/null
+++ b/mvsepless/models/bandit_v2/tfmodel.py
@@ -0,0 +1,145 @@
+import warnings
+
+import torch
+import torch.backends.cuda
+from torch import nn
+from torch.nn.modules import rnn
+from torch.utils.checkpoint import checkpoint_sequential
+
+
+class TimeFrequencyModellingModule(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+
+class ResidualRNN(nn.Module):
+ def __init__(
+ self,
+ emb_dim: int,
+ rnn_dim: int,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ use_batch_trick: bool = True,
+ use_layer_norm: bool = True,
+ ) -> None:
+ # n_group is the size of the 2nd dim
+ super().__init__()
+
+ assert use_layer_norm
+ assert use_batch_trick
+
+ self.use_layer_norm = use_layer_norm
+ self.norm = nn.LayerNorm(emb_dim)
+ self.rnn = rnn.__dict__[rnn_type](
+ input_size=emb_dim,
+ hidden_size=rnn_dim,
+ num_layers=1,
+ batch_first=True,
+ bidirectional=bidirectional,
+ )
+
+ self.fc = nn.Linear(
+ in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
+ )
+
+ self.use_batch_trick = use_batch_trick
+ if not self.use_batch_trick:
+ warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
+
+ def forward(self, z):
+ # z = (batch, n_uncrossed, n_across, emb_dim)
+
+ z0 = torch.clone(z)
+ z = self.norm(z)
+
+ batch, n_uncrossed, n_across, emb_dim = z.shape
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
+ z = self.rnn(z)[0]
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
+
+ z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
+
+ z = z + z0
+
+ return z
+
+
+class Transpose(nn.Module):
+ def __init__(self, dim0: int, dim1: int) -> None:
+ super().__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+
+ def forward(self, z):
+ return z.transpose(self.dim0, self.dim1)
+
+
+class SeqBandModellingModule(TimeFrequencyModellingModule):
+ def __init__(
+ self,
+ n_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ parallel_mode=False,
+ ) -> None:
+ super().__init__()
+
+ self.n_modules = n_modules
+
+ if parallel_mode:
+ self.seqband = nn.ModuleList([])
+ for _ in range(n_modules):
+ self.seqband.append(
+ nn.ModuleList(
+ [
+ ResidualRNN(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ ),
+ ResidualRNN(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ ),
+ ]
+ )
+ )
+ else:
+ seqband = []
+ for _ in range(2 * n_modules):
+ seqband += [
+ ResidualRNN(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ ),
+ Transpose(1, 2),
+ ]
+
+ self.seqband = nn.Sequential(*seqband)
+
+ self.parallel_mode = parallel_mode
+
+ def forward(self, z):
+ # z = (batch, n_bands, n_time, emb_dim)
+
+ if self.parallel_mode:
+ for sbm_pair in self.seqband:
+ # z: (batch, n_bands, n_time, emb_dim)
+ sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
+ zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
+ zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
+ z = zt + zf.transpose(1, 2)
+ else:
+ z = checkpoint_sequential(
+ self.seqband, self.n_modules, z, use_reentrant=False
+ )
+
+ q = z
+ return q # (batch, n_bands, n_time, emb_dim)
diff --git a/mvsepless/models/bandit_v2/utils.py b/mvsepless/models/bandit_v2/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad4eab5d8c5b5396ed717f5b9c365a6900eddd2f
--- /dev/null
+++ b/mvsepless/models/bandit_v2/utils.py
@@ -0,0 +1,523 @@
+import os
+from abc import abstractmethod
+from typing import Callable
+
+import numpy as np
+import torch
+from librosa import hz_to_midi, midi_to_hz
+from torchaudio import functional as taF
+
+# from spafe.fbanks import bark_fbanks
+# from spafe.utils.converters import erb2hz, hz2bark, hz2erb
+
+
+def band_widths_from_specs(band_specs):
+ return [e - i for i, e in band_specs]
+
+
+def check_nonzero_bandwidth(band_specs):
+ # pprint(band_specs)
+ for fstart, fend in band_specs:
+ if fend - fstart <= 0:
+ raise ValueError("Bands cannot be zero-width")
+
+
+def check_no_overlap(band_specs):
+ fend_prev = -1
+ for fstart_curr, fend_curr in band_specs:
+ if fstart_curr <= fend_prev:
+ raise ValueError("Bands cannot overlap")
+
+
+def check_no_gap(band_specs):
+ fstart, _ = band_specs[0]
+ assert fstart == 0
+
+ fend_prev = -1
+ for fstart_curr, fend_curr in band_specs:
+ if fstart_curr - fend_prev > 1:
+ raise ValueError("Bands cannot leave gap")
+ fend_prev = fend_curr
+
+
+class BandsplitSpecification:
+ def __init__(self, nfft: int, fs: int) -> None:
+ self.fs = fs
+ self.nfft = nfft
+ self.nyquist = fs / 2
+ self.max_index = nfft // 2 + 1
+
+ self.split500 = self.hertz_to_index(500)
+ self.split1k = self.hertz_to_index(1000)
+ self.split2k = self.hertz_to_index(2000)
+ self.split4k = self.hertz_to_index(4000)
+ self.split8k = self.hertz_to_index(8000)
+ self.split16k = self.hertz_to_index(16000)
+ self.split20k = self.hertz_to_index(20000)
+
+ self.above20k = [(self.split20k, self.max_index)]
+ self.above16k = [(self.split16k, self.split20k)] + self.above20k
+
+ def index_to_hertz(self, index: int):
+ return index * self.fs / self.nfft
+
+ def hertz_to_index(self, hz: float, round: bool = True):
+ index = hz * self.nfft / self.fs
+
+ if round:
+ index = int(np.round(index))
+
+ return index
+
+ def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
+ band_specs = []
+ lower = start_index
+
+ while lower < end_index:
+ upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
+ upper = min(upper, end_index)
+
+ band_specs.append((lower, upper))
+ lower = upper
+
+ return band_specs
+
+ @abstractmethod
+ def get_band_specs(self):
+ raise NotImplementedError
+
+
+class VocalBandsplitSpecification(BandsplitSpecification):
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
+ super().__init__(nfft=nfft, fs=fs)
+
+ self.version = version
+
+ def get_band_specs(self):
+ return getattr(self, f"version{self.version}")()
+
+ @property
+ def version1(self):
+ return self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.max_index, bandwidth_hz=1000
+ )
+
+ def version2(self):
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split16k, bandwidth_hz=1000
+ )
+ below20k = self.get_band_specs_with_bandwidth(
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
+ )
+
+ return below16k + below20k + self.above20k
+
+ def version3(self):
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split8k, bandwidth_hz=1000
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
+ )
+
+ return below8k + below16k + self.above16k
+
+ def version4(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
+ )
+
+ return below1k + below8k + below16k + self.above16k
+
+ def version5(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
+ )
+ below20k = self.get_band_specs_with_bandwidth(
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
+ )
+ return below1k + below16k + below20k + self.above20k
+
+ def version6(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
+ )
+ below4k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
+ )
+ return below1k + below4k + below8k + below16k + self.above16k
+
+ def version7(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
+ )
+ below4k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
+ )
+ below20k = self.get_band_specs_with_bandwidth(
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
+ )
+ return below1k + below4k + below8k + below16k + below20k + self.above20k
+
+
+class OtherBandsplitSpecification(VocalBandsplitSpecification):
+ def __init__(self, nfft: int, fs: int) -> None:
+ super().__init__(nfft=nfft, fs=fs, version="7")
+
+
+class BassBandsplitSpecification(BandsplitSpecification):
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
+ super().__init__(nfft=nfft, fs=fs)
+
+ def get_band_specs(self):
+ below500 = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split500, bandwidth_hz=50
+ )
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
+ )
+ below4k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
+ )
+ above16k = [(self.split16k, self.max_index)]
+
+ return below500 + below1k + below4k + below8k + below16k + above16k
+
+
+class DrumBandsplitSpecification(BandsplitSpecification):
+ def __init__(self, nfft: int, fs: int) -> None:
+ super().__init__(nfft=nfft, fs=fs)
+
+ def get_band_specs(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=50
+ )
+ below2k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
+ )
+ below4k = self.get_band_specs_with_bandwidth(
+ start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
+ )
+ above16k = [(self.split16k, self.max_index)]
+
+ return below1k + below2k + below4k + below8k + below16k + above16k
+
+
+class PerceptualBandsplitSpecification(BandsplitSpecification):
+ def __init__(
+ self,
+ nfft: int,
+ fs: int,
+ fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
+ n_bands: int,
+ f_min: float = 0.0,
+ f_max: float = None,
+ ) -> None:
+ super().__init__(nfft=nfft, fs=fs)
+ self.n_bands = n_bands
+ if f_max is None:
+ f_max = fs / 2
+
+ self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
+
+ weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True) # (1, n_freqs)
+ normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs)
+
+ freq_weights = []
+ band_specs = []
+ for i in range(self.n_bands):
+ active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
+ if isinstance(active_bins, int):
+ active_bins = (active_bins, active_bins)
+ if len(active_bins) == 0:
+ continue
+ start_index = active_bins[0]
+ end_index = active_bins[-1] + 1
+ band_specs.append((start_index, end_index))
+ freq_weights.append(normalized_mel_fb[i, start_index:end_index])
+
+ self.freq_weights = freq_weights
+ self.band_specs = band_specs
+
+ def get_band_specs(self):
+ return self.band_specs
+
+ def get_freq_weights(self):
+ return self.freq_weights
+
+ def save_to_file(self, dir_path: str) -> None:
+ os.makedirs(dir_path, exist_ok=True)
+
+ import pickle
+
+ with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
+ pickle.dump(
+ {
+ "band_specs": self.band_specs,
+ "freq_weights": self.freq_weights,
+ "filterbank": self.filterbank,
+ },
+ f,
+ )
+
+
+def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
+ fb = taF.melscale_fbanks(
+ n_mels=n_bands,
+ sample_rate=fs,
+ f_min=f_min,
+ f_max=f_max,
+ n_freqs=n_freqs,
+ ).T
+
+ fb[0, 0] = 1.0
+
+ return fb
+
+
+class MelBandsplitSpecification(PerceptualBandsplitSpecification):
+ def __init__(
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
+ ) -> None:
+ super().__init__(
+ fbank_fn=mel_filterbank,
+ nfft=nfft,
+ fs=fs,
+ n_bands=n_bands,
+ f_min=f_min,
+ f_max=f_max,
+ )
+
+
+def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
+ nfft = 2 * (n_freqs - 1)
+ df = fs / nfft
+ # init freqs
+ f_max = f_max or fs / 2
+ f_min = f_min or 0
+ f_min = fs / nfft
+
+ n_octaves = np.log2(f_max / f_min)
+ n_octaves_per_band = n_octaves / n_bands
+ bandwidth_mult = np.power(2.0, n_octaves_per_band)
+
+ low_midi = max(0, hz_to_midi(f_min))
+ high_midi = hz_to_midi(f_max)
+ midi_points = np.linspace(low_midi, high_midi, n_bands)
+ hz_pts = midi_to_hz(midi_points)
+
+ low_pts = hz_pts / bandwidth_mult
+ high_pts = hz_pts * bandwidth_mult
+
+ low_bins = np.floor(low_pts / df).astype(int)
+ high_bins = np.ceil(high_pts / df).astype(int)
+
+ fb = np.zeros((n_bands, n_freqs))
+
+ for i in range(n_bands):
+ fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
+
+ fb[0, : low_bins[0]] = 1.0
+ fb[-1, high_bins[-1] + 1 :] = 1.0
+
+ return torch.as_tensor(fb)
+
+
+class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
+ def __init__(
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
+ ) -> None:
+ super().__init__(
+ fbank_fn=musical_filterbank,
+ nfft=nfft,
+ fs=fs,
+ n_bands=n_bands,
+ f_min=f_min,
+ f_max=f_max,
+ )
+
+
+# def bark_filterbank(
+# n_bands, fs, f_min, f_max, n_freqs
+# ):
+# nfft = 2 * (n_freqs -1)
+# fb, _ = bark_fbanks.bark_filter_banks(
+# nfilts=n_bands,
+# nfft=nfft,
+# fs=fs,
+# low_freq=f_min,
+# high_freq=f_max,
+# scale="constant"
+# )
+
+# return torch.as_tensor(fb)
+
+# class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
+# def __init__(
+# self,
+# nfft: int,
+# fs: int,
+# n_bands: int,
+# f_min: float = 0.0,
+# f_max: float = None
+# ) -> None:
+# super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
+
+
+# def triangular_bark_filterbank(
+# n_bands, fs, f_min, f_max, n_freqs
+# ):
+
+# all_freqs = torch.linspace(0, fs // 2, n_freqs)
+
+# # calculate mel freq bins
+# m_min = hz2bark(f_min)
+# m_max = hz2bark(f_max)
+
+# m_pts = torch.linspace(m_min, m_max, n_bands + 2)
+# f_pts = 600 * torch.sinh(m_pts / 6)
+
+# # create filterbank
+# fb = _create_triangular_filterbank(all_freqs, f_pts)
+
+# fb = fb.T
+
+# first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
+# first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
+
+# fb[first_active_band, :first_active_bin] = 1.0
+
+# return fb
+
+# class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
+# def __init__(
+# self,
+# nfft: int,
+# fs: int,
+# n_bands: int,
+# f_min: float = 0.0,
+# f_max: float = None
+# ) -> None:
+# super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
+
+
+# def minibark_filterbank(
+# n_bands, fs, f_min, f_max, n_freqs
+# ):
+# fb = bark_filterbank(
+# n_bands,
+# fs,
+# f_min,
+# f_max,
+# n_freqs
+# )
+
+# fb[fb < np.sqrt(0.5)] = 0.0
+
+# return fb
+
+# class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
+# def __init__(
+# self,
+# nfft: int,
+# fs: int,
+# n_bands: int,
+# f_min: float = 0.0,
+# f_max: float = None
+# ) -> None:
+# super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
+
+
+# def erb_filterbank(
+# n_bands: int,
+# fs: int,
+# f_min: float,
+# f_max: float,
+# n_freqs: int,
+# ) -> Tensor:
+# # freq bins
+# A = (1000 * np.log(10)) / (24.7 * 4.37)
+# all_freqs = torch.linspace(0, fs // 2, n_freqs)
+
+# # calculate mel freq bins
+# m_min = hz2erb(f_min)
+# m_max = hz2erb(f_max)
+
+# m_pts = torch.linspace(m_min, m_max, n_bands + 2)
+# f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437
+
+# # create filterbank
+# fb = _create_triangular_filterbank(all_freqs, f_pts)
+
+# fb = fb.T
+
+
+# first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
+# first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
+
+# fb[first_active_band, :first_active_bin] = 1.0
+
+# return fb
+
+
+# class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
+# def __init__(
+# self,
+# nfft: int,
+# fs: int,
+# n_bands: int,
+# f_min: float = 0.0,
+# f_max: float = None
+# ) -> None:
+# super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
+
+if __name__ == "__main__":
+ import pandas as pd
+
+ band_defs = []
+
+ for bands in [VocalBandsplitSpecification]:
+ band_name = bands.__name__.replace("BandsplitSpecification", "")
+
+ mbs = bands(nfft=2048, fs=44100).get_band_specs()
+
+ for i, (f_min, f_max) in enumerate(mbs):
+ band_defs.append(
+ {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
+ )
+
+ df = pd.DataFrame(band_defs)
+ df.to_csv("vox7bands.csv", index=False)
diff --git a/mvsepless/models/bs_roformer/__init__.py b/mvsepless/models/bs_roformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..378a20b247a018ac062afd8db0bc2997841c22e5
--- /dev/null
+++ b/mvsepless/models/bs_roformer/__init__.py
@@ -0,0 +1,4 @@
+from .bs_roformer import BSRoformer
+from .bs_roformer_sw import BSRoformer_SW
+from .bs_roformer_fno import BSRoformer_FNO
+from .mel_band_roformer import MelBandRoformer
diff --git a/mvsepless/models/bs_roformer/attend.py b/mvsepless/models/bs_roformer/attend.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ebb7c937268ff4568ef6dbdcbc90abbfc8f0887
--- /dev/null
+++ b/mvsepless/models/bs_roformer/attend.py
@@ -0,0 +1,144 @@
+from functools import wraps
+from packaging import version
+from collections import namedtuple
+
+import os
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+
+from einops import rearrange, reduce
+
+# constants
+
+FlashAttentionConfig = namedtuple(
+ "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]
+)
+
+# helpers
+
+
+def exists(val):
+ return val is not None
+
+
+def default(v, d):
+ return v if exists(v) else d
+
+
+def once(fn):
+ called = False
+
+ @wraps(fn)
+ def inner(x):
+ nonlocal called
+ if called:
+ return
+ called = True
+ return fn(x)
+
+ return inner
+
+
+print_once = once(print)
+
+# main class
+
+
+class Attend(nn.Module):
+ def __init__(self, dropout=0.0, flash=False, scale=None):
+ super().__init__()
+ self.scale = scale
+ self.dropout = dropout
+ self.attn_dropout = nn.Dropout(dropout)
+
+ self.flash = flash
+ assert not (
+ flash and version.parse(torch.__version__) < version.parse("2.0.0")
+ ), "in order to use flash attention, you must be using pytorch 2.0 or above"
+
+ # determine efficient attention configs for cuda and cpu
+
+ self.cpu_config = FlashAttentionConfig(True, True, True)
+ self.cuda_config = None
+
+ if not torch.cuda.is_available() or not flash:
+ return
+
+ device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
+ device_version = version.parse(
+ f"{device_properties.major}.{device_properties.minor}"
+ )
+
+ if device_version >= version.parse("8.0"):
+ if os.name == "nt":
+ print_once(
+ "Windows OS detected, using math or mem efficient attention if input tensor is on cuda"
+ )
+ self.cuda_config = FlashAttentionConfig(False, True, True)
+ else:
+ print_once(
+ "GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda"
+ )
+ self.cuda_config = FlashAttentionConfig(True, False, False)
+ else:
+ print_once(
+ "GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda"
+ )
+ self.cuda_config = FlashAttentionConfig(False, True, True)
+
+ def flash_attn(self, q, k, v):
+ _, heads, q_len, _, k_len, is_cuda, device = (
+ *q.shape,
+ k.shape[-2],
+ q.is_cuda,
+ q.device,
+ )
+
+ if exists(self.scale):
+ default_scale = q.shape[-1] ** -0.5
+ q = q * (self.scale / default_scale)
+
+ # Check if there is a compatible device for flash attention
+
+ config = self.cuda_config if is_cuda else self.cpu_config
+
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
+
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
+ out = F.scaled_dot_product_attention(
+ q, k, v, dropout_p=self.dropout if self.training else 0.0
+ )
+
+ return out
+
+ def forward(self, q, k, v):
+ """
+ einstein notation
+ b - batch
+ h - heads
+ n, i, j - sequence length (base sequence length, source, target)
+ d - feature dimension
+ """
+
+ q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
+
+ scale = default(self.scale, q.shape[-1] ** -0.5)
+
+ if self.flash:
+ return self.flash_attn(q, k, v)
+
+ # similarity
+
+ sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
+
+ # attention
+
+ attn = sim.softmax(dim=-1)
+ attn = self.attn_dropout(attn)
+
+ # aggregate values
+
+ out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
+
+ return out
diff --git a/mvsepless/models/bs_roformer/attend_sw.py b/mvsepless/models/bs_roformer/attend_sw.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc73e3ea8875ba351901fda08f7b4353248e8ae8
--- /dev/null
+++ b/mvsepless/models/bs_roformer/attend_sw.py
@@ -0,0 +1,100 @@
+import logging
+import os
+
+import torch
+import torch.nn.functional as F
+from packaging import version
+from torch import Tensor, einsum, nn
+from torch.nn.attention import SDPBackend, sdpa_kernel
+
+logger = logging.getLogger(__name__)
+
+
+class Attend(nn.Module):
+ def __init__(self, dropout: float = 0.0, flash: bool = False, scale=None):
+ super().__init__()
+ self.scale = scale
+ self.dropout = dropout
+ self.attn_dropout = nn.Dropout(dropout)
+
+ self.flash = flash
+ assert not (
+ flash and version.parse(torch.__version__) < version.parse("2.0.0")
+ ), "expected pytorch >= 2.0.0 to use flash attention"
+
+ # determine efficient attention configs for cuda and cpu
+ self.cpu_backends = [
+ SDPBackend.FLASH_ATTENTION,
+ SDPBackend.EFFICIENT_ATTENTION,
+ SDPBackend.MATH,
+ ]
+ self.cuda_backends: list | None = None
+
+ if not torch.cuda.is_available() or not flash:
+ return
+
+ device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
+ device_version = version.parse(
+ f"{device_properties.major}.{device_properties.minor}"
+ )
+
+ if device_version >= version.parse("8.0"):
+ if os.name == "nt":
+ cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
+ logger.info(f"windows detected, {cuda_backends=}")
+ else:
+ cuda_backends = [SDPBackend.FLASH_ATTENTION]
+ logger.info(f"gpu compute capability >= 8.0, {cuda_backends=}")
+ else:
+ cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
+ logger.info(f"gpu compute capability < 8.0, {cuda_backends=}")
+
+ self.cuda_backends = cuda_backends
+
+ def flash_attn(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+ _, _heads, _q_len, _, _k_len, is_cuda, _device = (
+ *q.shape,
+ k.shape[-2],
+ q.is_cuda,
+ q.device,
+ ) # type: ignore
+
+ if self.scale is not None:
+ default_scale = q.shape[-1] ** -0.5
+ q = q * (self.scale / default_scale)
+
+ backends = self.cuda_backends if is_cuda else self.cpu_backends
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
+ with sdpa_kernel(backends=backends): # type: ignore
+ out = F.scaled_dot_product_attention(
+ q, k, v, dropout_p=self.dropout if self.training else 0.0
+ )
+
+ return out
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+ """
+ einstein notation
+ b - batch
+ h - heads
+ n, i, j - sequence length (base sequence length, source, target)
+ d - feature dimension
+ """
+ _q_len, _k_len, _device = q.shape[-2], k.shape[-2], q.device
+
+ scale = self.scale or q.shape[-1] ** -0.5
+
+ if self.flash:
+ return self.flash_attn(q, k, v)
+
+ # similarity
+ sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
+
+ # attention
+ attn = sim.softmax(dim=-1)
+ attn = self.attn_dropout(attn)
+
+ # aggregate values
+ out = einsum("b h i j, b h j d -> b h i d", attn, v)
+
+ return out
diff --git a/mvsepless/models/bs_roformer/bs_roformer.py b/mvsepless/models/bs_roformer/bs_roformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..adf2faed4c2a8a4db7606cb820207834eab256c5
--- /dev/null
+++ b/mvsepless/models/bs_roformer/bs_roformer.py
@@ -0,0 +1,708 @@
+from functools import partial
+
+import torch
+from torch import nn, einsum, Tensor
+from torch.nn import Module, ModuleList
+import torch.nn.functional as F
+
+from .attend import Attend
+from torch.utils.checkpoint import checkpoint
+
+from beartype.typing import Tuple, Optional, List, Callable
+from beartype import beartype
+
+from rotary_embedding_torch import RotaryEmbedding
+
+from einops import rearrange, pack, unpack
+from einops.layers.torch import Rearrange
+
+# helper functions
+
+
+def exists(val):
+ return val is not None
+
+
+def default(v, d):
+ return v if exists(v) else d
+
+
+def pack_one(t, pattern):
+ return pack([t], pattern)
+
+
+def unpack_one(t, ps, pattern):
+ return unpack(t, ps, pattern)[0]
+
+
+# norm
+
+
+def l2norm(t):
+ return F.normalize(t, dim=-1, p=2)
+
+
+class RMSNorm(Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
+
+
+# attention
+
+
+class FeedForward(Module):
+ def __init__(self, dim, mult=4, dropout=0.0):
+ super().__init__()
+ dim_inner = int(dim * mult)
+ self.net = nn.Sequential(
+ RMSNorm(dim),
+ nn.Linear(dim, dim_inner),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(dim_inner, dim),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class Attention(Module):
+ def __init__(
+ self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True
+ ):
+ super().__init__()
+ self.heads = heads
+ self.scale = dim_head**-0.5
+ dim_inner = heads * dim_head
+
+ self.rotary_embed = rotary_embed
+
+ self.attend = Attend(flash=flash, dropout=dropout)
+
+ self.norm = RMSNorm(dim)
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
+
+ self.to_gates = nn.Linear(dim, heads)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
+ )
+
+ def forward(self, x):
+ x = self.norm(x)
+
+ q, k, v = rearrange(
+ self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
+ )
+
+ if exists(self.rotary_embed):
+ q = self.rotary_embed.rotate_queries_or_keys(q)
+ k = self.rotary_embed.rotate_queries_or_keys(k)
+
+ out = self.attend(q, k, v)
+
+ gates = self.to_gates(x)
+ out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
+
+ out = rearrange(out, "b h n d -> b n (h d)")
+ return self.to_out(out)
+
+
+class LinearAttention(Module):
+ """
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
+ """
+
+ @beartype
+ def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0):
+ super().__init__()
+ dim_inner = dim_head * heads
+ self.norm = RMSNorm(dim)
+
+ self.to_qkv = nn.Sequential(
+ nn.Linear(dim, dim_inner * 3, bias=False),
+ Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
+ )
+
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
+
+ self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
+
+ self.to_out = nn.Sequential(
+ Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
+ )
+
+ def forward(self, x):
+ x = self.norm(x)
+
+ q, k, v = self.to_qkv(x)
+
+ q, k = map(l2norm, (q, k))
+ q = q * self.temperature.exp()
+
+ out = self.attend(q, k, v)
+
+ return self.to_out(out)
+
+
+class Transformer(Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ depth,
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.0,
+ ff_dropout=0.0,
+ ff_mult=4,
+ norm_output=True,
+ rotary_embed=None,
+ flash_attn=True,
+ linear_attn=False,
+ ):
+ super().__init__()
+ self.layers = ModuleList([])
+
+ for _ in range(depth):
+ if linear_attn:
+ attn = LinearAttention(
+ dim=dim,
+ dim_head=dim_head,
+ heads=heads,
+ dropout=attn_dropout,
+ flash=flash_attn,
+ )
+ else:
+ attn = Attention(
+ dim=dim,
+ dim_head=dim_head,
+ heads=heads,
+ dropout=attn_dropout,
+ rotary_embed=rotary_embed,
+ flash=flash_attn,
+ )
+
+ self.layers.append(
+ ModuleList(
+ [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
+ )
+ )
+
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
+
+ def forward(self, x):
+
+ for attn, ff in self.layers:
+ x = attn(x) + x
+ x = ff(x) + x
+
+ return self.norm(x)
+
+
+# bandsplit module
+
+
+class BandSplit(Module):
+ @beartype
+ def __init__(self, dim, dim_inputs: Tuple[int, ...]):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_features = ModuleList([])
+
+ for dim_in in dim_inputs:
+ net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
+
+ self.to_features.append(net)
+
+ def forward(self, x):
+ x = x.split(self.dim_inputs, dim=-1)
+
+ outs = []
+ for split_input, to_feature in zip(x, self.to_features):
+ split_output = to_feature(split_input)
+ outs.append(split_output)
+
+ return torch.stack(outs, dim=-2)
+
+
+def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
+ dim_hidden = default(dim_hidden, dim_in)
+
+ net = []
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
+
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
+ is_last = ind == (len(dims) - 2)
+
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
+
+ if is_last:
+ continue
+
+ net.append(activation())
+
+ return nn.Sequential(*net)
+
+
+class MaskEstimator(Module):
+ @beartype
+ def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_freqs = ModuleList([])
+ dim_hidden = dim * mlp_expansion_factor
+
+ for dim_in in dim_inputs:
+ net = []
+
+ mlp = nn.Sequential(
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
+ )
+
+ self.to_freqs.append(mlp)
+
+ def forward(self, x):
+ x = x.unbind(dim=-2)
+
+ outs = []
+
+ for band_features, mlp in zip(x, self.to_freqs):
+ freq_out = mlp(band_features)
+ outs.append(freq_out)
+
+ return torch.cat(outs, dim=-1)
+
+
+# main class
+
+DEFAULT_FREQS_PER_BANDS = (
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 12,
+ 12,
+ 12,
+ 12,
+ 12,
+ 12,
+ 12,
+ 12,
+ 24,
+ 24,
+ 24,
+ 24,
+ 24,
+ 24,
+ 24,
+ 24,
+ 48,
+ 48,
+ 48,
+ 48,
+ 48,
+ 48,
+ 48,
+ 48,
+ 128,
+ 129,
+)
+
+
+class BSRoformer(Module):
+
+ @beartype
+ def __init__(
+ self,
+ dim,
+ *,
+ depth,
+ stereo=False,
+ num_stems=1,
+ time_transformer_depth=2,
+ freq_transformer_depth=2,
+ linear_transformer_depth=0,
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
+ # in the paper, they divide into ~60 bands, test with 1 for starters
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.0,
+ ff_dropout=0.0,
+ flash_attn=True,
+ dim_freqs_in=1025,
+ stft_n_fft=2048,
+ stft_hop_length=512,
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
+ stft_win_length=2048,
+ stft_normalized=False,
+ stft_window_fn: Optional[Callable] = None,
+ mask_estimator_depth=2,
+ multi_stft_resolution_loss_weight=1.0,
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
+ 4096,
+ 2048,
+ 1024,
+ 512,
+ 256,
+ ),
+ multi_stft_hop_size=147,
+ multi_stft_normalized=False,
+ multi_stft_window_fn: Callable = torch.hann_window,
+ mlp_expansion_factor=4,
+ use_torch_checkpoint=False,
+ skip_connection=False,
+ ):
+ super().__init__()
+
+ self.stereo = stereo
+ self.audio_channels = 2 if stereo else 1
+ self.num_stems = num_stems
+ self.use_torch_checkpoint = use_torch_checkpoint
+ self.skip_connection = skip_connection
+
+ self.layers = ModuleList([])
+
+ transformer_kwargs = dict(
+ dim=dim,
+ heads=heads,
+ dim_head=dim_head,
+ attn_dropout=attn_dropout,
+ ff_dropout=ff_dropout,
+ flash_attn=flash_attn,
+ norm_output=False,
+ )
+
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
+
+ for _ in range(depth):
+ tran_modules = []
+ if linear_transformer_depth > 0:
+ tran_modules.append(
+ Transformer(
+ depth=linear_transformer_depth,
+ linear_attn=True,
+ **transformer_kwargs,
+ )
+ )
+ tran_modules.append(
+ Transformer(
+ depth=time_transformer_depth,
+ rotary_embed=time_rotary_embed,
+ **transformer_kwargs,
+ )
+ )
+ tran_modules.append(
+ Transformer(
+ depth=freq_transformer_depth,
+ rotary_embed=freq_rotary_embed,
+ **transformer_kwargs,
+ )
+ )
+ self.layers.append(nn.ModuleList(tran_modules))
+
+ self.final_norm = RMSNorm(dim)
+
+ self.stft_kwargs = dict(
+ n_fft=stft_n_fft,
+ hop_length=stft_hop_length,
+ win_length=stft_win_length,
+ normalized=stft_normalized,
+ )
+
+ self.stft_window_fn = partial(
+ default(stft_window_fn, torch.hann_window), stft_win_length
+ )
+
+ freqs = torch.stft(
+ torch.randn(1, 4096),
+ **self.stft_kwargs,
+ window=torch.ones(stft_win_length),
+ return_complex=True,
+ ).shape[1]
+
+ assert len(freqs_per_bands) > 1
+ assert (
+ sum(freqs_per_bands) == freqs
+ ), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
+
+ freqs_per_bands_with_complex = tuple(
+ 2 * f * self.audio_channels for f in freqs_per_bands
+ )
+
+ self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
+
+ self.mask_estimators = nn.ModuleList([])
+
+ for _ in range(num_stems):
+ mask_estimator = MaskEstimator(
+ dim=dim,
+ dim_inputs=freqs_per_bands_with_complex,
+ depth=mask_estimator_depth,
+ mlp_expansion_factor=mlp_expansion_factor,
+ )
+
+ self.mask_estimators.append(mask_estimator)
+
+ # for the multi-resolution stft loss
+
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
+ self.multi_stft_n_fft = stft_n_fft
+ self.multi_stft_window_fn = multi_stft_window_fn
+
+ self.multi_stft_kwargs = dict(
+ hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
+ )
+
+ def forward(self, raw_audio, target=None, return_loss_breakdown=False):
+ """
+ einops
+
+ b - batch
+ f - freq
+ t - time
+ s - audio channel (1 for mono, 2 for stereo)
+ n - number of 'stems'
+ c - complex (2)
+ d - feature dimension
+ """
+
+ device = raw_audio.device
+
+ # defining whether model is loaded on MPS (MacOS GPU accelerator)
+ x_is_mps = True if device.type == "mps" else False
+
+ if raw_audio.ndim == 2:
+ raw_audio = rearrange(raw_audio, "b t -> b 1 t")
+
+ channels = raw_audio.shape[1]
+ assert (not self.stereo and channels == 1) or (
+ self.stereo and channels == 2
+ ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
+
+ # to stft
+
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
+
+ stft_window = self.stft_window_fn(device=device)
+
+ # RuntimeError: FFT operations are only supported on MacOS 14+
+ # Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used
+ try:
+ stft_repr = torch.stft(
+ raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
+ )
+ except:
+ stft_repr = torch.stft(
+ raw_audio.cpu() if x_is_mps else raw_audio,
+ **self.stft_kwargs,
+ window=stft_window.cpu() if x_is_mps else stft_window,
+ return_complex=True,
+ ).to(device)
+
+ stft_repr = torch.view_as_real(stft_repr)
+
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
+ stft_repr = rearrange(
+ stft_repr, "b s f t c -> b (f s) t c"
+ ) # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
+
+ x = rearrange(stft_repr, "b f t c -> b t (f c)")
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(self.band_split, x, use_reentrant=False)
+ else:
+ x = self.band_split(x)
+
+ # axial / hierarchical attention
+
+ store = [None] * len(self.layers)
+ for i, transformer_block in enumerate(self.layers):
+
+ if len(transformer_block) == 3:
+ linear_transformer, time_transformer, freq_transformer = (
+ transformer_block
+ )
+
+ x, ft_ps = pack([x], "b * d")
+ if self.use_torch_checkpoint:
+ x = checkpoint(linear_transformer, x, use_reentrant=False)
+ else:
+ x = linear_transformer(x)
+ (x,) = unpack(x, ft_ps, "b * d")
+ else:
+ time_transformer, freq_transformer = transformer_block
+
+ if self.skip_connection:
+ # Sum all previous
+ for j in range(i):
+ x = x + store[j]
+
+ x = rearrange(x, "b t f d -> b f t d")
+ x, ps = pack([x], "* t d")
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(time_transformer, x, use_reentrant=False)
+ else:
+ x = time_transformer(x)
+
+ (x,) = unpack(x, ps, "* t d")
+ x = rearrange(x, "b f t d -> b t f d")
+ x, ps = pack([x], "* f d")
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(freq_transformer, x, use_reentrant=False)
+ else:
+ x = freq_transformer(x)
+
+ (x,) = unpack(x, ps, "* f d")
+
+ if self.skip_connection:
+ store[i] = x
+
+ x = self.final_norm(x)
+
+ num_stems = len(self.mask_estimators)
+
+ if self.use_torch_checkpoint:
+ mask = torch.stack(
+ [checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators],
+ dim=1,
+ )
+ else:
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
+ mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
+
+ # modulate frequency representation
+
+ stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
+
+ # complex number multiplication
+
+ stft_repr = torch.view_as_complex(stft_repr)
+ mask = torch.view_as_complex(mask)
+
+ stft_repr = stft_repr * mask
+
+ # istft
+
+ stft_repr = rearrange(
+ stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
+ )
+
+ # same as torch.stft() fix for MacOS MPS above
+ try:
+ recon_audio = torch.istft(
+ stft_repr,
+ **self.stft_kwargs,
+ window=stft_window,
+ return_complex=False,
+ length=raw_audio.shape[-1],
+ )
+ except:
+ recon_audio = torch.istft(
+ stft_repr.cpu() if x_is_mps else stft_repr,
+ **self.stft_kwargs,
+ window=stft_window.cpu() if x_is_mps else stft_window,
+ return_complex=False,
+ length=raw_audio.shape[-1],
+ ).to(device)
+
+ recon_audio = rearrange(
+ recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
+ )
+
+ if num_stems == 1:
+ recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
+
+ # if a target is passed in, calculate loss for learning
+
+ if not exists(target):
+ return recon_audio
+
+ if self.num_stems > 1:
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
+
+ if target.ndim == 2:
+ target = rearrange(target, "... t -> ... 1 t")
+
+ target = target[
+ ..., : recon_audio.shape[-1]
+ ] # protect against lost length on istft
+
+ loss = F.l1_loss(recon_audio, target)
+
+ multi_stft_resolution_loss = 0.0
+
+ for window_size in self.multi_stft_resolutions_window_sizes:
+ res_stft_kwargs = dict(
+ n_fft=max(
+ window_size, self.multi_stft_n_fft
+ ), # not sure what n_fft is across multi resolution stft
+ win_length=window_size,
+ return_complex=True,
+ window=self.multi_stft_window_fn(window_size, device=device),
+ **self.multi_stft_kwargs,
+ )
+
+ recon_Y = torch.stft(
+ rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
+ )
+ target_Y = torch.stft(
+ rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
+ )
+
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
+ recon_Y, target_Y
+ )
+
+ weighted_multi_resolution_loss = (
+ multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
+ )
+
+ total_loss = loss + weighted_multi_resolution_loss
+
+ if not return_loss_breakdown:
+ return total_loss
+
+ return total_loss, (loss, multi_stft_resolution_loss)
diff --git a/mvsepless/models/bs_roformer/bs_roformer_fno.py b/mvsepless/models/bs_roformer/bs_roformer_fno.py
new file mode 100644
index 0000000000000000000000000000000000000000..858a8aa52f74263797ce22c66f2dcf92180f635a
--- /dev/null
+++ b/mvsepless/models/bs_roformer/bs_roformer_fno.py
@@ -0,0 +1,758 @@
+from functools import partial
+
+import torch
+from torch import nn, einsum, Tensor
+from torch.nn import Module, ModuleList
+import torch.nn.functional as F
+from neuralop.models import FNO1d
+
+from .attend_sw import Attend
+
+try:
+ from .attend_sage import Attend as AttendSage
+except:
+ pass
+from torch.utils.checkpoint import checkpoint
+
+from beartype.typing import Tuple, Optional, List, Callable
+from beartype import beartype
+
+from rotary_embedding_torch import RotaryEmbedding
+
+from einops import rearrange, pack, unpack
+from einops.layers.torch import Rearrange
+
+# helper functions
+
+
+def exists(val):
+ return val is not None
+
+
+def default(v, d):
+ return v if exists(v) else d
+
+
+def pack_one(t, pattern):
+ return pack([t], pattern)
+
+
+def unpack_one(t, ps, pattern):
+ return unpack(t, ps, pattern)[0]
+
+
+# norm
+
+
+def l2norm(t):
+ return F.normalize(t, dim=-1, p=2)
+
+
+class RMSNorm(Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
+
+
+# attention
+
+
+class FeedForward(Module):
+ def __init__(self, dim, mult=4, dropout=0.0):
+ super().__init__()
+ dim_inner = int(dim * mult)
+ self.net = nn.Sequential(
+ RMSNorm(dim),
+ nn.Linear(dim, dim_inner),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(dim_inner, dim),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class Attention(Module):
+ def __init__(
+ self,
+ dim,
+ heads=8,
+ dim_head=64,
+ dropout=0.0,
+ rotary_embed=None,
+ flash=True,
+ sage_attention=False,
+ ):
+ super().__init__()
+ self.heads = heads
+ self.scale = dim_head**-0.5
+ dim_inner = heads * dim_head
+
+ self.rotary_embed = rotary_embed
+
+ if sage_attention:
+ self.attend = AttendSage(flash=flash, dropout=dropout)
+ else:
+ self.attend = Attend(flash=flash, dropout=dropout)
+
+ self.norm = RMSNorm(dim)
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
+
+ self.to_gates = nn.Linear(dim, heads)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
+ )
+
+ def forward(self, x):
+ x = self.norm(x)
+
+ q, k, v = rearrange(
+ self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
+ )
+
+ if exists(self.rotary_embed):
+ q = self.rotary_embed.rotate_queries_or_keys(q)
+ k = self.rotary_embed.rotate_queries_or_keys(k)
+
+ out = self.attend(q, k, v)
+
+ gates = self.to_gates(x)
+ out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
+
+ out = rearrange(out, "b h n d -> b n (h d)")
+ return self.to_out(out)
+
+
+class LinearAttention(Module):
+ """
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
+ """
+
+ @beartype
+ def __init__(
+ self,
+ *,
+ dim,
+ dim_head=32,
+ heads=8,
+ scale=8,
+ flash=False,
+ dropout=0.0,
+ sage_attention=False,
+ ):
+ super().__init__()
+ dim_inner = dim_head * heads
+ self.norm = RMSNorm(dim)
+
+ self.to_qkv = nn.Sequential(
+ nn.Linear(dim, dim_inner * 3, bias=False),
+ Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
+ )
+
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
+
+ if sage_attention:
+ self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
+ else:
+ self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
+
+ self.to_out = nn.Sequential(
+ Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
+ )
+
+ def forward(self, x):
+ x = self.norm(x)
+
+ q, k, v = self.to_qkv(x)
+
+ q, k = map(l2norm, (q, k))
+ q = q * self.temperature.exp()
+
+ out = self.attend(q, k, v)
+
+ return self.to_out(out)
+
+
+class Transformer(Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ depth,
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.0,
+ ff_dropout=0.0,
+ ff_mult=4,
+ norm_output=True,
+ rotary_embed=None,
+ flash_attn=True,
+ linear_attn=False,
+ sage_attention=False,
+ ):
+ super().__init__()
+ self.layers = ModuleList([])
+
+ for _ in range(depth):
+ if linear_attn:
+ attn = LinearAttention(
+ dim=dim,
+ dim_head=dim_head,
+ heads=heads,
+ dropout=attn_dropout,
+ flash=flash_attn,
+ sage_attention=sage_attention,
+ )
+ else:
+ attn = Attention(
+ dim=dim,
+ dim_head=dim_head,
+ heads=heads,
+ dropout=attn_dropout,
+ rotary_embed=rotary_embed,
+ flash=flash_attn,
+ sage_attention=sage_attention,
+ )
+
+ self.layers.append(
+ ModuleList(
+ [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
+ )
+ )
+
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
+
+ def forward(self, x):
+
+ for attn, ff in self.layers:
+ x = attn(x) + x
+ x = ff(x) + x
+
+ return self.norm(x)
+
+
+# bandsplit module
+
+
+class BandSplit(Module):
+ @beartype
+ def __init__(self, dim, dim_inputs: Tuple[int, ...]):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_features = ModuleList([])
+
+ for dim_in in dim_inputs:
+ net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
+
+ self.to_features.append(net)
+
+ def forward(self, x):
+ x = x.split(self.dim_inputs, dim=-1)
+
+ outs = []
+ for split_input, to_feature in zip(x, self.to_features):
+ split_output = to_feature(split_input)
+ outs.append(split_output)
+
+ return torch.stack(outs, dim=-2)
+
+
+def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
+ dim_hidden = default(dim_hidden, dim_in)
+
+ net = []
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
+
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
+ is_last = ind == (len(dims) - 2)
+
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
+
+ if is_last:
+ continue
+
+ net.append(activation())
+
+ return nn.Sequential(*net)
+
+
+class MaskEstimator(Module):
+ @beartype
+ def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_freqs = ModuleList([])
+ dim_hidden = dim * mlp_expansion_factor
+
+ for dim_in in dim_inputs:
+ net = []
+
+ mlp = nn.Sequential(
+ FNO1d(
+ n_modes_height=64,
+ hidden_channels=dim,
+ in_channels=dim,
+ out_channels=dim_in * 2,
+ lifting_channels=dim,
+ projection_channels=dim,
+ n_layers=3,
+ separable=True,
+ ),
+ nn.GLU(dim=-2),
+ )
+
+ self.to_freqs.append(mlp)
+
+ def forward(self, x):
+ x = x.unbind(dim=-2)
+
+ outs = []
+
+ for band_features, mlp in zip(x, self.to_freqs):
+ band_features = rearrange(band_features, "b t c -> b c t")
+ with torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32):
+ freq_out = mlp(band_features).float()
+ freq_out = rearrange(freq_out, "b c t -> b t c")
+ outs.append(freq_out)
+
+ return torch.cat(outs, dim=-1)
+
+
+# main class
+
+DEFAULT_FREQS_PER_BANDS = (
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 12,
+ 12,
+ 12,
+ 12,
+ 12,
+ 12,
+ 12,
+ 12,
+ 24,
+ 24,
+ 24,
+ 24,
+ 24,
+ 24,
+ 24,
+ 24,
+ 48,
+ 48,
+ 48,
+ 48,
+ 48,
+ 48,
+ 48,
+ 48,
+ 128,
+ 129,
+)
+
+
+class BSRoformer_FNO(Module):
+
+ @beartype
+ def __init__(
+ self,
+ dim,
+ *,
+ depth,
+ stereo=False,
+ num_stems=1,
+ time_transformer_depth=2,
+ freq_transformer_depth=2,
+ linear_transformer_depth=0,
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
+ # in the paper, they divide into ~60 bands, test with 1 for starters
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.0,
+ ff_dropout=0.0,
+ flash_attn=True,
+ dim_freqs_in=1025,
+ stft_n_fft=2048,
+ stft_hop_length=512,
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
+ stft_win_length=2048,
+ stft_normalized=False,
+ stft_window_fn: Optional[Callable] = None,
+ mask_estimator_depth=2,
+ multi_stft_resolution_loss_weight=1.0,
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
+ 4096,
+ 2048,
+ 1024,
+ 512,
+ 256,
+ ),
+ multi_stft_hop_size=147,
+ multi_stft_normalized=False,
+ multi_stft_window_fn: Callable = torch.hann_window,
+ mlp_expansion_factor=4,
+ use_torch_checkpoint=False,
+ skip_connection=False,
+ sage_attention=False,
+ fno=True,
+ ):
+ super().__init__()
+
+ self.stereo = stereo
+ self.audio_channels = 2 if stereo else 1
+ self.num_stems = num_stems
+ self.use_torch_checkpoint = use_torch_checkpoint
+ self.skip_connection = skip_connection
+
+ self.layers = ModuleList([])
+
+ if sage_attention:
+ print("Use Sage Attention")
+
+ transformer_kwargs = dict(
+ dim=dim,
+ heads=heads,
+ dim_head=dim_head,
+ attn_dropout=attn_dropout,
+ ff_dropout=ff_dropout,
+ flash_attn=flash_attn,
+ norm_output=False,
+ sage_attention=sage_attention,
+ )
+
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
+
+ for _ in range(depth):
+ tran_modules = []
+ if linear_transformer_depth > 0:
+ tran_modules.append(
+ Transformer(
+ depth=linear_transformer_depth,
+ linear_attn=True,
+ **transformer_kwargs,
+ )
+ )
+ tran_modules.append(
+ Transformer(
+ depth=time_transformer_depth,
+ rotary_embed=time_rotary_embed,
+ **transformer_kwargs,
+ )
+ )
+ tran_modules.append(
+ Transformer(
+ depth=freq_transformer_depth,
+ rotary_embed=freq_rotary_embed,
+ **transformer_kwargs,
+ )
+ )
+ self.layers.append(nn.ModuleList(tran_modules))
+
+ self.final_norm = RMSNorm(dim)
+
+ self.stft_kwargs = dict(
+ n_fft=stft_n_fft,
+ hop_length=stft_hop_length,
+ win_length=stft_win_length,
+ normalized=stft_normalized,
+ )
+
+ self.stft_window_fn = partial(
+ default(stft_window_fn, torch.hann_window), stft_win_length
+ )
+
+ freqs = torch.stft(
+ torch.randn(1, 4096),
+ **self.stft_kwargs,
+ window=torch.ones(stft_win_length),
+ return_complex=True,
+ ).shape[1]
+
+ assert len(freqs_per_bands) > 1
+ assert (
+ sum(freqs_per_bands) == freqs
+ ), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
+
+ freqs_per_bands_with_complex = tuple(
+ 2 * f * self.audio_channels for f in freqs_per_bands
+ )
+
+ self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
+
+ self.mask_estimators = nn.ModuleList([])
+
+ for _ in range(num_stems):
+ mask_estimator = MaskEstimator(
+ dim=dim,
+ dim_inputs=freqs_per_bands_with_complex,
+ depth=mask_estimator_depth,
+ mlp_expansion_factor=mlp_expansion_factor,
+ )
+
+ self.mask_estimators.append(mask_estimator)
+
+ # for the multi-resolution stft loss
+
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
+ self.multi_stft_n_fft = stft_n_fft
+ self.multi_stft_window_fn = multi_stft_window_fn
+
+ self.multi_stft_kwargs = dict(
+ hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
+ )
+
+ def forward(self, raw_audio, target=None, return_loss_breakdown=False):
+ """
+ einops
+
+ b - batch
+ f - freq
+ t - time
+ s - audio channel (1 for mono, 2 for stereo)
+ n - number of 'stems'
+ c - complex (2)
+ d - feature dimension
+ """
+
+ device = raw_audio.device
+
+ # defining whether model is loaded on MPS (MacOS GPU accelerator)
+ x_is_mps = True if device.type == "mps" else False
+
+ if raw_audio.ndim == 2:
+ raw_audio = rearrange(raw_audio, "b t -> b 1 t")
+
+ channels = raw_audio.shape[1]
+ assert (not self.stereo and channels == 1) or (
+ self.stereo and channels == 2
+ ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
+
+ # to stft
+
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
+
+ stft_window = self.stft_window_fn(device=device)
+
+ # RuntimeError: FFT operations are only supported on MacOS 14+
+ # Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used
+ try:
+ stft_repr = torch.stft(
+ raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
+ )
+ except:
+ stft_repr = torch.stft(
+ raw_audio.cpu() if x_is_mps else raw_audio,
+ **self.stft_kwargs,
+ window=stft_window.cpu() if x_is_mps else stft_window,
+ return_complex=True,
+ ).to(device)
+ stft_repr = torch.view_as_real(stft_repr)
+
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
+
+ # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
+ stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
+
+ x = rearrange(stft_repr, "b f t c -> b t (f c)")
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(self.band_split, x, use_reentrant=False)
+ else:
+ x = self.band_split(x)
+
+ # axial / hierarchical attention
+
+ store = [None] * len(self.layers)
+ for i, transformer_block in enumerate(self.layers):
+
+ if len(transformer_block) == 3:
+ linear_transformer, time_transformer, freq_transformer = (
+ transformer_block
+ )
+
+ x, ft_ps = pack([x], "b * d")
+ if self.use_torch_checkpoint:
+ x = checkpoint(linear_transformer, x, use_reentrant=False)
+ else:
+ x = linear_transformer(x)
+ (x,) = unpack(x, ft_ps, "b * d")
+ else:
+ time_transformer, freq_transformer = transformer_block
+
+ if self.skip_connection:
+ # Sum all previous
+ for j in range(i):
+ x = x + store[j]
+
+ x = rearrange(x, "b t f d -> b f t d")
+ x, ps = pack([x], "* t d")
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(time_transformer, x, use_reentrant=False)
+ else:
+ x = time_transformer(x)
+
+ (x,) = unpack(x, ps, "* t d")
+ x = rearrange(x, "b f t d -> b t f d")
+ x, ps = pack([x], "* f d")
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(freq_transformer, x, use_reentrant=False)
+ else:
+ x = freq_transformer(x)
+
+ (x,) = unpack(x, ps, "* f d")
+
+ if self.skip_connection:
+ store[i] = x
+
+ x = self.final_norm(x)
+
+ num_stems = len(self.mask_estimators)
+
+ if self.use_torch_checkpoint:
+ mask = torch.stack(
+ [checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators],
+ dim=1,
+ )
+ else:
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
+ mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
+
+ # modulate frequency representation
+
+ stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
+
+ # complex number multiplication
+
+ stft_repr = torch.view_as_complex(stft_repr)
+ mask = torch.view_as_complex(mask)
+
+ stft_repr = stft_repr * mask
+
+ # istft
+
+ stft_repr = rearrange(
+ stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
+ )
+
+ # same as torch.stft() fix for MacOS MPS above
+ try:
+ recon_audio = torch.istft(
+ stft_repr,
+ **self.stft_kwargs,
+ window=stft_window,
+ return_complex=False,
+ length=raw_audio.shape[-1],
+ )
+ except:
+ recon_audio = torch.istft(
+ stft_repr.cpu() if x_is_mps else stft_repr,
+ **self.stft_kwargs,
+ window=stft_window.cpu() if x_is_mps else stft_window,
+ return_complex=False,
+ length=raw_audio.shape[-1],
+ ).to(device)
+
+ recon_audio = rearrange(
+ recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
+ )
+
+ if num_stems == 1:
+ recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
+
+ # if a target is passed in, calculate loss for learning
+
+ if not exists(target):
+ return recon_audio
+
+ if self.num_stems > 1:
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
+
+ if target.ndim == 2:
+ target = rearrange(target, "... t -> ... 1 t")
+
+ target = target[
+ ..., : recon_audio.shape[-1]
+ ] # protect against lost length on istft
+
+ loss = F.l1_loss(recon_audio, target)
+
+ multi_stft_resolution_loss = 0.0
+
+ for window_size in self.multi_stft_resolutions_window_sizes:
+ res_stft_kwargs = dict(
+ n_fft=max(
+ window_size, self.multi_stft_n_fft
+ ), # not sure what n_fft is across multi resolution stft
+ win_length=window_size,
+ return_complex=True,
+ window=self.multi_stft_window_fn(window_size, device=device),
+ **self.multi_stft_kwargs,
+ )
+
+ recon_Y = torch.stft(
+ rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
+ )
+ target_Y = torch.stft(
+ rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
+ )
+
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
+ recon_Y, target_Y
+ )
+
+ weighted_multi_resolution_loss = (
+ multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
+ )
+
+ total_loss = loss + weighted_multi_resolution_loss
+
+ if not return_loss_breakdown:
+ return total_loss
+
+ return total_loss, (loss, multi_stft_resolution_loss)
diff --git a/mvsepless/models/bs_roformer/bs_roformer_sw.py b/mvsepless/models/bs_roformer/bs_roformer_sw.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b04ce518df91bf28126b76c96bc04910fbeb1de
--- /dev/null
+++ b/mvsepless/models/bs_roformer/bs_roformer_sw.py
@@ -0,0 +1,724 @@
+from __future__ import annotations
+
+from functools import partial
+
+import torch
+import torch.nn.functional as F
+from beartype import beartype
+from beartype.typing import Callable
+from einops import pack, rearrange, unpack
+from einops.layers.torch import Rearrange
+from torch import nn
+from torch.nn import Module, ModuleList
+from torch.utils.checkpoint import checkpoint
+
+from .attend_sw import Attend
+
+try:
+ from .attend_sage_sw import AttendSage
+except ImportError:
+ pass
+
+
+def l2norm(t):
+ return F.normalize(t, dim=-1, p=2)
+
+
+class CustomNorm(Module):
+ def __init__(self, dim, eps: float = 5.960464477539063e-08): # 0x1p-24
+ super().__init__()
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(dim))
+ self.eps = eps
+
+ def forward(self, x):
+ l2_norm = torch.linalg.norm(x, dim=-1, keepdim=True)
+ denom = torch.maximum(l2_norm, torch.full_like(l2_norm, self.eps))
+ normalized_x = x / denom
+ return normalized_x * self.scale * self.gamma
+
+
+# attention
+
+
+class RotaryEmbedding(nn.Module):
+ def __init__(self, cos_emb, sin_emb):
+ super().__init__()
+ # both (seq_len_for_rotation, dim_head)
+ self.cos_emb = cos_emb
+ self.sin_emb = sin_emb
+
+ def rotate_half(self, x):
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
+ x1, x2 = x.unbind(dim=-1)
+ x = torch.stack((-x2, x1), dim=-1)
+ return rearrange(x, "... d r -> ... (d r)")
+
+ def forward(self, x):
+ # x is (batch_eff, heads, seq_len_for_rotation, dim_head)
+ cos_b = self.cos_emb.unsqueeze(0).unsqueeze(0).to(x.device, x.dtype)
+ sin_b = self.sin_emb.unsqueeze(0).unsqueeze(0).to(x.device, x.dtype)
+
+ term1 = x * cos_b
+ term2 = self.rotate_half(x) * sin_b
+
+ sum = term1.to(torch.float32) + term2.to(torch.float32)
+ return sum.to(x.dtype)
+
+
+class FeedForward(Module):
+ def __init__(self, dim, mult=4, dropout=0.0):
+ super().__init__()
+ dim_inner = int(dim * mult)
+ self.net = nn.Sequential(
+ CustomNorm(dim),
+ nn.Linear(dim, dim_inner),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(dim_inner, dim),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class Attention(Module):
+ def __init__(
+ self,
+ dim,
+ heads=8,
+ dim_head=64,
+ dropout=0.0,
+ shared_qkv_bias=None,
+ shared_out_bias=None,
+ rotary_embed: RotaryEmbedding | None = None,
+ flash=True,
+ sage_attention=False,
+ ):
+ super().__init__()
+ self.heads = heads
+ self.scale = dim_head**-0.5
+ dim_inner = heads * dim_head
+
+ self.rotary_embed = rotary_embed
+
+ if sage_attention:
+ self.attend = AttendSage(flash=flash, dropout=dropout) # type: ignore
+ else:
+ self.attend = Attend(flash=flash, dropout=dropout) # type: ignore
+
+ self.norm = CustomNorm(dim)
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=(shared_qkv_bias is not None))
+ if shared_qkv_bias is not None:
+ self.to_qkv.bias = shared_qkv_bias
+
+ self.to_gates = nn.Linear(dim, heads)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(dim_inner, dim, bias=(shared_out_bias is not None)),
+ nn.Dropout(dropout),
+ )
+ if shared_out_bias is not None:
+ self.to_out[0].bias = shared_out_bias
+
+ def forward(self, x):
+ x = self.norm(x)
+
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
+
+ if self.rotary_embed is not None:
+ q = self.rotary_embed(q)
+ k = self.rotary_embed(k)
+
+ out = self.attend(q, k, v)
+
+ gates = self.to_gates(x)
+ gate_act = gates.sigmoid()
+
+ out = out * rearrange(gate_act, "b n h -> b h n 1")
+
+ out = rearrange(out, "b h n d -> b n (h d)")
+ out = self.to_out(out)
+ return out
+
+
+class LinearAttention(Module):
+ """
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
+ """
+
+ @beartype
+ def __init__(
+ self,
+ *,
+ dim,
+ dim_head=32,
+ heads=8,
+ scale=8,
+ flash=False,
+ dropout=0.0,
+ sage_attention=False,
+ ):
+ super().__init__()
+ dim_inner = dim_head * heads
+ self.norm = CustomNorm(dim)
+
+ self.to_qkv = nn.Sequential(
+ nn.Linear(dim, dim_inner * 3, bias=False),
+ Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
+ )
+
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
+
+ if sage_attention:
+ self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash) # type: ignore
+ else:
+ self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
+
+ self.to_out = nn.Sequential(
+ Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
+ )
+
+ def forward(self, x):
+ x = self.norm(x)
+
+ q, k, v = self.to_qkv(x)
+
+ q, k = map(l2norm, (q, k))
+ q = q * self.temperature.exp()
+
+ out = self.attend(q, k, v)
+
+ return self.to_out(out)
+
+
+class Transformer(Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ depth,
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.0,
+ ff_dropout=0.0,
+ ff_mult=4,
+ norm_output=True,
+ rotary_embed: RotaryEmbedding | None = None,
+ flash_attn=True,
+ linear_attn=False,
+ sage_attention=False,
+ shared_qkv_bias=None,
+ shared_out_bias=None,
+ ):
+ super().__init__()
+ self.layers = ModuleList([])
+
+ for _ in range(depth):
+ attn: LinearAttention | Attention
+ if linear_attn:
+ attn = LinearAttention(
+ dim=dim,
+ dim_head=dim_head,
+ heads=heads,
+ dropout=attn_dropout,
+ flash=flash_attn,
+ sage_attention=sage_attention,
+ )
+ else:
+ attn = Attention(
+ dim=dim,
+ dim_head=dim_head,
+ heads=heads,
+ dropout=attn_dropout,
+ shared_qkv_bias=shared_qkv_bias,
+ shared_out_bias=shared_out_bias,
+ rotary_embed=rotary_embed,
+ flash=flash_attn,
+ sage_attention=sage_attention,
+ )
+
+ self.layers.append(
+ ModuleList(
+ [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
+ )
+ )
+
+ self.norm = CustomNorm(dim) if norm_output else nn.Identity()
+
+ def forward(self, x):
+ for attn, ff in self.layers: # type: ignore
+ x = attn(x) + x
+ x = ff(x) + x
+ return self.norm(x)
+
+
+# bandsplit module
+
+
+class BandSplit(Module):
+ @beartype
+ def __init__(self, dim, dim_inputs: tuple[int, ...]):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_features = ModuleList([])
+
+ for dim_in in dim_inputs:
+ net = nn.Sequential(CustomNorm(dim_in), nn.Linear(dim_in, dim))
+
+ self.to_features.append(net)
+
+ def forward(self, x):
+ x = x.split(self.dim_inputs, dim=-1)
+
+ outs = []
+ for split_input, to_feature in zip(x, self.to_features):
+ split_output = to_feature(split_input)
+ outs.append(split_output)
+
+ return torch.stack(outs, dim=-2)
+
+
+def MLP(
+ dim_in: int,
+ dim_out: int,
+ dim_hidden: int | None = None,
+ depth: int = 1,
+ activation=nn.Tanh,
+):
+ dim_hidden = dim_hidden or dim_in
+
+ net = []
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
+
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
+ is_last = ind == (len(dims) - 2)
+
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
+
+ if is_last:
+ continue
+
+ net.append(activation())
+
+ return nn.Sequential(*net)
+
+
+class MaskEstimator(Module):
+ @beartype
+ def __init__(self, dim, dim_inputs: tuple[int, ...], depth, mlp_expansion_factor=4):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_freqs = ModuleList([])
+ dim_hidden = dim * mlp_expansion_factor
+
+ for dim_in in dim_inputs:
+ mlp = nn.Sequential(
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
+ )
+
+ self.to_freqs.append(mlp)
+
+ def forward(self, x):
+ x = x.unbind(dim=-2)
+
+ outs = []
+
+ for band_features, mlp in zip(x, self.to_freqs):
+ freq_out = mlp(band_features)
+ outs.append(freq_out)
+
+ return torch.cat(outs, dim=-1)
+
+
+# fmt: off
+DEFAULT_FREQS_PER_BANDS = (
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
+ 12, 12, 12, 12, 12, 12, 12, 12,
+ 24, 24, 24, 24, 24, 24, 24, 24,
+ 48, 48, 48, 48, 48, 48, 48, 48,
+ 128, 129
+)
+# fmt: on
+
+
+class BSRoformer_SW(Module):
+ @beartype
+ def __init__(
+ self,
+ dim,
+ *,
+ depth,
+ stereo=False,
+ num_stems=1,
+ time_transformer_depth=2,
+ freq_transformer_depth=2,
+ linear_transformer_depth=0,
+ freqs_per_bands: tuple[
+ int, ...
+ ] = DEFAULT_FREQS_PER_BANDS, # in the paper, they divide into ~60 bands, test with 1 for starters
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.0,
+ ff_dropout=0.0,
+ flash_attn=True,
+ stft_n_fft=2048,
+ stft_hop_length=512, # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
+ stft_win_length=2048,
+ stft_normalized=False,
+ stft_window_fn: Callable | None = None,
+ mask_estimator_depth=2,
+ multi_stft_resolution_loss_weight=1.0,
+ multi_stft_resolutions_window_sizes: tuple[int, ...] = (
+ 4096,
+ 2048,
+ 1024,
+ 512,
+ 256,
+ ),
+ multi_stft_hop_size=147,
+ multi_stft_normalized=False,
+ multi_stft_window_fn: Callable = torch.hann_window,
+ mlp_expansion_factor=4,
+ use_torch_checkpoint=False,
+ skip_connection=False,
+ sage_attention=False,
+ use_shared_bias=False,
+ chunk_size: int = 588800,
+ ):
+ super().__init__()
+
+ self.stereo = stereo
+ self.audio_channels = 2 if stereo else 1
+ self.num_stems = num_stems
+ self.use_torch_checkpoint = use_torch_checkpoint
+ self.skip_connection = skip_connection
+
+ self.layers = ModuleList([])
+
+ if sage_attention:
+ print("Use Sage Attention")
+
+ if use_shared_bias:
+ dim_inner = heads * dim_head
+ self.shared_qkv_bias = nn.Parameter(torch.ones(dim_inner * 3))
+ self.shared_out_bias = nn.Parameter(torch.ones(dim))
+
+ transformer_kwargs = dict(
+ dim=dim,
+ heads=heads,
+ dim_head=dim_head,
+ attn_dropout=attn_dropout,
+ ff_dropout=ff_dropout,
+ flash_attn=flash_attn,
+ norm_output=False,
+ sage_attention=sage_attention,
+ shared_qkv_bias=self.shared_qkv_bias,
+ shared_out_bias=self.shared_out_bias,
+ )
+
+ t_frames = chunk_size // stft_hop_length + 1 # e.g. 588800 // 512 + 1 = 1151
+ self.cos_emb_time = nn.Parameter(torch.zeros(t_frames, dim_head))
+ self.sin_emb_time = nn.Parameter(torch.zeros(t_frames, dim_head))
+ time_rotary_embed = RotaryEmbedding(
+ cos_emb=self.cos_emb_time, sin_emb=self.sin_emb_time
+ )
+
+ num_bands = len(freqs_per_bands) # e.g. 62
+ self.cos_emb_freq = nn.Parameter(torch.zeros(num_bands, dim_head))
+ self.sin_emb_freq = nn.Parameter(torch.zeros(num_bands, dim_head))
+ freq_rotary_embed = RotaryEmbedding(
+ cos_emb=self.cos_emb_freq, sin_emb=self.sin_emb_freq
+ )
+
+ for _ in range(depth):
+ tran_modules = []
+ if linear_transformer_depth > 0:
+ tran_modules.append(
+ Transformer(
+ depth=linear_transformer_depth,
+ linear_attn=True,
+ **transformer_kwargs,
+ )
+ )
+ tran_modules.append(
+ Transformer(
+ depth=time_transformer_depth,
+ rotary_embed=time_rotary_embed,
+ **transformer_kwargs,
+ )
+ )
+ tran_modules.append(
+ Transformer(
+ depth=freq_transformer_depth,
+ rotary_embed=freq_rotary_embed,
+ **transformer_kwargs,
+ )
+ )
+ self.layers.append(nn.ModuleList(tran_modules))
+
+ self.final_norm = CustomNorm(dim)
+
+ self.stft_kwargs = dict(
+ n_fft=stft_n_fft,
+ hop_length=stft_hop_length,
+ win_length=stft_win_length,
+ normalized=stft_normalized,
+ )
+
+ self.stft_window_fn = partial(
+ stft_window_fn or torch.hann_window, stft_win_length
+ )
+
+ freqs_per_bands_with_complex = tuple(
+ 2 * f * self.audio_channels for f in freqs_per_bands
+ )
+
+ self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
+
+ self.mask_estimators = nn.ModuleList([])
+
+ for _ in range(num_stems):
+ mask_estimator = MaskEstimator(
+ dim=dim,
+ dim_inputs=freqs_per_bands_with_complex,
+ depth=mask_estimator_depth,
+ mlp_expansion_factor=mlp_expansion_factor,
+ )
+
+ self.mask_estimators.append(mask_estimator)
+
+ # for the multi-resolution stft loss
+
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
+ self.multi_stft_n_fft = stft_n_fft
+ self.multi_stft_window_fn = multi_stft_window_fn
+
+ self.multi_stft_kwargs = dict(
+ hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
+ )
+
+ def forward(self, raw_audio, target=None, return_loss_breakdown=False):
+ """
+ einops
+
+ b - batch
+ f - freq
+ t - time
+ s - audio channel (1 for mono, 2 for stereo)
+ n - number of 'stems'
+ c - complex (2)
+ d - feature dimension
+ """
+
+ device = raw_audio.device
+
+ # defining whether model is loaded on MPS (MacOS GPU accelerator)
+ x_is_mps = True if device.type == "mps" else False
+
+ if raw_audio.ndim == 2:
+ raw_audio = rearrange(raw_audio, "b t -> b 1 t")
+
+ channels = raw_audio.shape[1]
+ assert (not self.stereo and channels == 1) or (
+ self.stereo and channels == 2
+ ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
+
+ # to stft
+
+ raw_audio, batch_audio_channel_packed_shape = pack([raw_audio], "* t")
+
+ stft_window = self.stft_window_fn(device=device)
+
+ # RuntimeError: FFT operations are only supported on MacOS 14+
+ # Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used
+ try:
+ stft_repr = torch.stft(
+ raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
+ )
+ except Exception:
+ stft_repr = torch.stft(
+ raw_audio.cpu() if x_is_mps else raw_audio,
+ **self.stft_kwargs,
+ window=stft_window.cpu() if x_is_mps else stft_window,
+ return_complex=True,
+ ).to(device)
+ stft_repr = torch.view_as_real(stft_repr)
+
+ stft_repr = unpack(stft_repr, batch_audio_channel_packed_shape, "* f t c")[0]
+
+ # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
+ stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
+
+ x = rearrange(stft_repr, "b f t c -> b t (f c)")
+
+ if torch.isnan(x).any() or torch.isinf(x).any():
+ raise RuntimeError(
+ f"NaN/Inf in x after stft: {x.isnan().sum()} NaNs, {x.isinf().sum()} Infs"
+ )
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(self.band_split, x, use_reentrant=False)
+ else:
+ x = self.band_split(x)
+
+ if torch.isnan(x).any() or torch.isinf(x).any():
+ raise RuntimeError(
+ f"NaN/Inf in x after band_split: {x.isnan().sum()} NaNs, {x.isinf().sum()} Infs"
+ )
+
+ # axial / hierarchical attention
+
+ store = [None] * len(self.layers)
+ for i, transformer_block in enumerate(self.layers):
+ if len(transformer_block) == 3:
+ linear_transformer, time_transformer, freq_transformer = (
+ transformer_block
+ )
+
+ x, ft_ps = pack([x], "b * d")
+ if self.use_torch_checkpoint:
+ x = checkpoint(linear_transformer, x, use_reentrant=False)
+ else:
+ x = linear_transformer(x)
+ (x,) = unpack(x, ft_ps, "b * d")
+ else:
+ time_transformer, freq_transformer = transformer_block
+
+ if self.skip_connection:
+ # Sum all previous
+ for j in range(i):
+ x = x + store[j]
+
+ x = rearrange(x, "b t f d -> b f t d")
+ x, ps = pack([x], "* t d")
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(time_transformer, x, use_reentrant=False)
+ else:
+ x = time_transformer(x)
+
+ (x,) = unpack(x, ps, "* t d")
+ x = rearrange(x, "b f t d -> b t f d")
+ x, ps = pack([x], "* f d")
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(freq_transformer, x, use_reentrant=False)
+ else:
+ x = freq_transformer(x)
+
+ (x,) = unpack(x, ps, "* f d")
+
+ if self.skip_connection:
+ store[i] = x
+
+ x = self.final_norm(x)
+
+ num_stems = len(self.mask_estimators)
+
+ if self.use_torch_checkpoint:
+ mask = torch.stack(
+ [checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators],
+ dim=1,
+ )
+ else:
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
+ mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
+
+ # modulate frequency representation
+
+ stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
+
+ # complex number multiplication
+
+ stft_repr = torch.view_as_complex(stft_repr)
+ mask = torch.view_as_complex(mask)
+
+ stft_repr = stft_repr * mask
+
+ # istft
+
+ stft_repr = rearrange(
+ stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
+ )
+
+ # same as torch.stft() fix for MacOS MPS above
+ try:
+ recon_audio = torch.istft(
+ stft_repr,
+ **self.stft_kwargs,
+ window=stft_window,
+ return_complex=False,
+ length=raw_audio.shape[-1],
+ )
+ except Exception:
+ recon_audio = torch.istft(
+ stft_repr.cpu() if x_is_mps else stft_repr,
+ **self.stft_kwargs,
+ window=stft_window.cpu() if x_is_mps else stft_window,
+ return_complex=False,
+ length=raw_audio.shape[-1],
+ ).to(device)
+
+ recon_audio = rearrange(
+ recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
+ )
+
+ if num_stems == 1:
+ recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
+
+ # if a target is passed in, calculate loss for learning
+
+ if target is None:
+ return recon_audio
+
+ if self.num_stems > 1:
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
+
+ if target.ndim == 2:
+ target = rearrange(target, "... t -> ... 1 t")
+
+ target = target[
+ ..., : recon_audio.shape[-1]
+ ] # protect against lost length on istft
+
+ loss = F.l1_loss(recon_audio, target)
+
+ multi_stft_resolution_loss = 0.0
+
+ for window_size in self.multi_stft_resolutions_window_sizes:
+ res_stft_kwargs = dict(
+ n_fft=max(
+ window_size, self.multi_stft_n_fft
+ ), # not sure what n_fft is across multi resolution stft
+ win_length=window_size,
+ return_complex=True,
+ window=self.multi_stft_window_fn(window_size, device=device),
+ **self.multi_stft_kwargs,
+ )
+
+ recon_Y = torch.stft(
+ rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
+ )
+ target_Y = torch.stft(
+ rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
+ )
+
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
+ recon_Y, target_Y
+ )
+
+ weighted_multi_resolution_loss = (
+ multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
+ )
+
+ total_loss = loss + weighted_multi_resolution_loss
+
+ if not return_loss_breakdown:
+ return total_loss
+
+ return total_loss, (loss, multi_stft_resolution_loss)
diff --git a/mvsepless/models/bs_roformer/mel_band_roformer.py b/mvsepless/models/bs_roformer/mel_band_roformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..833aa39c44c5f7b7e8cf9dd60a7cccaead36ef24
--- /dev/null
+++ b/mvsepless/models/bs_roformer/mel_band_roformer.py
@@ -0,0 +1,704 @@
+from functools import partial
+
+import torch
+from torch import nn, einsum, Tensor
+from torch.nn import Module, ModuleList
+import torch.nn.functional as F
+
+from .attend import Attend
+from torch.utils.checkpoint import checkpoint
+
+from beartype.typing import Tuple, Optional, List, Callable
+from beartype import beartype
+
+from rotary_embedding_torch import RotaryEmbedding
+
+from einops import rearrange, pack, unpack, reduce, repeat
+from einops.layers.torch import Rearrange
+
+from librosa import filters
+
+
+# helper functions
+
+
+def exists(val):
+ return val is not None
+
+
+def default(v, d):
+ return v if exists(v) else d
+
+
+def pack_one(t, pattern):
+ return pack([t], pattern)
+
+
+def unpack_one(t, ps, pattern):
+ return unpack(t, ps, pattern)[0]
+
+
+def pad_at_dim(t, pad, dim=-1, value=0.0):
+ dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
+ zeros = (0, 0) * dims_from_right
+ return F.pad(t, (*zeros, *pad), value=value)
+
+
+def l2norm(t):
+ return F.normalize(t, dim=-1, p=2)
+
+
+# norm
+
+
+class RMSNorm(Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
+
+
+# attention
+
+
+class FeedForward(Module):
+ def __init__(self, dim, mult=4, dropout=0.0):
+ super().__init__()
+ dim_inner = int(dim * mult)
+ self.net = nn.Sequential(
+ RMSNorm(dim),
+ nn.Linear(dim, dim_inner),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(dim_inner, dim),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class Attention(Module):
+ def __init__(
+ self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True
+ ):
+ super().__init__()
+ self.heads = heads
+ self.scale = dim_head**-0.5
+ dim_inner = heads * dim_head
+
+ self.rotary_embed = rotary_embed
+
+ self.attend = Attend(flash=flash, dropout=dropout)
+
+ self.norm = RMSNorm(dim)
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
+
+ self.to_gates = nn.Linear(dim, heads)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
+ )
+
+ def forward(self, x):
+ x = self.norm(x)
+
+ q, k, v = rearrange(
+ self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
+ )
+
+ if exists(self.rotary_embed):
+ q = self.rotary_embed.rotate_queries_or_keys(q)
+ k = self.rotary_embed.rotate_queries_or_keys(k)
+
+ out = self.attend(q, k, v)
+
+ gates = self.to_gates(x)
+ out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
+
+ out = rearrange(out, "b h n d -> b n (h d)")
+ return self.to_out(out)
+
+
+class LinearAttention(Module):
+ """
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
+ """
+
+ @beartype
+ def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0):
+ super().__init__()
+ dim_inner = dim_head * heads
+ self.norm = RMSNorm(dim)
+
+ self.to_qkv = nn.Sequential(
+ nn.Linear(dim, dim_inner * 3, bias=False),
+ Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
+ )
+
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
+
+ self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
+
+ self.to_out = nn.Sequential(
+ Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
+ )
+
+ def forward(self, x):
+ x = self.norm(x)
+
+ q, k, v = self.to_qkv(x)
+
+ q, k = map(l2norm, (q, k))
+ q = q * self.temperature.exp()
+
+ out = self.attend(q, k, v)
+
+ return self.to_out(out)
+
+
+class Transformer(Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ depth,
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.0,
+ ff_dropout=0.0,
+ ff_mult=4,
+ norm_output=True,
+ rotary_embed=None,
+ flash_attn=True,
+ linear_attn=False,
+ ):
+ super().__init__()
+ self.layers = ModuleList([])
+
+ for _ in range(depth):
+ if linear_attn:
+ attn = LinearAttention(
+ dim=dim,
+ dim_head=dim_head,
+ heads=heads,
+ dropout=attn_dropout,
+ flash=flash_attn,
+ )
+ else:
+ attn = Attention(
+ dim=dim,
+ dim_head=dim_head,
+ heads=heads,
+ dropout=attn_dropout,
+ rotary_embed=rotary_embed,
+ flash=flash_attn,
+ )
+
+ self.layers.append(
+ ModuleList(
+ [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
+ )
+ )
+
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
+
+ def forward(self, x):
+
+ for attn, ff in self.layers:
+ x = attn(x) + x
+ x = ff(x) + x
+
+ return self.norm(x)
+
+
+# bandsplit module
+
+
+class BandSplit(Module):
+ @beartype
+ def __init__(self, dim, dim_inputs: Tuple[int, ...]):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_features = ModuleList([])
+
+ for dim_in in dim_inputs:
+ net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
+
+ self.to_features.append(net)
+
+ def forward(self, x):
+ x = x.split(self.dim_inputs, dim=-1)
+
+ outs = []
+ for split_input, to_feature in zip(x, self.to_features):
+ split_output = to_feature(split_input)
+ outs.append(split_output)
+
+ return torch.stack(outs, dim=-2)
+
+
+def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
+ dim_hidden = default(dim_hidden, dim_in)
+
+ net = []
+ dims = (dim_in, *((dim_hidden,) * depth), dim_out)
+
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
+ is_last = ind == (len(dims) - 2)
+
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
+
+ if is_last:
+ continue
+
+ net.append(activation())
+
+ return nn.Sequential(*net)
+
+
+class MaskEstimator(Module):
+ @beartype
+ def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_freqs = ModuleList([])
+ dim_hidden = dim * mlp_expansion_factor
+
+ for dim_in in dim_inputs:
+ net = []
+
+ mlp = nn.Sequential(
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
+ )
+
+ self.to_freqs.append(mlp)
+
+ def forward(self, x):
+ x = x.unbind(dim=-2)
+
+ outs = []
+
+ for band_features, mlp in zip(x, self.to_freqs):
+ freq_out = mlp(band_features)
+ outs.append(freq_out)
+
+ return torch.cat(outs, dim=-1)
+
+
+# main class
+
+
+class MelBandRoformer(Module):
+
+ @beartype
+ def __init__(
+ self,
+ dim,
+ *,
+ depth,
+ stereo=False,
+ num_stems=1,
+ time_transformer_depth=2,
+ freq_transformer_depth=2,
+ linear_transformer_depth=0,
+ num_bands=60,
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.1,
+ ff_dropout=0.1,
+ flash_attn=True,
+ dim_freqs_in=1025,
+ sample_rate=44100, # needed for mel filter bank from librosa
+ stft_n_fft=2048,
+ stft_hop_length=512,
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
+ stft_win_length=2048,
+ stft_normalized=False,
+ stft_window_fn: Optional[Callable] = None,
+ mask_estimator_depth=1,
+ multi_stft_resolution_loss_weight=1.0,
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
+ 4096,
+ 2048,
+ 1024,
+ 512,
+ 256,
+ ),
+ multi_stft_hop_size=147,
+ multi_stft_normalized=False,
+ multi_stft_window_fn: Callable = torch.hann_window,
+ match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
+ mlp_expansion_factor=4,
+ use_torch_checkpoint=False,
+ skip_connection=False,
+ ):
+ super().__init__()
+
+ self.stereo = stereo
+ self.audio_channels = 2 if stereo else 1
+ self.num_stems = num_stems
+ self.use_torch_checkpoint = use_torch_checkpoint
+ self.skip_connection = skip_connection
+
+ self.layers = ModuleList([])
+
+ transformer_kwargs = dict(
+ dim=dim,
+ heads=heads,
+ dim_head=dim_head,
+ attn_dropout=attn_dropout,
+ ff_dropout=ff_dropout,
+ flash_attn=flash_attn,
+ )
+
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
+
+ for _ in range(depth):
+ tran_modules = []
+ if linear_transformer_depth > 0:
+ tran_modules.append(
+ Transformer(
+ depth=linear_transformer_depth,
+ linear_attn=True,
+ **transformer_kwargs,
+ )
+ )
+ tran_modules.append(
+ Transformer(
+ depth=time_transformer_depth,
+ rotary_embed=time_rotary_embed,
+ **transformer_kwargs,
+ )
+ )
+ tran_modules.append(
+ Transformer(
+ depth=freq_transformer_depth,
+ rotary_embed=freq_rotary_embed,
+ **transformer_kwargs,
+ )
+ )
+ self.layers.append(nn.ModuleList(tran_modules))
+
+ self.stft_window_fn = partial(
+ default(stft_window_fn, torch.hann_window), stft_win_length
+ )
+
+ self.stft_kwargs = dict(
+ n_fft=stft_n_fft,
+ hop_length=stft_hop_length,
+ win_length=stft_win_length,
+ normalized=stft_normalized,
+ )
+
+ freqs = torch.stft(
+ torch.randn(1, 4096),
+ **self.stft_kwargs,
+ window=torch.ones(stft_n_fft),
+ return_complex=True,
+ ).shape[1]
+
+ # create mel filter bank
+ # with librosa.filters.mel as in section 2 of paper
+
+ mel_filter_bank_numpy = filters.mel(
+ sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands
+ )
+
+ mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
+
+ # for some reason, it doesn't include the first freq? just force a value for now
+
+ mel_filter_bank[0][0] = 1.0
+
+ # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
+ # so let's force a positive value
+
+ mel_filter_bank[-1, -1] = 1.0
+
+ # binary as in paper (then estimated masks are averaged for overlapping regions)
+
+ freqs_per_band = mel_filter_bank > 0
+ assert freqs_per_band.any(
+ dim=0
+ ).all(), "all frequencies need to be covered by all bands for now"
+
+ repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands)
+ freq_indices = repeated_freq_indices[freqs_per_band]
+
+ if stereo:
+ freq_indices = repeat(freq_indices, "f -> f s", s=2)
+ freq_indices = freq_indices * 2 + torch.arange(2)
+ freq_indices = rearrange(freq_indices, "f s -> (f s)")
+
+ self.register_buffer("freq_indices", freq_indices, persistent=False)
+ self.register_buffer("freqs_per_band", freqs_per_band, persistent=False)
+
+ num_freqs_per_band = reduce(freqs_per_band, "b f -> b", "sum")
+ num_bands_per_freq = reduce(freqs_per_band, "b f -> f", "sum")
+
+ self.register_buffer("num_freqs_per_band", num_freqs_per_band, persistent=False)
+ self.register_buffer("num_bands_per_freq", num_bands_per_freq, persistent=False)
+
+ # band split and mask estimator
+
+ freqs_per_bands_with_complex = tuple(
+ 2 * f * self.audio_channels for f in num_freqs_per_band.tolist()
+ )
+
+ self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
+
+ self.mask_estimators = nn.ModuleList([])
+
+ for _ in range(num_stems):
+ mask_estimator = MaskEstimator(
+ dim=dim,
+ dim_inputs=freqs_per_bands_with_complex,
+ depth=mask_estimator_depth,
+ mlp_expansion_factor=mlp_expansion_factor,
+ )
+
+ self.mask_estimators.append(mask_estimator)
+
+ # for the multi-resolution stft loss
+
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
+ self.multi_stft_n_fft = stft_n_fft
+ self.multi_stft_window_fn = multi_stft_window_fn
+
+ self.multi_stft_kwargs = dict(
+ hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
+ )
+
+ self.match_input_audio_length = match_input_audio_length
+
+ def forward(self, raw_audio, target=None, return_loss_breakdown=False):
+ """
+ einops
+
+ b - batch
+ f - freq
+ t - time
+ s - audio channel (1 for mono, 2 for stereo)
+ n - number of 'stems'
+ c - complex (2)
+ d - feature dimension
+ """
+
+ device = raw_audio.device
+
+ if raw_audio.ndim == 2:
+ raw_audio = rearrange(raw_audio, "b t -> b 1 t")
+
+ batch, channels, raw_audio_length = raw_audio.shape
+
+ istft_length = raw_audio_length if self.match_input_audio_length else None
+
+ assert (not self.stereo and channels == 1) or (
+ self.stereo and channels == 2
+ ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
+
+ # to stft
+
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
+
+ stft_window = self.stft_window_fn(device=device)
+
+ stft_repr = torch.stft(
+ raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
+ )
+ stft_repr = torch.view_as_real(stft_repr)
+
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
+ stft_repr = rearrange(
+ stft_repr, "b s f t c -> b (f s) t c"
+ ) # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
+
+ # index out all frequencies for all frequency ranges across bands ascending in one go
+
+ batch_arange = torch.arange(batch, device=device)[..., None]
+
+ # account for stereo
+
+ x = stft_repr[batch_arange, self.freq_indices]
+
+ # fold the complex (real and imag) into the frequencies dimension
+
+ x = rearrange(x, "b f t c -> b t (f c)")
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(self.band_split, x, use_reentrant=False)
+ else:
+ x = self.band_split(x)
+
+ # axial / hierarchical attention
+
+ store = [None] * len(self.layers)
+ for i, transformer_block in enumerate(self.layers):
+
+ if len(transformer_block) == 3:
+ linear_transformer, time_transformer, freq_transformer = (
+ transformer_block
+ )
+
+ x, ft_ps = pack([x], "b * d")
+ if self.use_torch_checkpoint:
+ x = checkpoint(linear_transformer, x, use_reentrant=False)
+ else:
+ x = linear_transformer(x)
+ (x,) = unpack(x, ft_ps, "b * d")
+ else:
+ time_transformer, freq_transformer = transformer_block
+
+ if self.skip_connection:
+ # Sum all previous
+ for j in range(i):
+ x = x + store[j]
+
+ x = rearrange(x, "b t f d -> b f t d")
+ x, ps = pack([x], "* t d")
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(time_transformer, x, use_reentrant=False)
+ else:
+ x = time_transformer(x)
+
+ (x,) = unpack(x, ps, "* t d")
+ x = rearrange(x, "b f t d -> b t f d")
+ x, ps = pack([x], "* f d")
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(freq_transformer, x, use_reentrant=False)
+ else:
+ x = freq_transformer(x)
+
+ (x,) = unpack(x, ps, "* f d")
+
+ if self.skip_connection:
+ store[i] = x
+
+ num_stems = len(self.mask_estimators)
+ if self.use_torch_checkpoint:
+ masks = torch.stack(
+ [checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators],
+ dim=1,
+ )
+ else:
+ masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
+ masks = rearrange(masks, "b n t (f c) -> b n f t c", c=2)
+
+ # modulate frequency representation
+
+ stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
+
+ # complex number multiplication
+
+ stft_repr = torch.view_as_complex(stft_repr)
+ masks = torch.view_as_complex(masks)
+
+ masks = masks.type(stft_repr.dtype)
+
+ # need to average the estimated mask for the overlapped frequencies
+
+ scatter_indices = repeat(
+ self.freq_indices,
+ "f -> b n f t",
+ b=batch,
+ n=num_stems,
+ t=stft_repr.shape[-1],
+ )
+
+ stft_repr_expanded_stems = repeat(stft_repr, "b 1 ... -> b n ...", n=num_stems)
+ masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(
+ 2, scatter_indices, masks
+ )
+
+ denom = repeat(self.num_bands_per_freq, "f -> (f r) 1", r=channels)
+
+ masks_averaged = masks_summed / denom.clamp(min=1e-8)
+
+ # modulate stft repr with estimated mask
+
+ stft_repr = stft_repr * masks_averaged
+
+ # istft
+
+ stft_repr = rearrange(
+ stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
+ )
+
+ recon_audio = torch.istft(
+ stft_repr,
+ **self.stft_kwargs,
+ window=stft_window,
+ return_complex=False,
+ length=istft_length,
+ )
+
+ recon_audio = rearrange(
+ recon_audio,
+ "(b n s) t -> b n s t",
+ b=batch,
+ s=self.audio_channels,
+ n=num_stems,
+ )
+
+ if num_stems == 1:
+ recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
+
+ # if a target is passed in, calculate loss for learning
+
+ if not exists(target):
+ return recon_audio
+
+ if self.num_stems > 1:
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
+
+ if target.ndim == 2:
+ target = rearrange(target, "... t -> ... 1 t")
+
+ target = target[
+ ..., : recon_audio.shape[-1]
+ ] # protect against lost length on istft
+
+ loss = F.l1_loss(recon_audio, target)
+
+ multi_stft_resolution_loss = 0.0
+
+ for window_size in self.multi_stft_resolutions_window_sizes:
+ res_stft_kwargs = dict(
+ n_fft=max(
+ window_size, self.multi_stft_n_fft
+ ), # not sure what n_fft is across multi resolution stft
+ win_length=window_size,
+ return_complex=True,
+ window=self.multi_stft_window_fn(window_size, device=device),
+ **self.multi_stft_kwargs,
+ )
+
+ recon_Y = torch.stft(
+ rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
+ )
+ target_Y = torch.stft(
+ rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
+ )
+
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
+ recon_Y, target_Y
+ )
+
+ weighted_multi_resolution_loss = (
+ multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
+ )
+
+ total_loss = loss + weighted_multi_resolution_loss
+
+ if not return_loss_breakdown:
+ return total_loss
+
+ return total_loss, (loss, multi_stft_resolution_loss)
diff --git a/mvsepless/models/demucs4ht.py b/mvsepless/models/demucs4ht.py
new file mode 100644
index 0000000000000000000000000000000000000000..888ee82a19cf612b62504dfd8c0d9204f7a906f6
--- /dev/null
+++ b/mvsepless/models/demucs4ht.py
@@ -0,0 +1,712 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+
+import numpy as np
+import torch
+import json
+from omegaconf import OmegaConf
+from demucs.demucs import Demucs
+from demucs.hdemucs import HDemucs
+
+import math
+from openunmix.filtering import wiener
+from torch import nn
+from torch.nn import functional as F
+from fractions import Fraction
+from einops import rearrange
+
+from demucs.transformer import CrossTransformerEncoder
+
+from demucs.demucs import rescale_module
+from demucs.states import capture_init
+from demucs.spec import spectro, ispectro
+from demucs.hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
+
+
+class HTDemucs(nn.Module):
+ """
+ Spectrogram and hybrid Demucs model.
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
+ Frequency layers can still access information across time steps thanks to the DConv residual.
+
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
+
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
+ Open Unmix implementation [Stoter et al. 2019].
+
+ The loss is always on the temporal domain, by backpropagating through the above
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
+ contribution, without changing the one from the waveform, which will lead to worse performance.
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
+ hybrid models.
+
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
+
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
+ """
+
+ @capture_init
+ def __init__(
+ self,
+ sources,
+ # Channels
+ audio_channels=2,
+ channels=48,
+ channels_time=None,
+ growth=2,
+ # STFT
+ nfft=4096,
+ num_subbands=1,
+ wiener_iters=0,
+ end_iters=0,
+ wiener_residual=False,
+ cac=True,
+ # Main structure
+ depth=4,
+ rewrite=True,
+ # Frequency branch
+ multi_freqs=None,
+ multi_freqs_depth=3,
+ freq_emb=0.2,
+ emb_scale=10,
+ emb_smooth=True,
+ # Convolutions
+ kernel_size=8,
+ time_stride=2,
+ stride=4,
+ context=1,
+ context_enc=0,
+ # Normalization
+ norm_starts=4,
+ norm_groups=4,
+ # DConv residual branch
+ dconv_mode=1,
+ dconv_depth=2,
+ dconv_comp=8,
+ dconv_init=1e-3,
+ # Before the Transformer
+ bottom_channels=0,
+ # Transformer
+ t_layers=5,
+ t_emb="sin",
+ t_hidden_scale=4.0,
+ t_heads=8,
+ t_dropout=0.0,
+ t_max_positions=10000,
+ t_norm_in=True,
+ t_norm_in_group=False,
+ t_group_norm=False,
+ t_norm_first=True,
+ t_norm_out=True,
+ t_max_period=10000.0,
+ t_weight_decay=0.0,
+ t_lr=None,
+ t_layer_scale=True,
+ t_gelu=True,
+ t_weight_pos_embed=1.0,
+ t_sin_random_shift=0,
+ t_cape_mean_normalize=True,
+ t_cape_augment=True,
+ t_cape_glob_loc_scale=[5000.0, 1.0, 1.4],
+ t_sparse_self_attn=False,
+ t_sparse_cross_attn=False,
+ t_mask_type="diag",
+ t_mask_random_seed=42,
+ t_sparse_attn_window=500,
+ t_global_window=100,
+ t_sparsity=0.95,
+ t_auto_sparsity=False,
+ # ------ Particuliar parameters
+ t_cross_first=False,
+ # Weight init
+ rescale=0.1,
+ # Metadata
+ samplerate=44100,
+ segment=10,
+ use_train_segment=False,
+ ):
+ """
+ Args:
+ sources (list[str]): list of source names.
+ audio_channels (int): input/output audio channels.
+ channels (int): initial number of hidden channels.
+ channels_time: if not None, use a different `channels` value for the time branch.
+ growth: increase the number of hidden channels by this factor at each layer.
+ nfft: number of fft bins. Note that changing this require careful computation of
+ various shape parameters and will not work out of the box for hybrid models.
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
+ wiener_residual: add residual source before wiener filtering.
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
+ in input and output. no further processing is done before ISTFT.
+ depth (int): number of layers in the encoder and in the decoder.
+ rewrite (bool): add 1x1 convolution to each layer.
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
+ layers will be wrapped.
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
+ the actual value controls the weight of the embedding.
+ emb_scale: equivalent to scaling the embedding learning rate
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
+ kernel_size: kernel_size for encoder and decoder layers.
+ stride: stride for encoder and decoder layers.
+ time_stride: stride for the final time layer, after the merge.
+ context: context for 1x1 conv in the decoder.
+ context_enc: context for 1x1 conv in the encoder.
+ norm_starts: layer at which group norm starts being used.
+ decoder layers are numbered in reverse order.
+ norm_groups: number of groups for group norm.
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
+ dconv_depth: depth of residual DConv branch.
+ dconv_comp: compression of DConv branch.
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
+ dconv_init: initial scale for the DConv branch LayerScale.
+ bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
+ transformer in order to change the number of channels
+ t_layers: number of layers in each branch (waveform and spec) of the transformer
+ t_emb: "sin", "cape" or "scaled"
+ t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
+ for instance if C = 384 (the number of channels in the transformer) and
+ t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
+ 384 * 4 = 1536
+ t_heads: number of heads for the transformer
+ t_dropout: dropout in the transformer
+ t_max_positions: max_positions for the "scaled" positional embedding, only
+ useful if t_emb="scaled"
+ t_norm_in: (bool) norm before addinf positional embedding and getting into the
+ transformer layers
+ t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
+ timesteps (GroupNorm with group=1)
+ t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
+ timesteps (GroupNorm with group=1)
+ t_norm_first: (bool) if True the norm is before the attention and before the FFN
+ t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
+ t_max_period: (float) denominator in the sinusoidal embedding expression
+ t_weight_decay: (float) weight decay for the transformer
+ t_lr: (float) specific learning rate for the transformer
+ t_layer_scale: (bool) Layer Scale for the transformer
+ t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
+ t_weight_pos_embed: (float) weighting of the positional embedding
+ t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
+ see: https://arxiv.org/abs/2106.03143
+ t_cape_augment: (bool) if t_emb="cape", must be True during training and False
+ during the inference, see: https://arxiv.org/abs/2106.03143
+ t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
+ see: https://arxiv.org/abs/2106.03143
+ t_sparse_self_attn: (bool) if True, the self attentions are sparse
+ t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
+ unless you designed really specific masks)
+ t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
+ with '_' between: i.e. "diag_jmask_random" (note that this is permutation
+ invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
+ t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
+ that generated the random part of the mask
+ t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
+ a key (j), the mask is True id |i-j|<=t_sparse_attn_window
+ t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
+ and mask[:, :t_global_window] will be True
+ t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
+ level of the random part of the mask.
+ t_cross_first: (bool) if True cross attention is the first layer of the
+ transformer (False seems to be better)
+ rescale: weight rescaling trick
+ use_train_segment: (bool) if True, the actual size that is used during the
+ training is used during inference.
+ """
+ super().__init__()
+ self.num_subbands = num_subbands
+ self.cac = cac
+ self.wiener_residual = wiener_residual
+ self.audio_channels = audio_channels
+ self.sources = sources
+ self.kernel_size = kernel_size
+ self.context = context
+ self.stride = stride
+ self.depth = depth
+ self.bottom_channels = bottom_channels
+ self.channels = channels
+ self.samplerate = samplerate
+ self.segment = segment
+ self.use_train_segment = use_train_segment
+ self.nfft = nfft
+ self.hop_length = nfft // 4
+ self.wiener_iters = wiener_iters
+ self.end_iters = end_iters
+ self.freq_emb = None
+ assert wiener_iters == end_iters
+
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+
+ self.tencoder = nn.ModuleList()
+ self.tdecoder = nn.ModuleList()
+
+ chin = audio_channels
+ chin_z = chin # number of channels for the freq branch
+ if self.cac:
+ chin_z *= 2
+ if self.num_subbands > 1:
+ chin_z *= self.num_subbands
+ chout = channels_time or channels
+ chout_z = channels
+ freqs = nfft // 2
+
+ for index in range(depth):
+ norm = index >= norm_starts
+ freq = freqs > 1
+ stri = stride
+ ker = kernel_size
+ if not freq:
+ assert freqs == 1
+ ker = time_stride * 2
+ stri = time_stride
+
+ pad = True
+ last_freq = False
+ if freq and freqs <= kernel_size:
+ ker = freqs
+ pad = False
+ last_freq = True
+
+ kw = {
+ "kernel_size": ker,
+ "stride": stri,
+ "freq": freq,
+ "pad": pad,
+ "norm": norm,
+ "rewrite": rewrite,
+ "norm_groups": norm_groups,
+ "dconv_kw": {
+ "depth": dconv_depth,
+ "compress": dconv_comp,
+ "init": dconv_init,
+ "gelu": True,
+ },
+ }
+ kwt = dict(kw)
+ kwt["freq"] = 0
+ kwt["kernel_size"] = kernel_size
+ kwt["stride"] = stride
+ kwt["pad"] = True
+ kw_dec = dict(kw)
+ multi = False
+ if multi_freqs and index < multi_freqs_depth:
+ multi = True
+ kw_dec["context_freq"] = False
+
+ if last_freq:
+ chout_z = max(chout, chout_z)
+ chout = chout_z
+
+ enc = HEncLayer(
+ chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw
+ )
+ if freq:
+ tenc = HEncLayer(
+ chin,
+ chout,
+ dconv=dconv_mode & 1,
+ context=context_enc,
+ empty=last_freq,
+ **kwt,
+ )
+ self.tencoder.append(tenc)
+
+ if multi:
+ enc = MultiWrap(enc, multi_freqs)
+ self.encoder.append(enc)
+ if index == 0:
+ chin = self.audio_channels * len(self.sources)
+ chin_z = chin
+ if self.cac:
+ chin_z *= 2
+ if self.num_subbands > 1:
+ chin_z *= self.num_subbands
+ dec = HDecLayer(
+ chout_z,
+ chin_z,
+ dconv=dconv_mode & 2,
+ last=index == 0,
+ context=context,
+ **kw_dec,
+ )
+ if multi:
+ dec = MultiWrap(dec, multi_freqs)
+ if freq:
+ tdec = HDecLayer(
+ chout,
+ chin,
+ dconv=dconv_mode & 2,
+ empty=last_freq,
+ last=index == 0,
+ context=context,
+ **kwt,
+ )
+ self.tdecoder.insert(0, tdec)
+ self.decoder.insert(0, dec)
+
+ chin = chout
+ chin_z = chout_z
+ chout = int(growth * chout)
+ chout_z = int(growth * chout_z)
+ if freq:
+ if freqs <= kernel_size:
+ freqs = 1
+ else:
+ freqs //= stride
+ if index == 0 and freq_emb:
+ self.freq_emb = ScaledEmbedding(
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale
+ )
+ self.freq_emb_scale = freq_emb
+
+ if rescale:
+ rescale_module(self, reference=rescale)
+
+ transformer_channels = channels * growth ** (depth - 1)
+ if bottom_channels:
+ self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
+ self.channel_downsampler = nn.Conv1d(
+ bottom_channels, transformer_channels, 1
+ )
+ self.channel_upsampler_t = nn.Conv1d(
+ transformer_channels, bottom_channels, 1
+ )
+ self.channel_downsampler_t = nn.Conv1d(
+ bottom_channels, transformer_channels, 1
+ )
+
+ transformer_channels = bottom_channels
+
+ if t_layers > 0:
+ self.crosstransformer = CrossTransformerEncoder(
+ dim=transformer_channels,
+ emb=t_emb,
+ hidden_scale=t_hidden_scale,
+ num_heads=t_heads,
+ num_layers=t_layers,
+ cross_first=t_cross_first,
+ dropout=t_dropout,
+ max_positions=t_max_positions,
+ norm_in=t_norm_in,
+ norm_in_group=t_norm_in_group,
+ group_norm=t_group_norm,
+ norm_first=t_norm_first,
+ norm_out=t_norm_out,
+ max_period=t_max_period,
+ weight_decay=t_weight_decay,
+ lr=t_lr,
+ layer_scale=t_layer_scale,
+ gelu=t_gelu,
+ sin_random_shift=t_sin_random_shift,
+ weight_pos_embed=t_weight_pos_embed,
+ cape_mean_normalize=t_cape_mean_normalize,
+ cape_augment=t_cape_augment,
+ cape_glob_loc_scale=t_cape_glob_loc_scale,
+ sparse_self_attn=t_sparse_self_attn,
+ sparse_cross_attn=t_sparse_cross_attn,
+ mask_type=t_mask_type,
+ mask_random_seed=t_mask_random_seed,
+ sparse_attn_window=t_sparse_attn_window,
+ global_window=t_global_window,
+ sparsity=t_sparsity,
+ auto_sparsity=t_auto_sparsity,
+ )
+ else:
+ self.crosstransformer = None
+
+ def _spec(self, x):
+ hl = self.hop_length
+ nfft = self.nfft
+ x0 = x # noqa
+
+ # We re-pad the signal in order to keep the property
+ # that the size of the output is exactly the size of the input
+ # divided by the stride (here hop_length), when divisible.
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
+ # which is not supported by torch.stft.
+ # Having all convolution operations follow this convention allow to easily
+ # align the time and frequency branches later on.
+ assert hl == nfft // 4
+ le = int(math.ceil(x.shape[-1] / hl))
+ pad = hl // 2 * 3
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
+
+ z = spectro(x, nfft, hl)[..., :-1, :]
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
+ z = z[..., 2 : 2 + le]
+ return z
+
+ def _ispec(self, z, length=None, scale=0):
+ hl = self.hop_length // (4**scale)
+ z = F.pad(z, (0, 0, 0, 1))
+ z = F.pad(z, (2, 2))
+ pad = hl // 2 * 3
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
+ x = ispectro(z, hl, length=le)
+ x = x[..., pad : pad + length]
+ return x
+
+ def _magnitude(self, z):
+ # return the magnitude of the spectrogram, except when cac is True,
+ # in which case we just move the complex dimension to the channel one.
+ if self.cac:
+ B, C, Fr, T = z.shape
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
+ m = m.reshape(B, C * 2, Fr, T)
+ else:
+ m = z.abs()
+ return m
+
+ def _mask(self, z, m):
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
+ niters = self.wiener_iters
+ if self.cac:
+ B, S, C, Fr, T = m.shape
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
+ out = torch.view_as_complex(out.contiguous())
+ return out
+ if self.training:
+ niters = self.end_iters
+ if niters < 0:
+ z = z[:, None]
+ return z / (1e-8 + z.abs()) * m
+ else:
+ return self._wiener(m, z, niters)
+
+ def _wiener(self, mag_out, mix_stft, niters):
+ # apply wiener filtering from OpenUnmix.
+ init = mix_stft.dtype
+ wiener_win_len = 300
+ residual = self.wiener_residual
+
+ B, S, C, Fq, T = mag_out.shape
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
+
+ outs = []
+ for sample in range(B):
+ pos = 0
+ out = []
+ for pos in range(0, T, wiener_win_len):
+ frame = slice(pos, pos + wiener_win_len)
+ z_out = wiener(
+ mag_out[sample, frame],
+ mix_stft[sample, frame],
+ niters,
+ residual=residual,
+ )
+ out.append(z_out.transpose(-1, -2))
+ outs.append(torch.cat(out, dim=0))
+ out = torch.view_as_complex(torch.stack(outs, 0))
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
+ if residual:
+ out = out[:, :-1]
+ assert list(out.shape) == [B, S, C, Fq, T]
+ return out.to(init)
+
+ def valid_length(self, length: int):
+ """
+ Return a length that is appropriate for evaluation.
+ In our case, always return the training length, unless
+ it is smaller than the given length, in which case this
+ raises an error.
+ """
+ if not self.use_train_segment:
+ return length
+ training_length = int(self.segment * self.samplerate)
+ if training_length < length:
+ raise ValueError(
+ f"Given length {length} is longer than "
+ f"training length {training_length}"
+ )
+ return training_length
+
+ def cac2cws(self, x):
+ k = self.num_subbands
+ b, c, f, t = x.shape
+ x = x.reshape(b, c, k, f // k, t)
+ x = x.reshape(b, c * k, f // k, t)
+ return x
+
+ def cws2cac(self, x):
+ k = self.num_subbands
+ b, c, f, t = x.shape
+ x = x.reshape(b, c // k, k, f, t)
+ x = x.reshape(b, c // k, f * k, t)
+ return x
+
+ def forward(self, mix):
+ length = mix.shape[-1]
+ length_pre_pad = None
+ if self.use_train_segment:
+ if self.training:
+ self.segment = Fraction(mix.shape[-1], self.samplerate)
+ else:
+ training_length = int(self.segment * self.samplerate)
+ # print('Training length: {} Segment: {} Sample rate: {}'.format(training_length, self.segment, self.samplerate))
+ if mix.shape[-1] < training_length:
+ length_pre_pad = mix.shape[-1]
+ mix = F.pad(mix, (0, training_length - length_pre_pad))
+ # print("Mix: {}".format(mix.shape))
+ # print("Length: {}".format(length))
+ z = self._spec(mix)
+ # print("Z: {} Type: {}".format(z.shape, z.dtype))
+ mag = self._magnitude(z)
+ x = mag
+ # print("MAG: {} Type: {}".format(x.shape, x.dtype))
+
+ if self.num_subbands > 1:
+ x = self.cac2cws(x)
+ # print("After SUBBANDS: {} Type: {}".format(x.shape, x.dtype))
+
+ B, C, Fq, T = x.shape
+
+ # unlike previous Demucs, we always normalize because it is easier.
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
+ std = x.std(dim=(1, 2, 3), keepdim=True)
+ x = (x - mean) / (1e-5 + std)
+ # x will be the freq. branch input.
+
+ # Prepare the time branch input.
+ xt = mix
+ meant = xt.mean(dim=(1, 2), keepdim=True)
+ stdt = xt.std(dim=(1, 2), keepdim=True)
+ xt = (xt - meant) / (1e-5 + stdt)
+
+ # print("XT: {}".format(xt.shape))
+
+ # okay, this is a giant mess I know...
+ saved = [] # skip connections, freq.
+ saved_t = [] # skip connections, time.
+ lengths = [] # saved lengths to properly remove padding, freq branch.
+ lengths_t = [] # saved lengths for time branch.
+ for idx, encode in enumerate(self.encoder):
+ lengths.append(x.shape[-1])
+ inject = None
+ if idx < len(self.tencoder):
+ # we have not yet merged branches.
+ lengths_t.append(xt.shape[-1])
+ tenc = self.tencoder[idx]
+ xt = tenc(xt)
+ # print("Encode XT {}: {}".format(idx, xt.shape))
+ if not tenc.empty:
+ # save for skip connection
+ saved_t.append(xt)
+ else:
+ # tenc contains just the first conv., so that now time and freq.
+ # branches have the same shape and can be merged.
+ inject = xt
+ x = encode(x, inject)
+ # print("Encode X {}: {}".format(idx, x.shape))
+ if idx == 0 and self.freq_emb is not None:
+ # add frequency embedding to allow for non equivariant convolutions
+ # over the frequency axis.
+ frs = torch.arange(x.shape[-2], device=x.device)
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
+ x = x + self.freq_emb_scale * emb
+
+ saved.append(x)
+ if self.crosstransformer:
+ if self.bottom_channels:
+ b, c, f, t = x.shape
+ x = rearrange(x, "b c f t-> b c (f t)")
+ x = self.channel_upsampler(x)
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
+ xt = self.channel_upsampler_t(xt)
+
+ x, xt = self.crosstransformer(x, xt)
+ # print("Cross Tran X {}, XT: {}".format(x.shape, xt.shape))
+
+ if self.bottom_channels:
+ x = rearrange(x, "b c f t-> b c (f t)")
+ x = self.channel_downsampler(x)
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
+ xt = self.channel_downsampler_t(xt)
+
+ for idx, decode in enumerate(self.decoder):
+ skip = saved.pop(-1)
+ x, pre = decode(x, skip, lengths.pop(-1))
+ # print('Decode {} X: {}'.format(idx, x.shape))
+ # `pre` contains the output just before final transposed convolution,
+ # which is used when the freq. and time branch separate.
+
+ offset = self.depth - len(self.tdecoder)
+ if idx >= offset:
+ tdec = self.tdecoder[idx - offset]
+ length_t = lengths_t.pop(-1)
+ if tdec.empty:
+ assert pre.shape[2] == 1, pre.shape
+ pre = pre[:, :, 0]
+ xt, _ = tdec(pre, None, length_t)
+ else:
+ skip = saved_t.pop(-1)
+ xt, _ = tdec(xt, skip, length_t)
+ # print('Decode {} XT: {}'.format(idx, xt.shape))
+
+ # Let's make sure we used all stored skip connections.
+ assert len(saved) == 0
+ assert len(lengths_t) == 0
+ assert len(saved_t) == 0
+
+ S = len(self.sources)
+
+ if self.num_subbands > 1:
+ x = x.view(B, -1, Fq, T)
+ # print("X view 1: {}".format(x.shape))
+ x = self.cws2cac(x)
+ # print("X view 2: {}".format(x.shape))
+
+ x = x.view(B, S, -1, Fq * self.num_subbands, T)
+ x = x * std[:, None] + mean[:, None]
+ # print("X returned: {}".format(x.shape))
+
+ zout = self._mask(z, x)
+ if self.use_train_segment:
+ if self.training:
+ x = self._ispec(zout, length)
+ else:
+ x = self._ispec(zout, training_length)
+ else:
+ x = self._ispec(zout, length)
+
+ if self.use_train_segment:
+ if self.training:
+ xt = xt.view(B, S, -1, length)
+ else:
+ xt = xt.view(B, S, -1, training_length)
+ else:
+ xt = xt.view(B, S, -1, length)
+ xt = xt * stdt[:, None] + meant[:, None]
+ x = xt + x
+ if length_pre_pad:
+ x = x[..., :length_pre_pad]
+ return x
+
+
+def get_model(args):
+ extra = {
+ "sources": list(args.training.instruments),
+ "audio_channels": args.training.channels,
+ "samplerate": args.training.samplerate,
+ # 'segment': args.model_segment or 4 * args.dset.segment,
+ "segment": args.training.segment,
+ }
+ klass = {
+ "demucs": Demucs,
+ "hdemucs": HDemucs,
+ "htdemucs": HTDemucs,
+ }[args.model]
+ kw = OmegaConf.to_container(getattr(args, args.model), resolve=True)
+ model = klass(**extra, **kw)
+ return model
diff --git a/mvsepless/models/mdx23c_tfc_tdf_v3.py b/mvsepless/models/mdx23c_tfc_tdf_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd4637cfee387d1734ee3e33fe2e11ce923ef432
--- /dev/null
+++ b/mvsepless/models/mdx23c_tfc_tdf_v3.py
@@ -0,0 +1,259 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+from infer_utils import prefer_target_instrument
+
+
+class STFT:
+ def __init__(self, config):
+ self.n_fft = config.n_fft
+ self.hop_length = config.hop_length
+ self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
+ self.dim_f = config.dim_f
+
+ def __call__(self, x):
+ window = self.window.to(x.device)
+ batch_dims = x.shape[:-2]
+ c, t = x.shape[-2:]
+ x = x.reshape([-1, t])
+ x = torch.stft(
+ x,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ window=window,
+ center=True,
+ return_complex=True,
+ )
+ x = torch.view_as_real(x)
+ x = x.permute([0, 3, 1, 2])
+ x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape(
+ [*batch_dims, c * 2, -1, x.shape[-1]]
+ )
+ return x[..., : self.dim_f, :]
+
+ def inverse(self, x):
+ window = self.window.to(x.device)
+ batch_dims = x.shape[:-3]
+ c, f, t = x.shape[-3:]
+ n = self.n_fft // 2 + 1
+ f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device)
+ x = torch.cat([x, f_pad], -2)
+ x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
+ x = x.permute([0, 2, 3, 1])
+ x = x[..., 0] + x[..., 1] * 1.0j
+ x = torch.istft(
+ x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True
+ )
+ x = x.reshape([*batch_dims, 2, -1])
+ return x
+
+
+def get_norm(norm_type):
+ def norm(c, norm_type):
+ if norm_type == "BatchNorm":
+ return nn.BatchNorm2d(c)
+ elif norm_type == "InstanceNorm":
+ return nn.InstanceNorm2d(c, affine=True)
+ elif "GroupNorm" in norm_type:
+ g = int(norm_type.replace("GroupNorm", ""))
+ return nn.GroupNorm(num_groups=g, num_channels=c)
+ else:
+ return nn.Identity()
+
+ return partial(norm, norm_type=norm_type)
+
+
+def get_act(act_type):
+ if act_type == "gelu":
+ return nn.GELU()
+ elif act_type == "relu":
+ return nn.ReLU()
+ elif act_type[:3] == "elu":
+ alpha = float(act_type.replace("elu", ""))
+ return nn.ELU(alpha)
+ else:
+ raise Exception
+
+
+class Upscale(nn.Module):
+ def __init__(self, in_c, out_c, scale, norm, act):
+ super().__init__()
+ self.conv = nn.Sequential(
+ norm(in_c),
+ act,
+ nn.ConvTranspose2d(
+ in_channels=in_c,
+ out_channels=out_c,
+ kernel_size=scale,
+ stride=scale,
+ bias=False,
+ ),
+ )
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class Downscale(nn.Module):
+ def __init__(self, in_c, out_c, scale, norm, act):
+ super().__init__()
+ self.conv = nn.Sequential(
+ norm(in_c),
+ act,
+ nn.Conv2d(
+ in_channels=in_c,
+ out_channels=out_c,
+ kernel_size=scale,
+ stride=scale,
+ bias=False,
+ ),
+ )
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class TFC_TDF(nn.Module):
+ def __init__(self, in_c, c, l, f, bn, norm, act):
+ super().__init__()
+
+ self.blocks = nn.ModuleList()
+ for i in range(l):
+ block = nn.Module()
+
+ block.tfc1 = nn.Sequential(
+ norm(in_c),
+ act,
+ nn.Conv2d(in_c, c, 3, 1, 1, bias=False),
+ )
+ block.tdf = nn.Sequential(
+ norm(c),
+ act,
+ nn.Linear(f, f // bn, bias=False),
+ norm(c),
+ act,
+ nn.Linear(f // bn, f, bias=False),
+ )
+ block.tfc2 = nn.Sequential(
+ norm(c),
+ act,
+ nn.Conv2d(c, c, 3, 1, 1, bias=False),
+ )
+ block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False)
+
+ self.blocks.append(block)
+ in_c = c
+
+ def forward(self, x):
+ for block in self.blocks:
+ s = block.shortcut(x)
+ x = block.tfc1(x)
+ x = x + block.tdf(x)
+ x = block.tfc2(x)
+ x = x + s
+ return x
+
+
+class TFC_TDF_net(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ norm = get_norm(norm_type=config.model.norm)
+ act = get_act(act_type=config.model.act)
+
+ self.num_target_instruments = len(prefer_target_instrument(config))
+ self.num_subbands = config.model.num_subbands
+
+ dim_c = self.num_subbands * config.audio.num_channels * 2
+ n = config.model.num_scales
+ scale = config.model.scale
+ l = config.model.num_blocks_per_scale
+ c = config.model.num_channels
+ g = config.model.growth
+ bn = config.model.bottleneck_factor
+ f = config.audio.dim_f // self.num_subbands
+
+ self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)
+
+ self.encoder_blocks = nn.ModuleList()
+ for i in range(n):
+ block = nn.Module()
+ block.tfc_tdf = TFC_TDF(c, c, l, f, bn, norm, act)
+ block.downscale = Downscale(c, c + g, scale, norm, act)
+ f = f // scale[1]
+ c += g
+ self.encoder_blocks.append(block)
+
+ self.bottleneck_block = TFC_TDF(c, c, l, f, bn, norm, act)
+
+ self.decoder_blocks = nn.ModuleList()
+ for i in range(n):
+ block = nn.Module()
+ block.upscale = Upscale(c, c - g, scale, norm, act)
+ f = f * scale[1]
+ c -= g
+ block.tfc_tdf = TFC_TDF(2 * c, c, l, f, bn, norm, act)
+ self.decoder_blocks.append(block)
+
+ self.final_conv = nn.Sequential(
+ nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
+ act,
+ nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False),
+ )
+
+ self.stft = STFT(config.audio)
+
+ def cac2cws(self, x):
+ k = self.num_subbands
+ b, c, f, t = x.shape
+ x = x.reshape(b, c, k, f // k, t)
+ x = x.reshape(b, c * k, f // k, t)
+ return x
+
+ def cws2cac(self, x):
+ k = self.num_subbands
+ b, c, f, t = x.shape
+ x = x.reshape(b, c // k, k, f, t)
+ x = x.reshape(b, c // k, f * k, t)
+ return x
+
+ def forward(self, x):
+
+ x = self.stft(x)
+
+ mix = x = self.cac2cws(x)
+
+ first_conv_out = x = self.first_conv(x)
+
+ x = x.transpose(-1, -2)
+
+ encoder_outputs = []
+ for block in self.encoder_blocks:
+ x = block.tfc_tdf(x)
+ encoder_outputs.append(x)
+ x = block.downscale(x)
+
+ x = self.bottleneck_block(x)
+
+ for block in self.decoder_blocks:
+ x = block.upscale(x)
+ x = torch.cat([x, encoder_outputs.pop()], 1)
+ x = block.tfc_tdf(x)
+
+ x = x.transpose(-1, -2)
+
+ x = x * first_conv_out # reduce artifacts
+
+ x = self.final_conv(torch.cat([mix, x], 1))
+
+ x = self.cws2cac(x)
+
+ if self.num_target_instruments > 1:
+ b, c, f, t = x.shape
+ x = x.reshape(b, self.num_target_instruments, -1, f, t)
+
+ x = self.stft.inverse(x)
+
+ return x
diff --git a/mvsepless/models/mdx_net.py b/mvsepless/models/mdx_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb7a6bc355444ad09f2de15407ca6bee76f25b83
--- /dev/null
+++ b/mvsepless/models/mdx_net.py
@@ -0,0 +1,208 @@
+import torch
+import torch.nn as nn
+import onnxruntime as ort
+import numpy as np
+from typing import Dict, Any, List
+import torch.nn.functional as F
+import sys
+import json
+
+class MDXNet(nn.Module):
+ def __init__(self, dim_f: int, dim_t: int, n_fft: int, hop_length: int, primary_stem: str, compensation: float = 1.0):
+ super().__init__()
+ self.dim_f = dim_f
+ self.dim_t = dim_t
+ self.n_fft = n_fft
+ self.dim_c = 4
+ self.hop_length = hop_length
+ self.primary_stem = primary_stem
+ self.compensation = compensation
+
+ # Внутренний chunk_size MDXNet (фиксированный)
+ self.internal_chunk_size = self.hop_length * (self.dim_t - 1)
+ self.n_bins = self.n_fft // 2 + 1
+
+ # ONNX session будет инициализирован позже
+ self.ort_session = None
+
+ def init_onnx_session(self, onnx_model_path: str, device: str):
+ """Инициализирует ONNX runtime session"""
+ providers = ["CUDAExecutionProvider"] if "cuda" in device else ["CPUExecutionProvider"]
+ self.ort_session = ort.InferenceSession(onnx_model_path, providers=providers)
+
+ # Preload model
+ self.ort_session.run(
+ None,
+ {"input": torch.rand(1, 4, self.dim_f, self.dim_t).numpy()},
+ )
+
+ # Инициализируем оконную функцию на правильном устройстве
+ self.window = torch.hann_window(
+ window_length=self.n_fft, periodic=True
+ )
+
+ out_c = self.dim_c
+
+ self.freq_pad = torch.zeros(
+ [1, out_c, self.n_bins - self.dim_f, self.dim_t]
+ )
+
+ def stft(self, x: torch.Tensor) -> torch.Tensor:
+ """STFT преобразование для MDX-Net"""
+ # Убедимся, что window на том же устройстве, что и x
+ window = self.window.to(x.device)
+
+ x = x.reshape([-1, self.internal_chunk_size])
+ x = torch.stft(
+ x,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ window=window,
+ center=True,
+ return_complex=True,
+ )
+ x = torch.view_as_real(x)
+ x = x.permute([0, 3, 1, 2])
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
+ [-1, 4, self.n_bins, self.dim_t]
+ )
+ return x[:, :, :self.dim_f]
+
+ def istft(self, x: torch.Tensor) -> torch.Tensor:
+ """Обратное STFT преобразование для MDX-Net"""
+ # Убедимся, что window и freq_pad на том же устройстве, что и x
+ window = self.window.to(x.device)
+ freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]).to(x.device)
+
+ x = torch.cat([x, freq_pad], -2)
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
+ [-1, 2, self.n_bins, self.dim_t]
+ )
+ x = x.permute([0, 2, 3, 1])
+ x = x.contiguous()
+ x = torch.view_as_complex(x)
+ x = torch.istft(
+ x,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ window=window,
+ center=True,
+ )
+ return x.reshape([-1, 2, self.internal_chunk_size])
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Прямой проход через MDX-Net"""
+ if self.ort_session is None:
+ raise ValueError("ONNX session not initialized. Call init_onnx_session first.")
+
+ # Преобразуем в numpy для ONNX инференса
+ x_np = x.cpu().numpy()
+ output = self.ort_session.run(None, {"input": x_np})[0]
+ return torch.from_numpy(output).to(x.device)
+
+ def process_wave(self, wave: torch.Tensor, device: torch.device, num_overlap: int, pbar: bool = False) -> torch.Tensor:
+ """Обрабатывает аудио волну через MDX-Net с overlap и chunking"""
+ # Перемещаем wave на нужное устройство
+ wave = wave.to(device)
+
+ # Используем внутренний chunk_size MDXNet
+ chunk_size = self.internal_chunk_size
+ fade_size = chunk_size // 10
+ step = chunk_size // num_overlap
+ border = chunk_size - step
+
+ length_init = wave.shape[-1]
+
+ # Добавляем padding для обработки краев
+ if length_init > 2 * border and border > 0:
+ wave = nn.functional.pad(wave, (border, border), mode="reflect")
+
+ # Создаем оконную функцию на правильном устройстве
+ window = self._get_windowing_array(chunk_size, fade_size).to(device)
+
+ batch_size = 1 # MDXNet обычно использует batch_size=1
+
+ with torch.no_grad():
+ # Инициализируем результат и счетчик на правильном устройстве
+ result = torch.zeros_like(wave, device=device)
+ counter = torch.zeros_like(wave, device=device)
+
+ i = 0
+ batch_data = []
+ batch_locations = []
+
+ # Подсчитываем общее количество чанков для прогресса
+ total_chunks = 0
+ temp_i = 0
+ while temp_i < wave.shape[1]:
+ total_chunks += 1
+ temp_i += step
+
+ processed_chunks = 0
+
+ while i < wave.shape[1]:
+ # Извлекаем чанк
+ part = wave[:, i : i + chunk_size]
+ chunk_len = part.shape[-1]
+
+ # Добавляем padding если чанк меньше chunk_size
+ if chunk_len < chunk_size:
+ pad_mode = "reflect" if chunk_len > chunk_size // 2 else "constant"
+ part = nn.functional.pad(
+ part, (0, chunk_size - chunk_len), mode=pad_mode, value=0
+ )
+
+ batch_data.append(part)
+ batch_locations.append((i, chunk_len))
+ i += step
+
+ # Обрабатываем батч
+ if len(batch_data) >= batch_size or i >= wave.shape[1]:
+ arr = torch.stack(batch_data, dim=0)
+
+ # Обрабатываем каждый чанк в батче
+ for j, (start, seg_len) in enumerate(batch_locations):
+ # STFT -> ONNX inference -> iSTFT
+ spec = self.stft(arr[j:j+1])
+ processed_spec = self(spec)
+ processed_wav = self.istft(processed_spec)
+
+ # Применяем оконную функцию
+ window_segment = window[..., :seg_len]
+ result[:, start : start + seg_len] += processed_wav[0, :, :seg_len] * window_segment
+ counter[:, start : start + seg_len] += window_segment
+
+ # Обновляем прогресс
+ processed_chunks += len(batch_data)
+ if pbar:
+ progress_data = {
+ "processing": {
+ "processed": min(i, wave.shape[1]),
+ "total": wave.shape[1]
+ }
+ }
+ sys.stdout.write(json.dumps(progress_data, ensure_ascii=False) + '\n')
+ sys.stdout.flush()
+
+ # Очищаем батч
+ batch_data.clear()
+ batch_locations.clear()
+
+ # Вычисляем финальный результат
+ estimated_sources = result / counter
+
+ # Убираем padding
+ if length_init > 2 * border and border > 0:
+ estimated_sources = estimated_sources[..., border:-border]
+
+ return estimated_sources
+
+ def _get_windowing_array(self, window_size: int, fade_size: int) -> torch.Tensor:
+ """Генерирует оконную функцию с fade-in и fade-out"""
+ fadein = torch.linspace(0, 1, fade_size)
+ fadeout = torch.linspace(1, 0, fade_size)
+
+ window = torch.ones(window_size)
+ window[-fade_size:] = fadeout
+ window[:fade_size] = fadein
+ return window
\ No newline at end of file
diff --git a/mvsepless/models/scnet/__init__.py b/mvsepless/models/scnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f6ecefede9345237623066dd21ebd8253af1c60
--- /dev/null
+++ b/mvsepless/models/scnet/__init__.py
@@ -0,0 +1 @@
+from .scnet import SCNet
diff --git a/mvsepless/models/scnet/scnet.py b/mvsepless/models/scnet/scnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..71807479eb10bc803f2fb1f7d7f193bd30416167
--- /dev/null
+++ b/mvsepless/models/scnet/scnet.py
@@ -0,0 +1,419 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from collections import deque
+from .separation import SeparationNet
+import typing as tp
+import math
+
+
+class Swish(nn.Module):
+ def forward(self, x):
+ return x * x.sigmoid()
+
+
+class ConvolutionModule(nn.Module):
+ """
+ Convolution Module in SD block.
+
+ Args:
+ channels (int): input/output channels.
+ depth (int): number of layers in the residual branch. Each layer has its own
+ compress (float): amount of channel compression.
+ kernel (int): kernel size for the convolutions.
+ """
+
+ def __init__(self, channels, depth=2, compress=4, kernel=3):
+ super().__init__()
+ assert kernel % 2 == 1
+ self.depth = abs(depth)
+ hidden_size = int(channels / compress)
+ norm = lambda d: nn.GroupNorm(1, d)
+ self.layers = nn.ModuleList([])
+ for _ in range(self.depth):
+ padding = kernel // 2
+ mods = [
+ norm(channels),
+ nn.Conv1d(channels, hidden_size * 2, kernel, padding=padding),
+ nn.GLU(1),
+ nn.Conv1d(
+ hidden_size,
+ hidden_size,
+ kernel,
+ padding=padding,
+ groups=hidden_size,
+ ),
+ norm(hidden_size),
+ Swish(),
+ nn.Conv1d(hidden_size, channels, 1),
+ ]
+ layer = nn.Sequential(*mods)
+ self.layers.append(layer)
+
+ def forward(self, x):
+ for layer in self.layers:
+ x = x + layer(x)
+ return x
+
+
+class FusionLayer(nn.Module):
+ """
+ A FusionLayer within the decoder.
+
+ Args:
+ - channels (int): Number of input channels.
+ - kernel_size (int, optional): Kernel size for the convolutional layer, defaults to 3.
+ - stride (int, optional): Stride for the convolutional layer, defaults to 1.
+ - padding (int, optional): Padding for the convolutional layer, defaults to 1.
+ """
+
+ def __init__(self, channels, kernel_size=3, stride=1, padding=1):
+ super(FusionLayer, self).__init__()
+ self.conv = nn.Conv2d(
+ channels * 2, channels * 2, kernel_size, stride=stride, padding=padding
+ )
+
+ def forward(self, x, skip=None):
+ if skip is not None:
+ x += skip
+ x = x.repeat(1, 2, 1, 1)
+ x = self.conv(x)
+ x = F.glu(x, dim=1)
+ return x
+
+
+class SDlayer(nn.Module):
+ """
+ Implements a Sparse Down-sample Layer for processing different frequency bands separately.
+
+ Args:
+ - channels_in (int): Input channel count.
+ - channels_out (int): Output channel count.
+ - band_configs (dict): A dictionary containing configuration for each frequency band.
+ Keys are 'low', 'mid', 'high' for each band, and values are
+ dictionaries with keys 'SR', 'stride', and 'kernel' for proportion,
+ stride, and kernel size, respectively.
+ """
+
+ def __init__(self, channels_in, channels_out, band_configs):
+ super(SDlayer, self).__init__()
+
+ # Initializing convolutional layers for each band
+ self.convs = nn.ModuleList()
+ self.strides = []
+ self.kernels = []
+ for config in band_configs.values():
+ self.convs.append(
+ nn.Conv2d(
+ channels_in,
+ channels_out,
+ (config["kernel"], 1),
+ (config["stride"], 1),
+ (0, 0),
+ )
+ )
+ self.strides.append(config["stride"])
+ self.kernels.append(config["kernel"])
+
+ # Saving rate proportions for determining splits
+ self.SR_low = band_configs["low"]["SR"]
+ self.SR_mid = band_configs["mid"]["SR"]
+
+ def forward(self, x):
+ B, C, Fr, T = x.shape
+ # Define splitting points based on sampling rates
+ splits = [
+ (0, math.ceil(Fr * self.SR_low)),
+ (math.ceil(Fr * self.SR_low), math.ceil(Fr * (self.SR_low + self.SR_mid))),
+ (math.ceil(Fr * (self.SR_low + self.SR_mid)), Fr),
+ ]
+
+ # Processing each band with the corresponding convolution
+ outputs = []
+ original_lengths = []
+ for conv, stride, kernel, (start, end) in zip(
+ self.convs, self.strides, self.kernels, splits
+ ):
+ extracted = x[:, :, start:end, :]
+ original_lengths.append(end - start)
+ current_length = extracted.shape[2]
+
+ # padding
+ if stride == 1:
+ total_padding = kernel - stride
+ else:
+ total_padding = (stride - current_length % stride) % stride
+ pad_left = total_padding // 2
+ pad_right = total_padding - pad_left
+
+ padded = F.pad(extracted, (0, 0, pad_left, pad_right))
+
+ output = conv(padded)
+ outputs.append(output)
+
+ return outputs, original_lengths
+
+
+class SUlayer(nn.Module):
+ """
+ Implements a Sparse Up-sample Layer in decoder.
+
+ Args:
+ - channels_in: The number of input channels.
+ - channels_out: The number of output channels.
+ - convtr_configs: Dictionary containing the configurations for transposed convolutions.
+ """
+
+ def __init__(self, channels_in, channels_out, band_configs):
+ super(SUlayer, self).__init__()
+
+ # Initializing convolutional layers for each band
+ self.convtrs = nn.ModuleList(
+ [
+ nn.ConvTranspose2d(
+ channels_in,
+ channels_out,
+ [config["kernel"], 1],
+ [config["stride"], 1],
+ )
+ for _, config in band_configs.items()
+ ]
+ )
+
+ def forward(self, x, lengths, origin_lengths):
+ B, C, Fr, T = x.shape
+ # Define splitting points based on input lengths
+ splits = [
+ (0, lengths[0]),
+ (lengths[0], lengths[0] + lengths[1]),
+ (lengths[0] + lengths[1], None),
+ ]
+ # Processing each band with the corresponding convolution
+ outputs = []
+ for idx, (convtr, (start, end)) in enumerate(zip(self.convtrs, splits)):
+ out = convtr(x[:, :, start:end, :])
+ # Calculate the distance to trim the output symmetrically to original length
+ current_Fr_length = out.shape[2]
+ dist = abs(origin_lengths[idx] - current_Fr_length) // 2
+
+ # Trim the output to the original length symmetrically
+ trimmed_out = out[:, :, dist : dist + origin_lengths[idx], :]
+
+ outputs.append(trimmed_out)
+
+ # Concatenate trimmed outputs along the frequency dimension to return the final tensor
+ x = torch.cat(outputs, dim=2)
+
+ return x
+
+
+class SDblock(nn.Module):
+ """
+ Implements a simplified Sparse Down-sample block in encoder.
+
+ Args:
+ - channels_in (int): Number of input channels.
+ - channels_out (int): Number of output channels.
+ - band_config (dict): Configuration for the SDlayer specifying band splits and convolutions.
+ - conv_config (dict): Configuration for convolution modules applied to each band.
+ - depths (list of int): List specifying the convolution depths for low, mid, and high frequency bands.
+ """
+
+ def __init__(
+ self,
+ channels_in,
+ channels_out,
+ band_configs={},
+ conv_config={},
+ depths=[3, 2, 1],
+ kernel_size=3,
+ ):
+ super(SDblock, self).__init__()
+ self.SDlayer = SDlayer(channels_in, channels_out, band_configs)
+
+ # Dynamically create convolution modules for each band based on depths
+ self.conv_modules = nn.ModuleList(
+ [ConvolutionModule(channels_out, depth, **conv_config) for depth in depths]
+ )
+ # Set the kernel_size to an odd number.
+ self.globalconv = nn.Conv2d(
+ channels_out, channels_out, kernel_size, 1, (kernel_size - 1) // 2
+ )
+
+ def forward(self, x):
+ bands, original_lengths = self.SDlayer(x)
+ # B, C, f, T = band.shape
+ bands = [
+ F.gelu(
+ conv(band.permute(0, 2, 1, 3).reshape(-1, band.shape[1], band.shape[3]))
+ .view(band.shape[0], band.shape[2], band.shape[1], band.shape[3])
+ .permute(0, 2, 1, 3)
+ )
+ for conv, band in zip(self.conv_modules, bands)
+ ]
+ lengths = [band.size(-2) for band in bands]
+ full_band = torch.cat(bands, dim=2)
+ skip = full_band
+
+ output = self.globalconv(full_band)
+
+ return output, skip, lengths, original_lengths
+
+
+class SCNet(nn.Module):
+ """
+ The implementation of SCNet: Sparse Compression Network for Music Source Separation. Paper: https://arxiv.org/abs/2401.13276.pdf
+
+ Args:
+ - sources (List[str]): List of sources to be separated.
+ - audio_channels (int): Number of audio channels.
+ - nfft (int): Number of FFTs to determine the frequency dimension of the input.
+ - hop_size (int): Hop size for the STFT.
+ - win_size (int): Window size for STFT.
+ - normalized (bool): Whether to normalize the STFT.
+ - dims (List[int]): List of channel dimensions for each block.
+ - band_SR (List[float]): The proportion of each frequency band.
+ - band_stride (List[int]): The down-sampling ratio of each frequency band.
+ - band_kernel (List[int]): The kernel sizes for down-sampling convolution in each frequency band
+ - conv_depths (List[int]): List specifying the number of convolution modules in each SD block.
+ - compress (int): Compression factor for convolution module.
+ - conv_kernel (int): Kernel size for convolution layer in convolution module.
+ - num_dplayer (int): Number of dual-path layers.
+ - expand (int): Expansion factor in the dual-path RNN, default is 1.
+
+ """
+
+ def __init__(
+ self,
+ sources=["drums", "bass", "other", "vocals"],
+ audio_channels=2,
+ # Main structure
+ dims=[4, 32, 64, 128], # dims = [4, 64, 128, 256] in SCNet-large
+ # STFT
+ nfft=4096,
+ hop_size=1024,
+ win_size=4096,
+ normalized=True,
+ # SD/SU layer
+ band_SR=[0.175, 0.392, 0.433],
+ band_stride=[1, 4, 16],
+ band_kernel=[3, 4, 16],
+ # Convolution Module
+ conv_depths=[3, 2, 1],
+ compress=4,
+ conv_kernel=3,
+ # Dual-path RNN
+ num_dplayer=6,
+ expand=1,
+ ):
+ super().__init__()
+ self.sources = sources
+ self.audio_channels = audio_channels
+ self.dims = dims
+ band_keys = ["low", "mid", "high"]
+ self.band_configs = {
+ band_keys[i]: {
+ "SR": band_SR[i],
+ "stride": band_stride[i],
+ "kernel": band_kernel[i],
+ }
+ for i in range(len(band_keys))
+ }
+ self.hop_length = hop_size
+ self.conv_config = {
+ "compress": compress,
+ "kernel": conv_kernel,
+ }
+
+ self.stft_config = {
+ "n_fft": nfft,
+ "hop_length": hop_size,
+ "win_length": win_size,
+ "center": True,
+ "normalized": normalized,
+ }
+
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+
+ for index in range(len(dims) - 1):
+ enc = SDblock(
+ channels_in=dims[index],
+ channels_out=dims[index + 1],
+ band_configs=self.band_configs,
+ conv_config=self.conv_config,
+ depths=conv_depths,
+ )
+ self.encoder.append(enc)
+
+ dec = nn.Sequential(
+ FusionLayer(channels=dims[index + 1]),
+ SUlayer(
+ channels_in=dims[index + 1],
+ channels_out=(
+ dims[index] if index != 0 else dims[index] * len(sources)
+ ),
+ band_configs=self.band_configs,
+ ),
+ )
+ self.decoder.insert(0, dec)
+
+ self.separation_net = SeparationNet(
+ channels=dims[-1],
+ expand=expand,
+ num_layers=num_dplayer,
+ )
+
+ def forward(self, x):
+ # B, C, L = x.shape
+ B = x.shape[0]
+ # In the initial padding, ensure that the number of frames after the STFT (the length of the T dimension) is even,
+ # so that the RFFT operation can be used in the separation network.
+ padding = self.hop_length - x.shape[-1] % self.hop_length
+ if (x.shape[-1] + padding) // self.hop_length % 2 == 0:
+ padding += self.hop_length
+ x = F.pad(x, (0, padding))
+
+ # STFT
+ L = x.shape[-1]
+ x = x.reshape(-1, L)
+ x = torch.stft(x, **self.stft_config, return_complex=True)
+ x = torch.view_as_real(x)
+ x = x.permute(0, 3, 1, 2).reshape(
+ x.shape[0] // self.audio_channels,
+ x.shape[3] * self.audio_channels,
+ x.shape[1],
+ x.shape[2],
+ )
+
+ B, C, Fr, T = x.shape
+
+ save_skip = deque()
+ save_lengths = deque()
+ save_original_lengths = deque()
+ # encoder
+ for sd_layer in self.encoder:
+ x, skip, lengths, original_lengths = sd_layer(x)
+ save_skip.append(skip)
+ save_lengths.append(lengths)
+ save_original_lengths.append(original_lengths)
+
+ # separation
+ x = self.separation_net(x)
+
+ # decoder
+ for fusion_layer, su_layer in self.decoder:
+ x = fusion_layer(x, save_skip.pop())
+ x = su_layer(x, save_lengths.pop(), save_original_lengths.pop())
+
+ # output
+ n = self.dims[0]
+ x = x.view(B, n, -1, Fr, T)
+ x = x.reshape(-1, 2, Fr, T).permute(0, 2, 3, 1)
+ x = torch.view_as_complex(x.contiguous())
+ x = torch.istft(x, **self.stft_config)
+ x = x.reshape(B, len(self.sources), self.audio_channels, -1)
+
+ x = x[:, :, :, :-padding]
+
+ return x
diff --git a/mvsepless/models/scnet/separation.py b/mvsepless/models/scnet/separation.py
new file mode 100644
index 0000000000000000000000000000000000000000..8965e2c8b14fa2c1fb6a2766e840c45128e1303d
--- /dev/null
+++ b/mvsepless/models/scnet/separation.py
@@ -0,0 +1,129 @@
+import torch
+import torch.nn as nn
+from torch.nn.modules.rnn import LSTM
+
+
+class FeatureConversion(nn.Module):
+ """
+ Integrates into the adjacent Dual-Path layer.
+
+ Args:
+ channels (int): Number of input channels.
+ inverse (bool): If True, uses ifft; otherwise, uses rfft.
+ """
+
+ def __init__(self, channels, inverse):
+ super().__init__()
+ self.inverse = inverse
+ self.channels = channels
+
+ def forward(self, x):
+ # B, C, F, T = x.shape
+ if self.inverse:
+ x = x.float()
+ x_r = x[:, : self.channels // 2, :, :]
+ x_i = x[:, self.channels // 2 :, :, :]
+ x = torch.complex(x_r, x_i)
+ x = torch.fft.irfft(x, dim=3, norm="ortho")
+ else:
+ x = x.float()
+ x = torch.fft.rfft(x, dim=3, norm="ortho")
+ x_real = x.real
+ x_imag = x.imag
+ x = torch.cat([x_real, x_imag], dim=1)
+ return x
+
+
+class DualPathRNN(nn.Module):
+ """
+ Dual-Path RNN in Separation Network.
+
+ Args:
+ d_model (int): The number of expected features in the input (input_size).
+ expand (int): Expansion factor used to calculate the hidden_size of LSTM.
+ bidirectional (bool): If True, becomes a bidirectional LSTM.
+ """
+
+ def __init__(self, d_model, expand, bidirectional=True):
+ super(DualPathRNN, self).__init__()
+
+ self.d_model = d_model
+ self.hidden_size = d_model * expand
+ self.bidirectional = bidirectional
+ # Initialize LSTM layers and normalization layers
+ self.lstm_layers = nn.ModuleList(
+ [self._init_lstm_layer(self.d_model, self.hidden_size) for _ in range(2)]
+ )
+ self.linear_layers = nn.ModuleList(
+ [nn.Linear(self.hidden_size * 2, self.d_model) for _ in range(2)]
+ )
+ self.norm_layers = nn.ModuleList([nn.GroupNorm(1, d_model) for _ in range(2)])
+
+ def _init_lstm_layer(self, d_model, hidden_size):
+ return LSTM(
+ d_model,
+ hidden_size,
+ num_layers=1,
+ bidirectional=self.bidirectional,
+ batch_first=True,
+ )
+
+ def forward(self, x):
+ B, C, F, T = x.shape
+
+ # Process dual-path rnn
+ original_x = x
+ # Frequency-path
+ x = self.norm_layers[0](x)
+ x = x.transpose(1, 3).contiguous().view(B * T, F, C)
+ x, _ = self.lstm_layers[0](x)
+ x = self.linear_layers[0](x)
+ x = x.view(B, T, F, C).transpose(1, 3)
+ x = x + original_x
+
+ original_x = x
+ # Time-path
+ x = self.norm_layers[1](x)
+ x = x.transpose(1, 2).contiguous().view(B * F, C, T).transpose(1, 2)
+ x, _ = self.lstm_layers[1](x)
+ x = self.linear_layers[1](x)
+ x = x.transpose(1, 2).contiguous().view(B, F, C, T).transpose(1, 2)
+ x = x + original_x
+
+ return x
+
+
+class SeparationNet(nn.Module):
+ """
+ Implements a simplified Sparse Down-sample block in an encoder architecture.
+
+ Args:
+ - channels (int): Number input channels.
+ - expand (int): Expansion factor used to calculate the hidden_size of LSTM.
+ - num_layers (int): Number of dual-path layers.
+ """
+
+ def __init__(self, channels, expand=1, num_layers=6):
+ super(SeparationNet, self).__init__()
+
+ self.num_layers = num_layers
+
+ self.dp_modules = nn.ModuleList(
+ [
+ DualPathRNN(channels * (2 if i % 2 == 1 else 1), expand)
+ for i in range(num_layers)
+ ]
+ )
+
+ self.feature_conversion = nn.ModuleList(
+ [
+ FeatureConversion(channels * 2, inverse=False if i % 2 == 0 else True)
+ for i in range(num_layers)
+ ]
+ )
+
+ def forward(self, x):
+ for i in range(self.num_layers):
+ x = self.dp_modules[i](x)
+ x = self.feature_conversion[i](x)
+ return x
diff --git a/mvsepless/models/scnet_unofficial/__init__.py b/mvsepless/models/scnet_unofficial/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..298d9939f5177c6b24cca743c83a351a84a6ffce
--- /dev/null
+++ b/mvsepless/models/scnet_unofficial/__init__.py
@@ -0,0 +1 @@
+from models.scnet_unofficial.scnet import SCNet
diff --git a/mvsepless/models/scnet_unofficial/modules/__init__.py b/mvsepless/models/scnet_unofficial/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..69617bb15044d9bbfd0211fcdfa0fa605b01c048
--- /dev/null
+++ b/mvsepless/models/scnet_unofficial/modules/__init__.py
@@ -0,0 +1,3 @@
+from models.scnet_unofficial.modules.dualpath_rnn import DualPathRNN
+from models.scnet_unofficial.modules.sd_encoder import SDBlock
+from models.scnet_unofficial.modules.su_decoder import SUBlock
diff --git a/mvsepless/models/scnet_unofficial/modules/dualpath_rnn.py b/mvsepless/models/scnet_unofficial/modules/dualpath_rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..644d05a19cc83798402ee08f269b08a52eaeac09
--- /dev/null
+++ b/mvsepless/models/scnet_unofficial/modules/dualpath_rnn.py
@@ -0,0 +1,238 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as Func
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ return Func.normalize(x, dim=-1) * self.scale * self.gamma
+
+
+class MambaModule(nn.Module):
+ def __init__(self, d_model, d_state, d_conv, d_expand):
+ super().__init__()
+ self.norm = RMSNorm(dim=d_model)
+ self.mamba = Mamba(
+ d_model=d_model, d_state=d_state, d_conv=d_conv, d_expand=d_expand
+ )
+
+ def forward(self, x):
+ x = x + self.mamba(self.norm(x))
+ return x
+
+
+class RNNModule(nn.Module):
+ """
+ RNNModule class implements a recurrent neural network module with LSTM cells.
+
+ Args:
+ - input_dim (int): Dimensionality of the input features.
+ - hidden_dim (int): Dimensionality of the hidden state of the LSTM.
+ - bidirectional (bool, optional): If True, uses bidirectional LSTM. Defaults to True.
+
+ Shapes:
+ - Input: (B, T, D) where
+ B is batch size,
+ T is sequence length,
+ D is input dimensionality.
+ - Output: (B, T, D) where
+ B is batch size,
+ T is sequence length,
+ D is input dimensionality.
+ """
+
+ def __init__(self, input_dim: int, hidden_dim: int, bidirectional: bool = True):
+ """
+ Initializes RNNModule with input dimension, hidden dimension, and bidirectional flag.
+ """
+ super().__init__()
+ self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=input_dim)
+ self.rnn = nn.LSTM(
+ input_dim, hidden_dim, batch_first=True, bidirectional=bidirectional
+ )
+ self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, input_dim)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the RNNModule.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, T, D).
+
+ Returns:
+ - torch.Tensor: Output tensor of shape (B, T, D).
+ """
+ x = x.transpose(1, 2)
+ x = self.groupnorm(x)
+ x = x.transpose(1, 2)
+
+ x, (hidden, _) = self.rnn(x)
+ x = self.fc(x)
+ return x
+
+
+class RFFTModule(nn.Module):
+ """
+ RFFTModule class implements a module for performing real-valued Fast Fourier Transform (FFT)
+ or its inverse on input tensors.
+
+ Args:
+ - inverse (bool, optional): If False, performs forward FFT. If True, performs inverse FFT. Defaults to False.
+
+ Shapes:
+ - Input: (B, F, T, D) where
+ B is batch size,
+ F is the number of features,
+ T is sequence length,
+ D is input dimensionality.
+ - Output: (B, F, T // 2 + 1, D * 2) if performing forward FFT.
+ (B, F, T, D // 2, 2) if performing inverse FFT.
+ """
+
+ def __init__(self, inverse: bool = False):
+ """
+ Initializes RFFTModule with inverse flag.
+ """
+ super().__init__()
+ self.inverse = inverse
+
+ def forward(self, x: torch.Tensor, time_dim: int) -> torch.Tensor:
+ """
+ Performs forward or inverse FFT on the input tensor x.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, F, T, D).
+ - time_dim (int): Input size of time dimension.
+
+ Returns:
+ - torch.Tensor: Output tensor after FFT or its inverse operation.
+ """
+ dtype = x.dtype
+ B, F, T, D = x.shape
+
+ # RuntimeError: cuFFT only supports dimensions whose sizes are powers of two when computing in half precision
+ x = x.float()
+
+ if not self.inverse:
+ x = torch.fft.rfft(x, dim=2)
+ x = torch.view_as_real(x)
+ x = x.reshape(B, F, T // 2 + 1, D * 2)
+ else:
+ x = x.reshape(B, F, T, D // 2, 2)
+ x = torch.view_as_complex(x)
+ x = torch.fft.irfft(x, n=time_dim, dim=2)
+
+ x = x.to(dtype)
+ return x
+
+ def extra_repr(self) -> str:
+ """
+ Returns extra representation string with module's configuration.
+ """
+ return f"inverse={self.inverse}"
+
+
+class DualPathRNN(nn.Module):
+ """
+ DualPathRNN class implements a neural network with alternating layers of RNNModule and RFFTModule.
+
+ Args:
+ - n_layers (int): Number of layers in the network.
+ - input_dim (int): Dimensionality of the input features.
+ - hidden_dim (int): Dimensionality of the hidden state of the RNNModule.
+
+ Shapes:
+ - Input: (B, F, T, D) where
+ B is batch size,
+ F is the number of features (frequency dimension),
+ T is sequence length (time dimension),
+ D is input dimensionality (channel dimension).
+ - Output: (B, F, T, D) where
+ B is batch size,
+ F is the number of features (frequency dimension),
+ T is sequence length (time dimension),
+ D is input dimensionality (channel dimension).
+ """
+
+ def __init__(
+ self,
+ n_layers: int,
+ input_dim: int,
+ hidden_dim: int,
+ use_mamba: bool = False,
+ d_state: int = 16,
+ d_conv: int = 4,
+ d_expand: int = 2,
+ ):
+ """
+ Initializes DualPathRNN with the specified number of layers, input dimension, and hidden dimension.
+ """
+ super().__init__()
+
+ if use_mamba:
+ from mamba_ssm.modules.mamba_simple import Mamba
+
+ net = MambaModule
+ dkwargs = {
+ "d_model": input_dim,
+ "d_state": d_state,
+ "d_conv": d_conv,
+ "d_expand": d_expand,
+ }
+ ukwargs = {
+ "d_model": input_dim * 2,
+ "d_state": d_state,
+ "d_conv": d_conv,
+ "d_expand": d_expand * 2,
+ }
+ else:
+ net = RNNModule
+ dkwargs = {"input_dim": input_dim, "hidden_dim": hidden_dim}
+ ukwargs = {"input_dim": input_dim * 2, "hidden_dim": hidden_dim * 2}
+
+ self.layers = nn.ModuleList()
+ for i in range(1, n_layers + 1):
+ kwargs = dkwargs if i % 2 == 1 else ukwargs
+ layer = nn.ModuleList(
+ [
+ net(**kwargs),
+ net(**kwargs),
+ RFFTModule(inverse=(i % 2 == 0)),
+ ]
+ )
+ self.layers.append(layer)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the DualPathRNN.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, F, T, D).
+
+ Returns:
+ - torch.Tensor: Output tensor of shape (B, F, T, D).
+ """
+
+ time_dim = x.shape[2]
+
+ for time_layer, freq_layer, rfft_layer in self.layers:
+ B, F, T, D = x.shape
+
+ x = x.reshape((B * F), T, D)
+ x = time_layer(x)
+ x = x.reshape(B, F, T, D)
+ x = x.permute(0, 2, 1, 3)
+
+ x = x.reshape((B * T), F, D)
+ x = freq_layer(x)
+ x = x.reshape(B, T, F, D)
+ x = x.permute(0, 2, 1, 3)
+
+ x = rfft_layer(x, time_dim)
+
+ return x
diff --git a/mvsepless/models/scnet_unofficial/modules/sd_encoder.py b/mvsepless/models/scnet_unofficial/modules/sd_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..742577f480693671437dc50358a1a65d251b6e9b
--- /dev/null
+++ b/mvsepless/models/scnet_unofficial/modules/sd_encoder.py
@@ -0,0 +1,285 @@
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+
+from models.scnet_unofficial.utils import create_intervals
+
+
+class Downsample(nn.Module):
+ """
+ Downsample class implements a module for downsampling input tensors using 2D convolution.
+
+ Args:
+ - input_dim (int): Dimensionality of the input channels.
+ - output_dim (int): Dimensionality of the output channels.
+ - stride (int): Stride value for the convolution operation.
+
+ Shapes:
+ - Input: (B, C_in, F, T) where
+ B is batch size,
+ C_in is the number of input channels,
+ F is the frequency dimension,
+ T is the time dimension.
+ - Output: (B, C_out, F // stride, T) where
+ B is batch size,
+ C_out is the number of output channels,
+ F // stride is the downsampled frequency dimension.
+
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ stride: int,
+ ):
+ """
+ Initializes Downsample with input dimension, output dimension, and stride.
+ """
+ super().__init__()
+ self.conv = nn.Conv2d(input_dim, output_dim, 1, (stride, 1))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the Downsample module.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, C_in, F, T).
+
+ Returns:
+ - torch.Tensor: Downsampled tensor of shape (B, C_out, F // stride, T).
+ """
+ return self.conv(x)
+
+
+class ConvolutionModule(nn.Module):
+ """
+ ConvolutionModule class implements a module with a sequence of convolutional layers similar to Conformer.
+
+ Args:
+ - input_dim (int): Dimensionality of the input features.
+ - hidden_dim (int): Dimensionality of the hidden features.
+ - kernel_sizes (List[int]): List of kernel sizes for the convolutional layers.
+ - bias (bool, optional): If True, adds a learnable bias to the output. Default is False.
+
+ Shapes:
+ - Input: (B, T, D) where
+ B is batch size,
+ T is sequence length,
+ D is input dimensionality.
+ - Output: (B, T, D) where
+ B is batch size,
+ T is sequence length,
+ D is input dimensionality.
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ kernel_sizes: List[int],
+ bias: bool = False,
+ ) -> None:
+ """
+ Initializes ConvolutionModule with input dimension, hidden dimension, kernel sizes, and bias.
+ """
+ super().__init__()
+ self.sequential = nn.Sequential(
+ nn.GroupNorm(num_groups=1, num_channels=input_dim),
+ nn.Conv1d(
+ input_dim,
+ 2 * hidden_dim,
+ kernel_sizes[0],
+ stride=1,
+ padding=(kernel_sizes[0] - 1) // 2,
+ bias=bias,
+ ),
+ nn.GLU(dim=1),
+ nn.Conv1d(
+ hidden_dim,
+ hidden_dim,
+ kernel_sizes[1],
+ stride=1,
+ padding=(kernel_sizes[1] - 1) // 2,
+ groups=hidden_dim,
+ bias=bias,
+ ),
+ nn.GroupNorm(num_groups=1, num_channels=hidden_dim),
+ nn.SiLU(),
+ nn.Conv1d(
+ hidden_dim,
+ input_dim,
+ kernel_sizes[2],
+ stride=1,
+ padding=(kernel_sizes[2] - 1) // 2,
+ bias=bias,
+ ),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the ConvolutionModule.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, T, D).
+
+ Returns:
+ - torch.Tensor: Output tensor of shape (B, T, D).
+ """
+ x = x.transpose(1, 2)
+ x = x + self.sequential(x)
+ x = x.transpose(1, 2)
+ return x
+
+
+class SDLayer(nn.Module):
+ """
+ SDLayer class implements a subband decomposition layer with downsampling and convolutional modules.
+
+ Args:
+ - subband_interval (Tuple[float, float]): Tuple representing the frequency interval for subband decomposition.
+ - input_dim (int): Dimensionality of the input channels.
+ - output_dim (int): Dimensionality of the output channels after downsampling.
+ - downsample_stride (int): Stride value for the downsampling operation.
+ - n_conv_modules (int): Number of convolutional modules.
+ - kernel_sizes (List[int]): List of kernel sizes for the convolutional layers.
+ - bias (bool, optional): If True, adds a learnable bias to the convolutional layers. Default is True.
+
+ Shapes:
+ - Input: (B, Fi, T, Ci) where
+ B is batch size,
+ Fi is the number of input subbands,
+ T is sequence length, and
+ Ci is the number of input channels.
+ - Output: (B, Fi+1, T, Ci+1) where
+ B is batch size,
+ Fi+1 is the number of output subbands,
+ T is sequence length,
+ Ci+1 is the number of output channels.
+ """
+
+ def __init__(
+ self,
+ subband_interval: Tuple[float, float],
+ input_dim: int,
+ output_dim: int,
+ downsample_stride: int,
+ n_conv_modules: int,
+ kernel_sizes: List[int],
+ bias: bool = True,
+ ):
+ """
+ Initializes SDLayer with subband interval, input dimension,
+ output dimension, downsample stride, number of convolutional modules, kernel sizes, and bias.
+ """
+ super().__init__()
+ self.subband_interval = subband_interval
+ self.downsample = Downsample(input_dim, output_dim, downsample_stride)
+ self.activation = nn.GELU()
+ conv_modules = [
+ ConvolutionModule(
+ input_dim=output_dim,
+ hidden_dim=output_dim // 4,
+ kernel_sizes=kernel_sizes,
+ bias=bias,
+ )
+ for _ in range(n_conv_modules)
+ ]
+ self.conv_modules = nn.Sequential(*conv_modules)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the SDLayer.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci).
+
+ Returns:
+ - torch.Tensor: Output tensor of shape (B, Fi+1, T, Ci+1).
+ """
+ B, F, T, C = x.shape
+ x = x[:, int(self.subband_interval[0] * F) : int(self.subband_interval[1] * F)]
+ x = x.permute(0, 3, 1, 2)
+ x = self.downsample(x)
+ x = self.activation(x)
+ x = x.permute(0, 2, 3, 1)
+
+ B, F, T, C = x.shape
+ x = x.reshape((B * F), T, C)
+ x = self.conv_modules(x)
+ x = x.reshape(B, F, T, C)
+
+ return x
+
+
+class SDBlock(nn.Module):
+ """
+ SDBlock class implements a block with subband decomposition layers and global convolution.
+
+ Args:
+ - input_dim (int): Dimensionality of the input channels.
+ - output_dim (int): Dimensionality of the output channels.
+ - bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands.
+ - downsample_strides (List[int]): List of stride values for downsampling in each subband layer.
+ - n_conv_modules (List[int]): List specifying the number of convolutional modules in each subband layer.
+ - kernel_sizes (List[int], optional): List of kernel sizes for the convolutional layers. Default is None.
+
+ Shapes:
+ - Input: (B, Fi, T, Ci) where
+ B is batch size,
+ Fi is the number of input subbands,
+ T is sequence length,
+ Ci is the number of input channels.
+ - Output: (B, Fi+1, T, Ci+1) where
+ B is batch size,
+ Fi+1 is the number of output subbands,
+ T is sequence length,
+ Ci+1 is the number of output channels.
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ bandsplit_ratios: List[float],
+ downsample_strides: List[int],
+ n_conv_modules: List[int],
+ kernel_sizes: List[int] = None,
+ ):
+ """
+ Initializes SDBlock with input dimension, output dimension, band split ratios, downsample strides, number of convolutional modules, and kernel sizes.
+ """
+ super().__init__()
+ if kernel_sizes is None:
+ kernel_sizes = [3, 3, 1]
+ assert sum(bandsplit_ratios) == 1, "The split ratios must sum up to 1."
+ subband_intervals = create_intervals(bandsplit_ratios)
+ self.sd_layers = nn.ModuleList(
+ SDLayer(
+ input_dim=input_dim,
+ output_dim=output_dim,
+ subband_interval=sbi,
+ downsample_stride=dss,
+ n_conv_modules=ncm,
+ kernel_sizes=kernel_sizes,
+ )
+ for sbi, dss, ncm in zip(
+ subband_intervals, downsample_strides, n_conv_modules
+ )
+ )
+ self.global_conv2d = nn.Conv2d(output_dim, output_dim, 1, 1)
+
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Performs forward pass through the SDBlock.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci).
+
+ Returns:
+ - Tuple[torch.Tensor, torch.Tensor]: Output tensor and skip connection tensor.
+ """
+ x_skip = torch.concat([layer(x) for layer in self.sd_layers], dim=1)
+ x = self.global_conv2d(x_skip.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
+ return x, x_skip
diff --git a/mvsepless/models/scnet_unofficial/modules/su_decoder.py b/mvsepless/models/scnet_unofficial/modules/su_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..660c1fa6cbfd9b43bed73204a0bb6593524de272
--- /dev/null
+++ b/mvsepless/models/scnet_unofficial/modules/su_decoder.py
@@ -0,0 +1,241 @@
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+
+from models.scnet_unofficial.utils import get_convtranspose_output_padding
+
+
+class FusionLayer(nn.Module):
+ """
+ FusionLayer class implements a module for fusing two input tensors using convolutional operations.
+
+ Args:
+ - input_dim (int): Dimensionality of the input channels.
+ - kernel_size (int, optional): Kernel size for the convolutional layer. Default is 3.
+ - stride (int, optional): Stride value for the convolutional layer. Default is 1.
+ - padding (int, optional): Padding value for the convolutional layer. Default is 1.
+
+ Shapes:
+ - Input: (B, F, T, C) and (B, F, T, C) where
+ B is batch size,
+ F is the number of features,
+ T is sequence length,
+ C is input dimensionality.
+ - Output: (B, F, T, C) where
+ B is batch size,
+ F is the number of features,
+ T is sequence length,
+ C is input dimensionality.
+ """
+
+ def __init__(
+ self, input_dim: int, kernel_size: int = 3, stride: int = 1, padding: int = 1
+ ):
+ """
+ Initializes FusionLayer with input dimension, kernel size, stride, and padding.
+ """
+ super().__init__()
+ self.conv = nn.Conv2d(
+ input_dim * 2,
+ input_dim * 2,
+ kernel_size=(kernel_size, 1),
+ stride=(stride, 1),
+ padding=(padding, 0),
+ )
+ self.activation = nn.GLU()
+
+ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the FusionLayer.
+
+ Args:
+ - x1 (torch.Tensor): First input tensor of shape (B, F, T, C).
+ - x2 (torch.Tensor): Second input tensor of shape (B, F, T, C).
+
+ Returns:
+ - torch.Tensor: Output tensor of shape (B, F, T, C).
+ """
+ x = x1 + x2
+ x = x.repeat(1, 1, 1, 2)
+ x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
+ x = self.activation(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ Upsample class implements a module for upsampling input tensors using transposed 2D convolution.
+
+ Args:
+ - input_dim (int): Dimensionality of the input channels.
+ - output_dim (int): Dimensionality of the output channels.
+ - stride (int): Stride value for the transposed convolution operation.
+ - output_padding (int): Output padding value for the transposed convolution operation.
+
+ Shapes:
+ - Input: (B, C_in, F, T) where
+ B is batch size,
+ C_in is the number of input channels,
+ F is the frequency dimension,
+ T is the time dimension.
+ - Output: (B, C_out, F * stride + output_padding, T) where
+ B is batch size,
+ C_out is the number of output channels,
+ F * stride + output_padding is the upsampled frequency dimension.
+ """
+
+ def __init__(
+ self, input_dim: int, output_dim: int, stride: int, output_padding: int
+ ):
+ """
+ Initializes Upsample with input dimension, output dimension, stride, and output padding.
+ """
+ super().__init__()
+ self.conv = nn.ConvTranspose2d(
+ input_dim, output_dim, 1, (stride, 1), output_padding=(output_padding, 0)
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the Upsample module.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, C_in, F, T).
+
+ Returns:
+ - torch.Tensor: Output tensor of shape (B, C_out, F * stride + output_padding, T).
+ """
+ return self.conv(x)
+
+
+class SULayer(nn.Module):
+ """
+ SULayer class implements a subband upsampling layer using transposed convolution.
+
+ Args:
+ - input_dim (int): Dimensionality of the input channels.
+ - output_dim (int): Dimensionality of the output channels.
+ - upsample_stride (int): Stride value for the upsampling operation.
+ - subband_shape (int): Shape of the subband.
+ - sd_interval (Tuple[int, int]): Start and end indices of the subband interval.
+
+ Shapes:
+ - Input: (B, F, T, C) where
+ B is batch size,
+ F is the number of features,
+ T is sequence length,
+ C is input dimensionality.
+ - Output: (B, F, T, C) where
+ B is batch size,
+ F is the number of features,
+ T is sequence length,
+ C is input dimensionality.
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ upsample_stride: int,
+ subband_shape: int,
+ sd_interval: Tuple[int, int],
+ ):
+ """
+ Initializes SULayer with input dimension, output dimension, upsample stride, subband shape, and subband interval.
+ """
+ super().__init__()
+ sd_shape = sd_interval[1] - sd_interval[0]
+ upsample_output_padding = get_convtranspose_output_padding(
+ input_shape=sd_shape, output_shape=subband_shape, stride=upsample_stride
+ )
+ self.upsample = Upsample(
+ input_dim=input_dim,
+ output_dim=output_dim,
+ stride=upsample_stride,
+ output_padding=upsample_output_padding,
+ )
+ self.sd_interval = sd_interval
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the SULayer.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, F, T, C).
+
+ Returns:
+ - torch.Tensor: Output tensor of shape (B, F, T, C).
+ """
+ x = x[:, self.sd_interval[0] : self.sd_interval[1]]
+ x = x.permute(0, 3, 1, 2)
+ x = self.upsample(x)
+ x = x.permute(0, 2, 3, 1)
+ return x
+
+
+class SUBlock(nn.Module):
+ """
+ SUBlock class implements a block with fusion layer and subband upsampling layers.
+
+ Args:
+ - input_dim (int): Dimensionality of the input channels.
+ - output_dim (int): Dimensionality of the output channels.
+ - upsample_strides (List[int]): List of stride values for the upsampling operations.
+ - subband_shapes (List[int]): List of shapes for the subbands.
+ - sd_intervals (List[Tuple[int, int]]): List of intervals for subband decomposition.
+
+ Shapes:
+ - Input: (B, Fi-1, T, Ci-1) and (B, Fi-1, T, Ci-1) where
+ B is batch size,
+ Fi-1 is the number of input subbands,
+ T is sequence length,
+ Ci-1 is the number of input channels.
+ - Output: (B, Fi, T, Ci) where
+ B is batch size,
+ Fi is the number of output subbands,
+ T is sequence length,
+ Ci is the number of output channels.
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ upsample_strides: List[int],
+ subband_shapes: List[int],
+ sd_intervals: List[Tuple[int, int]],
+ ):
+ """
+ Initializes SUBlock with input dimension, output dimension,
+ upsample strides, subband shapes, and subband intervals.
+ """
+ super().__init__()
+ self.fusion_layer = FusionLayer(input_dim=input_dim)
+ self.su_layers = nn.ModuleList(
+ SULayer(
+ input_dim=input_dim,
+ output_dim=output_dim,
+ upsample_stride=uss,
+ subband_shape=sbs,
+ sd_interval=sdi,
+ )
+ for i, (uss, sbs, sdi) in enumerate(
+ zip(upsample_strides, subband_shapes, sd_intervals)
+ )
+ )
+
+ def forward(self, x: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the SUBlock.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, Fi-1, T, Ci-1).
+ - x_skip (torch.Tensor): Input skip connection tensor of shape (B, Fi-1, T, Ci-1).
+
+ Returns:
+ - torch.Tensor: Output tensor of shape (B, Fi, T, Ci).
+ """
+ x = self.fusion_layer(x, x_skip)
+ x = torch.concat([layer(x) for layer in self.su_layers], dim=1)
+ return x
diff --git a/mvsepless/models/scnet_unofficial/scnet.py b/mvsepless/models/scnet_unofficial/scnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6dcf7285da2358b6bf2dcd2b9a177fed6022a0d
--- /dev/null
+++ b/mvsepless/models/scnet_unofficial/scnet.py
@@ -0,0 +1,246 @@
+"""
+SCNet - great paper, great implementation
+https://arxiv.org/pdf/2401.13276.pdf
+https://github.com/amanteur/SCNet-PyTorch
+"""
+
+from typing import List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchaudio
+
+from models.scnet_unofficial.modules import DualPathRNN, SDBlock, SUBlock
+from models.scnet_unofficial.utils import compute_sd_layer_shapes, compute_gcr
+
+from einops import rearrange, pack, unpack
+from functools import partial
+
+from beartype.typing import Tuple, Optional, List, Callable
+from beartype import beartype
+
+
+def exists(val):
+ return val is not None
+
+
+def default(v, d):
+ return v if exists(v) else d
+
+
+def pack_one(t, pattern):
+ return pack([t], pattern)
+
+
+def unpack_one(t, ps, pattern):
+ return unpack(t, ps, pattern)[0]
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
+
+
+class BandSplit(nn.Module):
+ @beartype
+ def __init__(self, dim, dim_inputs: Tuple[int, ...]):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_features = ModuleList([])
+
+ for dim_in in dim_inputs:
+ net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
+
+ self.to_features.append(net)
+
+ def forward(self, x):
+ x = x.split(self.dim_inputs, dim=-1)
+
+ outs = []
+ for split_input, to_feature in zip(x, self.to_features):
+ split_output = to_feature(split_input)
+ outs.append(split_output)
+
+ return torch.stack(outs, dim=-2)
+
+
+class SCNet(nn.Module):
+ """
+ SCNet class implements a source separation network,
+ which explicitly split the spectrogram of the mixture into several subbands
+ and introduce a sparsity-based encoder to model different frequency bands.
+
+ Paper: "SCNET: SPARSE COMPRESSION NETWORK FOR MUSIC SOURCE SEPARATION"
+ Authors: Weinan Tong, Jiaxu Zhu et al.
+ Link: https://arxiv.org/abs/2401.13276.pdf
+
+ Args:
+ - n_fft (int): Number of FFTs to determine the frequency dimension of the input.
+ - dims (List[int]): List of channel dimensions for each block.
+ - bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands.
+ - downsample_strides (List[int]): List of stride values for downsampling in each block.
+ - n_conv_modules (List[int]): List specifying the number of convolutional modules in each block.
+ - n_rnn_layers (int): Number of recurrent layers in the dual path RNN.
+ - rnn_hidden_dim (int): Dimensionality of the hidden state in the dual path RNN.
+ - n_sources (int, optional): Number of sources to be separated. Default is 4.
+
+ Shapes:
+ - Input: (B, C, T) where
+ B is batch size,
+ C is channel dim (mono / stereo),
+ T is time dim
+ - Output: (B, N, C, T) where
+ B is batch size,
+ N is the number of sources.
+ C is channel dim (mono / stereo),
+ T is sequence length,
+ """
+
+ @beartype
+ def __init__(
+ self,
+ n_fft: int,
+ dims: List[int],
+ bandsplit_ratios: List[float],
+ downsample_strides: List[int],
+ n_conv_modules: List[int],
+ n_rnn_layers: int,
+ rnn_hidden_dim: int,
+ n_sources: int = 4,
+ hop_length: int = 1024,
+ win_length: int = 4096,
+ stft_window_fn: Optional[Callable] = None,
+ stft_normalized: bool = False,
+ **kwargs,
+ ):
+ """
+ Initializes SCNet with input parameters.
+ """
+ super().__init__()
+ self.assert_input_data(
+ bandsplit_ratios,
+ downsample_strides,
+ n_conv_modules,
+ )
+
+ n_blocks = len(dims) - 1
+ n_freq_bins = n_fft // 2 + 1
+ subband_shapes, sd_intervals = compute_sd_layer_shapes(
+ input_shape=n_freq_bins,
+ bandsplit_ratios=bandsplit_ratios,
+ downsample_strides=downsample_strides,
+ n_layers=n_blocks,
+ )
+ self.sd_blocks = nn.ModuleList(
+ SDBlock(
+ input_dim=dims[i],
+ output_dim=dims[i + 1],
+ bandsplit_ratios=bandsplit_ratios,
+ downsample_strides=downsample_strides,
+ n_conv_modules=n_conv_modules,
+ )
+ for i in range(n_blocks)
+ )
+ self.dualpath_blocks = DualPathRNN(
+ n_layers=n_rnn_layers,
+ input_dim=dims[-1],
+ hidden_dim=rnn_hidden_dim,
+ **kwargs,
+ )
+ self.su_blocks = nn.ModuleList(
+ SUBlock(
+ input_dim=dims[i + 1],
+ output_dim=dims[i] if i != 0 else dims[i] * n_sources,
+ subband_shapes=subband_shapes[i],
+ sd_intervals=sd_intervals[i],
+ upsample_strides=downsample_strides,
+ )
+ for i in reversed(range(n_blocks))
+ )
+ self.gcr = compute_gcr(subband_shapes)
+
+ self.stft_kwargs = dict(
+ n_fft=n_fft,
+ hop_length=hop_length,
+ win_length=win_length,
+ normalized=stft_normalized,
+ )
+
+ self.stft_window_fn = partial(
+ default(stft_window_fn, torch.hann_window), win_length
+ )
+ self.n_sources = n_sources
+ self.hop_length = hop_length
+
+ @staticmethod
+ def assert_input_data(*args):
+ """
+ Asserts that the shapes of input features are equal.
+ """
+ for arg1 in args:
+ for arg2 in args:
+ if len(arg1) != len(arg2):
+ raise ValueError(
+ f"Shapes of input features {arg1} and {arg2} are not equal."
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the SCNet.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, C, T).
+
+ Returns:
+ - torch.Tensor: Output tensor of shape (B, N, C, T).
+ """
+
+ device = x.device
+ stft_window = self.stft_window_fn(device=device)
+
+ if x.ndim == 2:
+ x = rearrange(x, "b t -> b 1 t")
+
+ c = x.shape[1]
+
+ stft_pad = self.hop_length - x.shape[-1] % self.hop_length
+ x = F.pad(x, (0, stft_pad))
+
+ # stft
+ x, ps = pack_one(x, "* t")
+ x = torch.stft(x, **self.stft_kwargs, window=stft_window, return_complex=True)
+ x = torch.view_as_real(x)
+ x = unpack_one(x, ps, "* c f t")
+ x = rearrange(x, "b c f t r -> b f t (c r)")
+
+ # encoder part
+ x_skips = []
+ for sd_block in self.sd_blocks:
+ x, x_skip = sd_block(x)
+ x_skips.append(x_skip)
+
+ # separation part
+ x = self.dualpath_blocks(x)
+
+ # decoder part
+ for su_block, x_skip in zip(self.su_blocks, reversed(x_skips)):
+ x = su_block(x, x_skip)
+
+ # istft
+ x = rearrange(x, "b f t (c r n) -> b n c f t r", c=c, n=self.n_sources, r=2)
+ x = x.contiguous()
+
+ x = torch.view_as_complex(x)
+ x = rearrange(x, "b n c f t -> (b n c) f t")
+ x = torch.istft(x, **self.stft_kwargs, window=stft_window, return_complex=False)
+ x = rearrange(x, "(b n c) t -> b n c t", c=c, n=self.n_sources)
+
+ x = x[..., :-stft_pad]
+
+ return x
diff --git a/mvsepless/models/scnet_unofficial/utils.py b/mvsepless/models/scnet_unofficial/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d236d499322a5db4ae4a813e75901dcfe28f7993
--- /dev/null
+++ b/mvsepless/models/scnet_unofficial/utils.py
@@ -0,0 +1,135 @@
+"""
+SCNet - great paper, great implementation
+https://arxiv.org/pdf/2401.13276.pdf
+https://github.com/amanteur/SCNet-PyTorch
+"""
+
+from typing import List, Tuple, Union
+
+import torch
+
+
+def create_intervals(
+ splits: List[Union[float, int]],
+) -> List[Union[Tuple[float, float], Tuple[int, int]]]:
+ """
+ Create intervals based on splits provided.
+
+ Args:
+ - splits (List[Union[float, int]]): List of floats or integers representing splits.
+
+ Returns:
+ - List[Union[Tuple[float, float], Tuple[int, int]]]: List of tuples representing intervals.
+ """
+ start = 0
+ return [(start, start := start + split) for split in splits]
+
+
+def get_conv_output_shape(
+ input_shape: int,
+ kernel_size: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ stride: int = 1,
+) -> int:
+ """
+ Compute the output shape of a convolutional layer.
+
+ Args:
+ - input_shape (int): Input shape.
+ - kernel_size (int, optional): Kernel size of the convolution. Default is 1.
+ - padding (int, optional): Padding size. Default is 0.
+ - dilation (int, optional): Dilation factor. Default is 1.
+ - stride (int, optional): Stride value. Default is 1.
+
+ Returns:
+ - int: Output shape.
+ """
+ return int(
+ (input_shape + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
+ )
+
+
+def get_convtranspose_output_padding(
+ input_shape: int,
+ output_shape: int,
+ kernel_size: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ stride: int = 1,
+) -> int:
+ """
+ Compute the output padding for a convolution transpose operation.
+
+ Args:
+ - input_shape (int): Input shape.
+ - output_shape (int): Desired output shape.
+ - kernel_size (int, optional): Kernel size of the convolution. Default is 1.
+ - padding (int, optional): Padding size. Default is 0.
+ - dilation (int, optional): Dilation factor. Default is 1.
+ - stride (int, optional): Stride value. Default is 1.
+
+ Returns:
+ - int: Output padding.
+ """
+ return (
+ output_shape
+ - (input_shape - 1) * stride
+ + 2 * padding
+ - dilation * (kernel_size - 1)
+ - 1
+ )
+
+
+def compute_sd_layer_shapes(
+ input_shape: int,
+ bandsplit_ratios: List[float],
+ downsample_strides: List[int],
+ n_layers: int,
+) -> Tuple[List[List[int]], List[List[Tuple[int, int]]]]:
+ """
+ Compute the shapes for the subband layers.
+
+ Args:
+ - input_shape (int): Input shape.
+ - bandsplit_ratios (List[float]): Ratios for splitting the frequency bands.
+ - downsample_strides (List[int]): Strides for downsampling in each layer.
+ - n_layers (int): Number of layers.
+
+ Returns:
+ - Tuple[List[List[int]], List[List[Tuple[int, int]]]]: Tuple containing subband shapes and convolution shapes.
+ """
+ bandsplit_shapes_list = []
+ conv2d_shapes_list = []
+ for _ in range(n_layers):
+ bandsplit_intervals = create_intervals(bandsplit_ratios)
+ bandsplit_shapes = [
+ int(right * input_shape) - int(left * input_shape)
+ for left, right in bandsplit_intervals
+ ]
+ conv2d_shapes = [
+ get_conv_output_shape(bs, stride=ds)
+ for bs, ds in zip(bandsplit_shapes, downsample_strides)
+ ]
+ input_shape = sum(conv2d_shapes)
+ bandsplit_shapes_list.append(bandsplit_shapes)
+ conv2d_shapes_list.append(create_intervals(conv2d_shapes))
+
+ return bandsplit_shapes_list, conv2d_shapes_list
+
+
+def compute_gcr(subband_shapes: List[List[int]]) -> float:
+ """
+ Compute the global compression ratio.
+
+ Args:
+ - subband_shapes (List[List[int]]): List of subband shapes.
+
+ Returns:
+ - float: Global compression ratio.
+ """
+ t = torch.Tensor(subband_shapes)
+ gcr = torch.stack(
+ [(1 - t[i + 1] / t[i]).mean() for i in range(0, len(t) - 1)]
+ ).mean()
+ return float(gcr)
diff --git a/mvsepless/models/vr_arch/__init__.py b/mvsepless/models/vr_arch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc75c42ef220d5787ba8371835a0cd48109806bf
--- /dev/null
+++ b/mvsepless/models/vr_arch/__init__.py
@@ -0,0 +1,321 @@
+import os
+import math
+import sys
+import json
+import torch
+import torch.nn as nn
+import librosa
+import numpy as np
+from tqdm import tqdm
+
+from . import spec_utils, nets, nets_new
+from .model_param_init import ModelParameters
+
+VOCAL_STEM = "vocals"
+INST_STEM = "instrumental"
+OTHER_STEM = "other"
+BASS_STEM = "bass"
+DRUM_STEM = "drums"
+GUITAR_STEM = "guitar"
+PIANO_STEM = "piano"
+SYNTH_STEM = "synthesizer"
+STRINGS_STEM = "strings"
+WOODWINDS_STEM = "woodwinds"
+BRASS_STEM = "brass"
+WIND_INST_STEM = "wind_inst"
+
+NON_ACCOM_STEMS = (
+ VOCAL_STEM,
+ OTHER_STEM,
+ BASS_STEM,
+ DRUM_STEM,
+ GUITAR_STEM,
+ PIANO_STEM,
+ SYNTH_STEM,
+ STRINGS_STEM,
+ WOODWINDS_STEM,
+ BRASS_STEM,
+ WIND_INST_STEM,
+)
+
+class VRNet:
+ def __init__(
+ self,
+ model_params={},
+ nout=None,
+ nout_lstm=None,
+ ):
+ self.torch_device_mps = torch.backends.mps.is_available()
+ self.enable_post_process = False
+ self.post_process_threshold = 0.2
+ self.batch_size = 1
+ self.window_size = 512
+ self.high_end_process = False
+ self.primary_stem = "Instrumental"
+ self.secondary_stem = "Vocals"
+ self.model_capacity = 32, 128
+ self.is_vr_51_model = False
+ if nout and nout_lstm:
+ self.model_capacity = nout, nout_lstm
+ self.is_vr_51_model = True
+ self.model_params = ModelParameters(model_params)
+ self.input_high_end_h = None
+ self.input_high_end = None
+ self.enable_tta = False
+ self.model_samplerate = self.model_params.param["sr"]
+ self.model_run = lambda *args, **kwargs: print(
+ "Model run method is not initialised yet."
+ )
+
+ def load_checkpoint(self, checkpoint_path: str, device: torch.device):
+ nn_arch_sizes = [
+ 31191,
+ 33966,
+ 56817,
+ 123821,
+ 123812,
+ 129605,
+ 218409,
+ 537238,
+ 537227,
+ ] # default
+ vr_5_1_models = [56817, 218409]
+ model_size = math.ceil(os.stat(checkpoint_path).st_size / 1024)
+ nn_arch_size = min(nn_arch_sizes, key=lambda x: abs(x - model_size))
+
+ if nn_arch_size in vr_5_1_models or self.is_vr_51_model:
+ self.model_run = nets_new.CascadedNet(
+ self.model_params.param["bins"] * 2,
+ nn_arch_size,
+ nout=self.model_capacity[0],
+ nout_lstm=self.model_capacity[1],
+ )
+ self.is_vr_51_model = True
+ else:
+ self.model_run = nets.determine_model_capacity(
+ self.model_params.param["bins"] * 2, nn_arch_size
+ )
+
+ self.model_run.load_state_dict(
+ torch.load(checkpoint_path, map_location=device)
+ )
+ self.model_run.to(device)
+
+ def loading_mix(self, numpy_array, orig_sr=44100):
+ X_wave, X_spec_s = {}, {}
+
+ bands_n = len(self.model_params.param["band"])
+
+ audio_file = numpy_array
+
+ for d in tqdm(range(bands_n, 0, -1)):
+ bp = self.model_params.param["band"][d]
+
+ wav_resolution = bp["res_type"]
+
+ if self.torch_device_mps:
+ wav_resolution = "polyphase"
+
+ if d == bands_n: # high-end band
+ X_wave[d], _ = librosa.resample(
+ y=numpy_array,
+ orig_sr=orig_sr,
+ target_sr=bp["sr"],
+ res_type=wav_resolution,
+ )
+ X_spec_s[d] = spec_utils.wave_to_spectrogram(
+ X_wave[d],
+ bp["hl"],
+ bp["n_fft"],
+ self.model_params,
+ band=d,
+ is_v51_model=self.is_vr_51_model,
+ )
+
+ if X_wave[d].ndim == 1:
+ X_wave[d] = np.asarray([X_wave[d], X_wave[d]])
+ else: # lower bands
+ X_wave[d] = librosa.resample(
+ X_wave[d + 1],
+ orig_sr=self.model_params.param["band"][d + 1]["sr"],
+ target_sr=bp["sr"],
+ res_type=wav_resolution,
+ )
+ X_spec_s[d] = spec_utils.wave_to_spectrogram(
+ X_wave[d],
+ bp["hl"],
+ bp["n_fft"],
+ self.model_params,
+ band=d,
+ is_v51_model=self.is_vr_51_model,
+ )
+
+ if d == bands_n and self.high_end_process:
+ self.input_high_end_h = (bp["n_fft"] // 2 - bp["crop_stop"]) + (
+ self.model_params.param["pre_filter_stop"]
+ - self.model_params.param["pre_filter_start"]
+ )
+ self.input_high_end = X_spec_s[d][
+ :, bp["n_fft"] // 2 - self.input_high_end_h : bp["n_fft"] // 2, :
+ ]
+
+ X_spec = spec_utils.combine_spectrograms(
+ X_spec_s, self.model_params, is_v51_model=self.is_vr_51_model
+ )
+
+ del X_wave, X_spec_s
+
+ return X_spec
+
+ def inference_vr(self, X_spec, device, aggressiveness):
+ def _execute(X_mag_pad, roi_size):
+ X_dataset = []
+ patches = (X_mag_pad.shape[2] - 2 * self.model_run.offset) // roi_size
+ total = patches
+ for i in tqdm(range(patches)):
+ processed = min(i + self.batch_size, patches)
+ sys.stdout.write(json.dumps({"processing": {"processed": processed, "total": total}}, ensure_ascii=False) + '\n')
+ start = i * roi_size
+ X_mag_window = X_mag_pad[:, :, start : start + self.window_size]
+ X_dataset.append(X_mag_window)
+
+ total_iterations = (
+ patches // self.batch_size
+ if not self.enable_tta
+ else (patches // self.batch_size) * 2
+ )
+
+ X_dataset = np.asarray(X_dataset)
+ self.model_run.eval()
+ with torch.no_grad():
+ mask = []
+
+ for i in tqdm(range(0, patches, self.batch_size)):
+ processed = min(i + self.batch_size, patches)
+ sys.stdout.write(json.dumps({"processing": {"processed": processed, "total": total}}, ensure_ascii=False) + '\n')
+ X_batch = X_dataset[i : i + self.batch_size]
+ X_batch = torch.from_numpy(X_batch).to(device)
+ pred = self.model_run.predict_mask(X_batch)
+ if not pred.size()[3] > 0:
+ raise ValueError(
+ f"Window size error: h1_shape[3] must be greater than h2_shape[3]"
+ )
+ pred = pred.detach().cpu().numpy()
+ pred = np.concatenate(pred, axis=2)
+ mask.append(pred)
+ if len(mask) == 0:
+ raise ValueError(
+ f"Window size error: h1_shape[3] must be greater than h2_shape[3]"
+ )
+
+ mask = np.concatenate(mask, axis=2)
+ return mask
+
+ def postprocess(mask, X_mag, X_phase):
+ is_non_accom_stem = False
+ for stem in NON_ACCOM_STEMS:
+ if stem == self.primary_stem:
+ is_non_accom_stem = True
+
+ mask = spec_utils.adjust_aggr(mask, is_non_accom_stem, aggressiveness)
+
+ if self.enable_post_process:
+ mask = spec_utils.merge_artifacts(
+ mask, thres=self.post_process_threshold
+ )
+
+ y_spec = mask * X_mag * np.exp(1.0j * X_phase)
+ v_spec = (1 - mask) * X_mag * np.exp(1.0j * X_phase)
+
+ return y_spec, v_spec
+
+ X_mag, X_phase = spec_utils.preprocess(X_spec)
+ n_frame = X_mag.shape[2]
+ pad_l, pad_r, roi_size = spec_utils.make_padding(
+ n_frame, self.window_size, self.model_run.offset
+ )
+ X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")
+ X_mag_pad /= X_mag_pad.max()
+ mask = _execute(X_mag_pad, roi_size)
+
+ if self.enable_tta:
+ pad_l += roi_size // 2
+ pad_r += roi_size // 2
+ X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")
+ X_mag_pad /= X_mag_pad.max()
+ mask_tta = _execute(X_mag_pad, roi_size)
+ mask_tta = mask_tta[:, :, roi_size // 2 :]
+ mask = (mask[:, :, :n_frame] + mask_tta[:, :, :n_frame]) * 0.5
+ else:
+ mask = mask[:, :, :n_frame]
+
+ y_spec, v_spec = postprocess(mask, X_mag, X_phase)
+
+ return y_spec, v_spec
+
+ def spec_to_wav(self, spec):
+ if (
+ self.high_end_process
+ and isinstance(self.input_high_end, np.ndarray)
+ and self.input_high_end_h
+ ):
+ input_high_end_ = spec_utils.mirroring(
+ "mirroring", spec, self.input_high_end, self.model_params
+ )
+ wav = spec_utils.cmb_spectrogram_to_wave(
+ spec,
+ self.model_params,
+ self.input_high_end_h,
+ input_high_end_,
+ is_v51_model=self.is_vr_51_model,
+ )
+ else:
+ wav = spec_utils.cmb_spectrogram_to_wave(
+ spec, self.model_params, is_v51_model=self.is_vr_51_model
+ )
+
+ return wav
+
+ def settings(self,
+ enable_post_process=False,
+ post_process_threshold=0.2,
+ batch_size=1,
+ window_size=512,
+ high_end_process=False,
+ primary_stem="Instrumental",
+ secondary_stem="Vocals"
+ ):
+ self.enable_post_process = enable_post_process
+ self.post_process_threshold = post_process_threshold
+ self.batch_size = batch_size
+ self.window_size = window_size
+ self.high_end_process = high_end_process
+ self.primary_stem = primary_stem
+ self.secondary_stem = secondary_stem
+
+ def demix(self, numpy_array, sr, device, aggression):
+ aggr = float(int(aggression) / 100)
+ aggressiveness = { # это должно быть в demix
+ "value": aggr,
+ "split_bin": self.model_params.param["band"][1]["crop_stop"],
+ "aggr_correction": self.model_params.param.get("aggr_correction"),
+ }
+ y_spec, v_spec = self.inference_vr(
+ self.loading_mix(numpy_array, sr), device, aggressiveness
+ )
+ y_spec = np.nan_to_num(y_spec, nan=0.0, posinf=0.0, neginf=0.0)
+ v_spec = np.nan_to_num(v_spec, nan=0.0, posinf=0.0, neginf=0.0)
+ primary_stem_array = self.spec_to_wav(y_spec).T
+ primary_stem_array = librosa.resample(
+ primary_stem_array.T,
+ orig_sr=self.model_samplerate,
+ target_sr=44100,
+ ).T
+ secondary_stem_array = self.spec_to_wav(v_spec).T
+ secondary_stem_array = librosa.resample(
+ secondary_stem_array.T,
+ orig_sr=self.model_samplerate,
+ target_sr=44100,
+ ).T
+ return {self.primary_stem: primary_stem_array, self.secondary_stem: secondary_stem_array}
+
diff --git a/mvsepless/models/vr_arch/layers.py b/mvsepless/models/vr_arch/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..fde999c98fc9832b7161a949f27ce4b268d06213
--- /dev/null
+++ b/mvsepless/models/vr_arch/layers.py
@@ -0,0 +1,329 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from . import spec_utils
+
+
+class Conv2DBNActiv(nn.Module):
+ """
+ This class implements a convolutional layer followed by batch normalization and an activation function.
+ It is a common pattern in deep learning for processing images or feature maps. The convolutional layer
+ applies a set of learnable filters to the input. Batch normalization then normalizes the output of the
+ convolution, and finally, an activation function introduces non-linearity to the model, allowing it to
+ learn more complex patterns.
+
+ Attributes:
+ conv (nn.Sequential): A sequential container of Conv2d, BatchNorm2d, and an activation layer.
+
+ Args:
+ num_input_channels (int): Number of input channels.
+ num_output_channels (int): Number of output channels.
+ kernel_size (int, optional): Size of the kernel. Defaults to 3.
+ stride_length (int, optional): Stride of the convolution. Defaults to 1.
+ padding_size (int, optional): Padding added to all sides of the input. Defaults to 1.
+ dilation_rate (int, optional): Spacing between kernel elements. Defaults to 1.
+ activation_function (callable, optional): The activation function to use. Defaults to nn.ReLU.
+ """
+
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
+ super(Conv2DBNActiv, self).__init__()
+
+ # The nn.Sequential container allows us to stack the Conv2d, BatchNorm2d, and activation layers
+ # into a single module, simplifying the forward pass.
+ self.conv = nn.Sequential(
+ nn.Conv2d(
+ nin,
+ nout,
+ kernel_size=ksize,
+ stride=stride,
+ padding=pad,
+ dilation=dilation,
+ bias=False,
+ ),
+ nn.BatchNorm2d(nout),
+ activ(),
+ )
+
+ def __call__(self, input_tensor):
+ # Defines the computation performed at every call.
+ # Simply passes the input through the sequential container.
+ return self.conv(input_tensor)
+
+
+class SeperableConv2DBNActiv(nn.Module):
+ """
+ This class implements a separable convolutional layer followed by batch normalization and an activation function.
+ Separable convolutions are a type of convolution that splits the convolution operation into two simpler operations:
+ a depthwise convolution and a pointwise convolution. This can reduce the number of parameters and computational cost,
+ making the network more efficient while maintaining similar performance.
+
+ The depthwise convolution applies a single filter per input channel (input depth). The pointwise convolution,
+ which follows, applies a 1x1 convolution to combine the outputs of the depthwise convolution across channels.
+ Batch normalization is then applied to stabilize learning and reduce internal covariate shift. Finally,
+ an activation function introduces non-linearity, allowing the network to learn complex patterns.
+ Attributes:
+ conv (nn.Sequential): A sequential container of depthwise Conv2d, pointwise Conv2d, BatchNorm2d, and an activation layer.
+
+ Args:
+ num_input_channels (int): Number of input channels.
+ num_output_channels (int): Number of output channels.
+ kernel_size (int, optional): Size of the kernel for the depthwise convolution. Defaults to 3.
+ stride_length (int, optional): Stride of the convolution. Defaults to 1.
+ padding_size (int, optional): Padding added to all sides of the input for the depthwise convolution. Defaults to 1.
+ dilation_rate (int, optional): Spacing between kernel elements for the depthwise convolution. Defaults to 1.
+ activation_function (callable, optional): The activation function to use. Defaults to nn.ReLU.
+ """
+
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
+ super(SeperableConv2DBNActiv, self).__init__()
+
+ # Initialize the sequential container with the depthwise convolution.
+ # The number of groups in the depthwise convolution is set to num_input_channels, which means each input channel is treated separately.
+ # The pointwise convolution then combines these separate channels into num_output_channels channels.
+ # Batch normalization is applied to the output of the pointwise convolution.
+ # Finally, the activation function is applied to introduce non-linearity.
+ self.conv = nn.Sequential(
+ nn.Conv2d(
+ nin,
+ nin, # For depthwise convolution, in_channels = out_channels = num_input_channels
+ kernel_size=ksize,
+ stride=stride,
+ padding=pad,
+ dilation=dilation,
+ groups=nin, # This makes it a depthwise convolution
+ bias=False, # Bias is not used because it will be handled by BatchNorm2d
+ ),
+ nn.Conv2d(
+ nin,
+ nout, # Pointwise convolution to combine channels
+ kernel_size=1, # Kernel size of 1 for pointwise convolution
+ bias=False, # Bias is not used because it will be handled by BatchNorm2d
+ ),
+ nn.BatchNorm2d(nout), # Normalize the output of the pointwise convolution
+ activ(), # Apply the activation function
+ )
+
+ def __call__(self, input_tensor):
+ # Pass the input through the sequential container.
+ # This performs the depthwise convolution, followed by the pointwise convolution,
+ # batch normalization, and finally applies the activation function.
+ return self.conv(input_tensor)
+
+
+class Encoder(nn.Module):
+ """
+ The Encoder class is a part of the neural network architecture that is responsible for processing the input data.
+ It consists of two convolutional layers, each followed by batch normalization and an activation function.
+ The purpose of the Encoder is to transform the input data into a higher-level, abstract representation.
+ This is achieved by applying filters (through convolutions) that can capture patterns or features in the data.
+ The Encoder can be thought of as a feature extractor that prepares the data for further processing by the network.
+ Attributes:
+ conv1 (Conv2DBNActiv): The first convolutional layer in the encoder.
+ conv2 (Conv2DBNActiv): The second convolutional layer in the encoder.
+
+ Args:
+ number_of_input_channels (int): Number of input channels for the first convolutional layer.
+ number_of_output_channels (int): Number of output channels for the convolutional layers.
+ kernel_size (int): Kernel size for the convolutional layers.
+ stride_length (int): Stride for the convolutional operations.
+ padding_size (int): Padding added to all sides of the input for the convolutional layers.
+ activation_function (callable): The activation function to use after each convolutional layer.
+ """
+
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
+ super(Encoder, self).__init__()
+
+ # The first convolutional layer takes the input and applies a convolution,
+ # followed by batch normalization and an activation function specified by `activation_function`.
+ # This layer is responsible for capturing the initial set of features from the input data.
+ self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
+
+ # The second convolutional layer further processes the output from the first layer,
+ # applying another set of convolution, batch normalization, and activation.
+ # This layer helps in capturing more complex patterns in the data by building upon the initial features extracted by conv1.
+ self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ)
+
+ def __call__(self, input_tensor):
+ # The input data `input_tensor` is passed through the first convolutional layer.
+ # The output of this layer serves as a 'skip connection' that can be used later in the network to preserve spatial information.
+ skip = self.conv1(input_tensor)
+
+ # The output from the first layer is then passed through the second convolutional layer.
+ # This processed data `hidden` is the final output of the Encoder, representing the abstracted features of the input.
+ hidden = self.conv2(skip)
+
+ # The Encoder returns two outputs: `hidden`, the abstracted feature representation, and `skip`, the intermediate representation from conv1.
+ return hidden, skip
+
+
+class Decoder(nn.Module):
+ """
+ The Decoder class is part of the neural network architecture, specifically designed to perform the inverse operation of an encoder.
+ Its main role is to reconstruct or generate data from encoded representations, which is crucial in tasks like image segmentation or audio processing.
+ This class uses upsampling, convolution, optional dropout for regularization, and concatenation of skip connections to achieve its goal.
+
+ Attributes:
+ convolution (Conv2DBNActiv): A convolutional layer with batch normalization and activation function.
+ dropout_layer (nn.Dropout2d): An optional dropout layer for regularization to prevent overfitting.
+
+ Args:
+ input_channels (int): Number of input channels for the convolutional layer.
+ output_channels (int): Number of output channels for the convolutional layer.
+ kernel_size (int): Kernel size for the convolutional layer.
+ stride (int): Stride for the convolutional operations.
+ padding (int): Padding added to all sides of the input for the convolutional layer.
+ activation_function (callable): The activation function to use after the convolutional layer.
+ include_dropout (bool): Whether to include a dropout layer for regularization.
+ """
+
+ def __init__(
+ self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False
+ ):
+ super(Decoder, self).__init__()
+
+ # Initialize the convolutional layer with specified parameters.
+ self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
+
+ # Initialize the dropout layer if include_dropout is set to True
+ self.dropout = nn.Dropout2d(0.1) if dropout else None
+
+ def __call__(self, input_tensor, skip=None):
+ # Upsample the input tensor to a higher resolution using bilinear interpolation.
+ input_tensor = F.interpolate(
+ input_tensor, scale_factor=2, mode="bilinear", align_corners=True
+ )
+ # If a skip connection is provided, crop it to match the size of input_tensor and concatenate them along the channel dimension.
+ if skip is not None:
+ skip = spec_utils.crop_center(
+ skip, input_tensor
+ ) # Crop skip_connection to match input_tensor's dimensions.
+ input_tensor = torch.cat(
+ [input_tensor, skip], dim=1
+ ) # Concatenate input_tensor and skip_connection along the channel dimension.
+
+ # Pass the concatenated tensor (or just input_tensor if no skip_connection is provided) through the convolutional layer.
+ output_tensor = self.conv(input_tensor)
+
+ # If dropout is enabled, apply it to the output of the convolutional layer.
+ if self.dropout is not None:
+ output_tensor = self.dropout(output_tensor)
+
+ # Return the final output tensor.
+ return output_tensor
+
+
+class ASPPModule(nn.Module):
+ """
+ Atrous Spatial Pyramid Pooling (ASPP) Module is designed for capturing multi-scale context by applying
+ atrous convolution at multiple rates. This is particularly useful in segmentation tasks where capturing
+ objects at various scales is beneficial. The module applies several parallel dilated convolutions with
+ different dilation rates to the input feature map, allowing it to efficiently capture information at
+ multiple scales.
+
+ Attributes:
+ conv1 (nn.Sequential): Applies adaptive average pooling followed by a 1x1 convolution.
+ nn_architecture (int): Identifier for the neural network architecture being used.
+ six_layer (list): List containing architecture identifiers that require six layers.
+ seven_layer (list): List containing architecture identifiers that require seven layers.
+ conv2-conv7 (nn.Module): Convolutional layers with varying dilation rates for multi-scale feature extraction.
+ bottleneck (nn.Sequential): A 1x1 convolutional layer that combines all features followed by dropout for regularization.
+ """
+
+ def __init__(self, nn_architecture, nin, nout, dilations=(4, 8, 16), activ=nn.ReLU):
+ """
+ Initializes the ASPP module with specified parameters.
+
+ Args:
+ nn_architecture (int): Identifier for the neural network architecture.
+ input_channels (int): Number of input channels.
+ output_channels (int): Number of output channels.
+ dilations (tuple): Tuple of dilation rates for the atrous convolutions.
+ activation (callable): Activation function to use after convolutional layers.
+ """
+ super(ASPPModule, self).__init__()
+
+ # Adaptive average pooling reduces the spatial dimensions to 1x1, focusing on global context,
+ # followed by a 1x1 convolution to project back to the desired channel dimension.
+ self.conv1 = nn.Sequential(
+ nn.AdaptiveAvgPool2d((1, None)),
+ Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ),
+ )
+
+ self.nn_architecture = nn_architecture
+ # Architecture identifiers for models requiring additional layers.
+ self.six_layer = [129605]
+ self.seven_layer = [537238, 537227, 33966]
+
+ # Extra convolutional layer used for six and seven layer configurations.
+ extra_conv = SeperableConv2DBNActiv(
+ nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
+ )
+
+ # Standard 1x1 convolution for channel reduction.
+ self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
+
+ # Separable convolutions with different dilation rates for multi-scale feature extraction.
+ self.conv3 = SeperableConv2DBNActiv(
+ nin, nin, 3, 1, dilations[0], dilations[0], activ=activ
+ )
+ self.conv4 = SeperableConv2DBNActiv(
+ nin, nin, 3, 1, dilations[1], dilations[1], activ=activ
+ )
+ self.conv5 = SeperableConv2DBNActiv(
+ nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
+ )
+
+ # Depending on the architecture, include the extra convolutional layers.
+ if self.nn_architecture in self.six_layer:
+ self.conv6 = extra_conv
+ nin_x = 6
+ elif self.nn_architecture in self.seven_layer:
+ self.conv6 = extra_conv
+ self.conv7 = extra_conv
+ nin_x = 7
+ else:
+ nin_x = 5
+
+ # Bottleneck layer combines all the multi-scale features into the desired number of output channels.
+ self.bottleneck = nn.Sequential(
+ Conv2DBNActiv(nin * nin_x, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1)
+ )
+
+ def forward(self, input_tensor):
+ """
+ Forward pass of the ASPP module.
+
+ Args:
+ input_tensor (Tensor): Input tensor.
+
+ Returns:
+ Tensor: Output tensor after applying ASPP.
+ """
+ _, _, h, w = input_tensor.size()
+
+ # Apply the first convolutional sequence and upsample to the original resolution.
+ feat1 = F.interpolate(
+ self.conv1(input_tensor), size=(h, w), mode="bilinear", align_corners=True
+ )
+
+ # Apply the remaining convolutions directly on the input.
+ feat2 = self.conv2(input_tensor)
+ feat3 = self.conv3(input_tensor)
+ feat4 = self.conv4(input_tensor)
+ feat5 = self.conv5(input_tensor)
+
+ # Concatenate features from all layers. Depending on the architecture, include the extra features.
+ if self.nn_architecture in self.six_layer:
+ feat6 = self.conv6(input_tensor)
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6), dim=1)
+ elif self.nn_architecture in self.seven_layer:
+ feat6 = self.conv6(input_tensor)
+ feat7 = self.conv7(input_tensor)
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6, feat7), dim=1)
+ else:
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
+
+ # Apply the bottleneck layer to combine and reduce the channel dimensions.
+ bottleneck_output = self.bottleneck(out)
+ return bottleneck_output
diff --git a/mvsepless/models/vr_arch/layers_new.py b/mvsepless/models/vr_arch/layers_new.py
new file mode 100644
index 0000000000000000000000000000000000000000..b245ee25234e4058a0989a1a864aa0d8539d7f63
--- /dev/null
+++ b/mvsepless/models/vr_arch/layers_new.py
@@ -0,0 +1,180 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from . import spec_utils
+
+
+class Conv2DBNActiv(nn.Module):
+ """
+ Conv2DBNActiv Class:
+ This class implements a convolutional layer followed by batch normalization and an activation function.
+ It is a fundamental building block for constructing neural networks, especially useful in image and audio processing tasks.
+ The class encapsulates the pattern of applying a convolution, normalizing the output, and then applying a non-linear activation.
+ """
+
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
+ super(Conv2DBNActiv, self).__init__()
+
+ # Sequential model combining Conv2D, BatchNorm, and activation function into a single module
+ self.conv = nn.Sequential(
+ nn.Conv2d(
+ nin,
+ nout,
+ kernel_size=ksize,
+ stride=stride,
+ padding=pad,
+ dilation=dilation,
+ bias=False,
+ ),
+ nn.BatchNorm2d(nout),
+ activ(),
+ )
+
+ def __call__(self, input_tensor):
+ # Forward pass through the sequential model
+ return self.conv(input_tensor)
+
+
+class Encoder(nn.Module):
+ """
+ Encoder Class:
+ This class defines an encoder module typically used in autoencoder architectures.
+ It consists of two convolutional layers, each followed by batch normalization and an activation function.
+ """
+
+ def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
+ super(Encoder, self).__init__()
+
+ # First convolutional layer of the encoder
+ self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ)
+ # Second convolutional layer of the encoder
+ self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
+
+ def __call__(self, input_tensor):
+ # Applying the first and then the second convolutional layers
+ hidden = self.conv1(input_tensor)
+ hidden = self.conv2(hidden)
+
+ return hidden
+
+
+class Decoder(nn.Module):
+ """
+ Decoder Class:
+ This class defines a decoder module, which is the counterpart of the Encoder class in autoencoder architectures.
+ It applies a convolutional layer followed by batch normalization and an activation function, with an optional dropout layer for regularization.
+ """
+
+ def __init__(
+ self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False
+ ):
+ super(Decoder, self).__init__()
+ # Convolutional layer with optional dropout for regularization
+ self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
+ # self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
+ self.dropout = nn.Dropout2d(0.1) if dropout else None
+
+ def __call__(self, input_tensor, skip=None):
+ # Forward pass through the convolutional layer and optional dropout
+ input_tensor = F.interpolate(
+ input_tensor, scale_factor=2, mode="bilinear", align_corners=True
+ )
+
+ if skip is not None:
+ skip = spec_utils.crop_center(skip, input_tensor)
+ input_tensor = torch.cat([input_tensor, skip], dim=1)
+
+ hidden = self.conv1(input_tensor)
+ # hidden = self.conv2(hidden)
+
+ if self.dropout is not None:
+ hidden = self.dropout(hidden)
+
+ return hidden
+
+
+class ASPPModule(nn.Module):
+ """
+ ASPPModule Class:
+ This class implements the Atrous Spatial Pyramid Pooling (ASPP) module, which is useful for semantic image segmentation tasks.
+ It captures multi-scale contextual information by applying convolutions at multiple dilation rates.
+ """
+
+ def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False):
+ super(ASPPModule, self).__init__()
+
+ # Global context convolution captures the overall context
+ self.conv1 = nn.Sequential(
+ nn.AdaptiveAvgPool2d((1, None)),
+ Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ),
+ )
+ self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
+ self.conv3 = Conv2DBNActiv(
+ nin, nout, 3, 1, dilations[0], dilations[0], activ=activ
+ )
+ self.conv4 = Conv2DBNActiv(
+ nin, nout, 3, 1, dilations[1], dilations[1], activ=activ
+ )
+ self.conv5 = Conv2DBNActiv(
+ nin, nout, 3, 1, dilations[2], dilations[2], activ=activ
+ )
+ self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ)
+ self.dropout = nn.Dropout2d(0.1) if dropout else None
+
+ def forward(self, input_tensor):
+ _, _, h, w = input_tensor.size()
+
+ # Upsample global context to match input size and combine with local and multi-scale features
+ feat1 = F.interpolate(
+ self.conv1(input_tensor), size=(h, w), mode="bilinear", align_corners=True
+ )
+ feat2 = self.conv2(input_tensor)
+ feat3 = self.conv3(input_tensor)
+ feat4 = self.conv4(input_tensor)
+ feat5 = self.conv5(input_tensor)
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
+ out = self.bottleneck(out)
+
+ if self.dropout is not None:
+ out = self.dropout(out)
+
+ return out
+
+
+class LSTMModule(nn.Module):
+ """
+ LSTMModule Class:
+ This class defines a module that combines convolutional feature extraction with a bidirectional LSTM for sequence modeling.
+ It is useful for tasks that require understanding temporal dynamics in data, such as speech and audio processing.
+ """
+
+ def __init__(self, nin_conv, nin_lstm, nout_lstm):
+ super(LSTMModule, self).__init__()
+ # Convolutional layer for initial feature extraction
+ self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0)
+
+ # Bidirectional LSTM for capturing temporal dynamics
+ self.lstm = nn.LSTM(
+ input_size=nin_lstm, hidden_size=nout_lstm // 2, bidirectional=True
+ )
+
+ # Dense layer for output dimensionality matching
+ self.dense = nn.Sequential(
+ nn.Linear(nout_lstm, nin_lstm), nn.BatchNorm1d(nin_lstm), nn.ReLU()
+ )
+
+ def forward(self, input_tensor):
+ N, _, nbins, nframes = input_tensor.size()
+
+ # Extract features and prepare for LSTM
+ hidden = self.conv(input_tensor)[:, 0] # N, nbins, nframes
+ hidden = hidden.permute(2, 0, 1) # nframes, N, nbins
+ hidden, _ = self.lstm(hidden)
+
+ # Apply dense layer and reshape to match expected output format
+ hidden = self.dense(hidden.reshape(-1, hidden.size()[-1])) # nframes * N, nbins
+ hidden = hidden.reshape(nframes, N, 1, nbins)
+ hidden = hidden.permute(1, 2, 3, 0)
+
+ return hidden
diff --git a/mvsepless/models/vr_arch/model_param_init.py b/mvsepless/models/vr_arch/model_param_init.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cfde44a8c958e989854e71613bfecc8c139bdfd
--- /dev/null
+++ b/mvsepless/models/vr_arch/model_param_init.py
@@ -0,0 +1,76 @@
+import json
+
+default_param = {}
+default_param["bins"] = -1
+default_param["unstable_bins"] = -1 # training only
+default_param["stable_bins"] = -1 # training only
+default_param["sr"] = 44100
+default_param["pre_filter_start"] = -1
+default_param["pre_filter_stop"] = -1
+default_param["band"] = {}
+
+N_BINS = "n_bins"
+
+
+def int_keys(d):
+ """
+ Converts string keys that represent integers into actual integer keys in a list.
+
+ This function is particularly useful when dealing with JSON data that may represent
+ integer keys as strings due to the nature of JSON encoding. By converting these keys
+ back to integers, it ensures that the data can be used in a manner consistent with
+ its original representation, especially in contexts where the distinction between
+ string and integer keys is important.
+
+ Args:
+ input_list (list of tuples): A list of (key, value) pairs where keys are strings
+ that may represent integers.
+
+ Returns:
+ dict: A dictionary with keys converted to integers where applicable.
+ """
+ # Initialize an empty dictionary to hold the converted key-value pairs.
+ result_dict = {}
+ # Iterate through each key-value pair in the input list.
+ for key, value in d:
+ # Check if the key is a digit (i.e., represents an integer).
+ if key.isdigit():
+ # Convert the key from a string to an integer.
+ key = int(key)
+ result_dict[key] = value
+ return result_dict
+
+
+class ModelParameters(object):
+ """
+ A class to manage model parameters, including loading from a configuration file.
+
+ Attributes:
+ param (dict): Dictionary holding all parameters for the model.
+ """
+
+ def __init__(self, model_params=None):
+ """
+ Initializes the ModelParameters object by loading parameters from a JSON configuration file.
+
+ Args:
+ config_path (str): Path to the JSON configuration file.
+ """
+
+ self.param = model_params
+
+ # Ensure certain parameters are set to False if not specified in the configuration.
+ for k in [
+ "mid_side",
+ "mid_side_b",
+ "mid_side_b2",
+ "stereo_w",
+ "stereo_n",
+ "reverse",
+ ]:
+ if not k in self.param:
+ self.param[k] = False
+
+ # If 'n_bins' is specified in the parameters, it's used as the value for 'bins'.
+ if N_BINS in self.param:
+ self.param["bins"] = self.param[N_BINS]
diff --git a/mvsepless/models/vr_arch/nets.py b/mvsepless/models/vr_arch/nets.py
new file mode 100644
index 0000000000000000000000000000000000000000..53fbf64c5d2b8dcae60ac75017a4144030c0ec81
--- /dev/null
+++ b/mvsepless/models/vr_arch/nets.py
@@ -0,0 +1,223 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from . import layers
+
+
+class BaseASPPNet(nn.Module):
+ """
+ BaseASPPNet Class:
+ This class defines the base architecture for an Atrous Spatial Pyramid Pooling (ASPP) network.
+ It is designed to extract features from input data at multiple scales by using dilated convolutions.
+ This is particularly useful for tasks that benefit from understanding context at different resolutions,
+ such as semantic segmentation. The network consists of a series of encoder layers for downsampling and feature extraction,
+ followed by an ASPP module for multi-scale feature extraction, and finally a series of decoder layers for upsampling.
+ """
+
+ def __init__(self, nn_architecture, nin, ch, dilations=(4, 8, 16)):
+ super(BaseASPPNet, self).__init__()
+ self.nn_architecture = nn_architecture
+
+ # Encoder layers progressively increase the number of channels while reducing spatial dimensions.
+ self.enc1 = layers.Encoder(nin, ch, 3, 2, 1)
+ self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1)
+ self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1)
+ self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1)
+
+ # Depending on the network architecture, an additional encoder layer and a specific ASPP module are initialized.
+ if self.nn_architecture == 129605:
+ self.enc5 = layers.Encoder(ch * 8, ch * 16, 3, 2, 1)
+ self.aspp = layers.ASPPModule(nn_architecture, ch * 16, ch * 32, dilations)
+ self.dec5 = layers.Decoder(ch * (16 + 32), ch * 16, 3, 1, 1)
+ else:
+ self.aspp = layers.ASPPModule(nn_architecture, ch * 8, ch * 16, dilations)
+
+ # Decoder layers progressively decrease the number of channels while increasing spatial dimensions.
+ self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1)
+ self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1)
+ self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1)
+ self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1)
+
+ def __call__(self, input_tensor):
+ # The input tensor is passed through a series of encoder layers.
+ hidden_state, encoder_output1 = self.enc1(input_tensor)
+ hidden_state, encoder_output2 = self.enc2(hidden_state)
+ hidden_state, encoder_output3 = self.enc3(hidden_state)
+ hidden_state, encoder_output4 = self.enc4(hidden_state)
+
+ # Depending on the network architecture, the hidden state is processed by an additional encoder layer and the ASPP module.
+ if self.nn_architecture == 129605:
+ hidden_state, encoder_output5 = self.enc5(hidden_state)
+ hidden_state = self.aspp(hidden_state)
+ # The decoder layers use skip connections from the encoder layers for better feature integration.
+ hidden_state = self.dec5(hidden_state, encoder_output5)
+ else:
+ hidden_state = self.aspp(hidden_state)
+
+ # The hidden state is further processed by the decoder layers, using skip connections for feature integration.
+ hidden_state = self.dec4(hidden_state, encoder_output4)
+ hidden_state = self.dec3(hidden_state, encoder_output3)
+ hidden_state = self.dec2(hidden_state, encoder_output2)
+ hidden_state = self.dec1(hidden_state, encoder_output1)
+
+ return hidden_state
+
+
+def determine_model_capacity(n_fft_bins, nn_architecture):
+ """
+ The determine_model_capacity function is designed to select the appropriate model configuration
+ based on the frequency bins and network architecture. It maps specific architectures to predefined
+ model capacities, which dictate the structure and parameters of the CascadedASPPNet model.
+ """
+
+ # Predefined model architectures categorized by their precision level.
+ sp_model_arch = [31191, 33966, 129605]
+ hp_model_arch = [123821, 123812]
+ hp2_model_arch = [537238, 537227]
+
+ # Mapping network architectures to their corresponding model capacity data.
+ if nn_architecture in sp_model_arch:
+ model_capacity_data = [
+ (2, 16),
+ (2, 16),
+ (18, 8, 1, 1, 0),
+ (8, 16),
+ (34, 16, 1, 1, 0),
+ (16, 32),
+ (32, 2, 1),
+ (16, 2, 1),
+ (16, 2, 1),
+ ]
+
+ if nn_architecture in hp_model_arch:
+ model_capacity_data = [
+ (2, 32),
+ (2, 32),
+ (34, 16, 1, 1, 0),
+ (16, 32),
+ (66, 32, 1, 1, 0),
+ (32, 64),
+ (64, 2, 1),
+ (32, 2, 1),
+ (32, 2, 1),
+ ]
+
+ if nn_architecture in hp2_model_arch:
+ model_capacity_data = [
+ (2, 64),
+ (2, 64),
+ (66, 32, 1, 1, 0),
+ (32, 64),
+ (130, 64, 1, 1, 0),
+ (64, 128),
+ (128, 2, 1),
+ (64, 2, 1),
+ (64, 2, 1),
+ ]
+
+ # Initializing the CascadedASPPNet model with the selected model capacity data.
+ cascaded = CascadedASPPNet
+ model = cascaded(n_fft_bins, model_capacity_data, nn_architecture)
+
+ return model
+
+
+class CascadedASPPNet(nn.Module):
+ """
+ CascadedASPPNet Class:
+ This class implements a cascaded version of the ASPP network, designed for processing audio signals
+ for tasks such as vocal removal. It consists of multiple stages, each with its own ASPP network,
+ to process different frequency bands of the input signal. This allows the model to effectively
+ handle the full spectrum of audio frequencies by focusing on different frequency bands separately.
+ """
+
+ def __init__(self, n_fft, model_capacity_data, nn_architecture):
+ super(CascadedASPPNet, self).__init__()
+ # The first stage processes the low and high frequency bands separately.
+ self.stg1_low_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[0])
+ self.stg1_high_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[1])
+
+ # Bridge layers connect different stages of the network.
+ self.stg2_bridge = layers.Conv2DBNActiv(*model_capacity_data[2])
+ self.stg2_full_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[3])
+
+ self.stg3_bridge = layers.Conv2DBNActiv(*model_capacity_data[4])
+ self.stg3_full_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[5])
+
+ # Output layers for the final mask prediction and auxiliary outputs.
+ self.out = nn.Conv2d(*model_capacity_data[6], bias=False)
+ self.aux1_out = nn.Conv2d(*model_capacity_data[7], bias=False)
+ self.aux2_out = nn.Conv2d(*model_capacity_data[8], bias=False)
+
+ # Parameters for handling the frequency bins of the input signal.
+ self.max_bin = n_fft // 2
+ self.output_bin = n_fft // 2 + 1
+
+ self.offset = 128
+
+ def forward(self, input_tensor):
+ # The forward pass processes the input tensor through each stage of the network,
+ # combining the outputs of different frequency bands and stages to produce the final mask.
+ mix = input_tensor.detach()
+ input_tensor = input_tensor.clone()
+
+ # Preparing the input tensor by selecting the mainput_tensorimum frequency bin.
+ input_tensor = input_tensor[:, :, : self.max_bin]
+
+ # Processing the low and high frequency bands separately in the first stage.
+ bandwidth = input_tensor.size()[2] // 2
+ aux1 = torch.cat(
+ [
+ self.stg1_low_band_net(input_tensor[:, :, :bandwidth]),
+ self.stg1_high_band_net(input_tensor[:, :, bandwidth:]),
+ ],
+ dim=2,
+ )
+
+ # Combining the outputs of the first stage and passing through the second stage.
+ hidden_state = torch.cat([input_tensor, aux1], dim=1)
+ aux2 = self.stg2_full_band_net(self.stg2_bridge(hidden_state))
+
+ # Further processing the combined outputs through the third stage.
+ hidden_state = torch.cat([input_tensor, aux1, aux2], dim=1)
+ hidden_state = self.stg3_full_band_net(self.stg3_bridge(hidden_state))
+
+ # Applying the final output layer to produce the mask.
+ mask = torch.sigmoid(self.out(hidden_state))
+
+ # Padding the mask to match the output frequency bin size.
+ mask = F.pad(
+ input=mask,
+ pad=(0, 0, 0, self.output_bin - mask.size()[2]),
+ mode="replicate",
+ )
+
+ # During training, auxiliary outputs are also produced and padded accordingly.
+ if self.training:
+ aux1 = torch.sigmoid(self.aux1_out(aux1))
+ aux1 = F.pad(
+ input=aux1,
+ pad=(0, 0, 0, self.output_bin - aux1.size()[2]),
+ mode="replicate",
+ )
+ aux2 = torch.sigmoid(self.aux2_out(aux2))
+ aux2 = F.pad(
+ input=aux2,
+ pad=(0, 0, 0, self.output_bin - aux2.size()[2]),
+ mode="replicate",
+ )
+ return mask * mix, aux1 * mix, aux2 * mix
+ else:
+ return mask # * mix
+
+ def predict_mask(self, input_tensor):
+ # This method predicts the mask for the input tensor by calling the forward method
+ # and applying any necessary padding adjustments.
+ mask = self.forward(input_tensor)
+
+ # Adjusting the mask by removing padding offsets if present.
+ if self.offset > 0:
+ mask = mask[:, :, :, self.offset : -self.offset]
+
+ return mask
diff --git a/mvsepless/models/vr_arch/nets_new.py b/mvsepless/models/vr_arch/nets_new.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f20cebac69542d085211d4ca02b6822a08ff34e
--- /dev/null
+++ b/mvsepless/models/vr_arch/nets_new.py
@@ -0,0 +1,182 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from . import layers_new as layers
+
+
+class BaseNet(nn.Module):
+ """
+ BaseNet Class:
+ This class defines the base network architecture for vocal removal. It includes a series of encoders for feature extraction,
+ an ASPP module for capturing multi-scale context, and a series of decoders for reconstructing the output. Additionally,
+ it incorporates an LSTM module for capturing temporal dependencies.
+ """
+
+ def __init__(
+ self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6))
+ ):
+ super(BaseNet, self).__init__()
+ # Initialize the encoder layers with increasing output channels for hierarchical feature extraction.
+ self.enc1 = layers.Conv2DBNActiv(nin, nout, 3, 1, 1)
+ self.enc2 = layers.Encoder(nout, nout * 2, 3, 2, 1)
+ self.enc3 = layers.Encoder(nout * 2, nout * 4, 3, 2, 1)
+ self.enc4 = layers.Encoder(nout * 4, nout * 6, 3, 2, 1)
+ self.enc5 = layers.Encoder(nout * 6, nout * 8, 3, 2, 1)
+
+ # ASPP module for capturing multi-scale features with different dilation rates.
+ self.aspp = layers.ASPPModule(nout * 8, nout * 8, dilations, dropout=True)
+
+ # Decoder layers for upscaling and merging features from different levels of the encoder and ASPP module.
+ self.dec4 = layers.Decoder(nout * (6 + 8), nout * 6, 3, 1, 1)
+ self.dec3 = layers.Decoder(nout * (4 + 6), nout * 4, 3, 1, 1)
+ self.dec2 = layers.Decoder(nout * (2 + 4), nout * 2, 3, 1, 1)
+
+ # LSTM module for capturing temporal dependencies in the sequence of features.
+ self.lstm_dec2 = layers.LSTMModule(nout * 2, nin_lstm, nout_lstm)
+ self.dec1 = layers.Decoder(nout * (1 + 2) + 1, nout * 1, 3, 1, 1)
+
+ def __call__(self, input_tensor):
+ # Sequentially pass the input through the encoder layers.
+ encoded1 = self.enc1(input_tensor)
+ encoded2 = self.enc2(encoded1)
+ encoded3 = self.enc3(encoded2)
+ encoded4 = self.enc4(encoded3)
+ encoded5 = self.enc5(encoded4)
+
+ # Pass the deepest encoder output through the ASPP module.
+ bottleneck = self.aspp(encoded5)
+
+ # Sequentially upscale and merge the features using the decoder layers.
+ bottleneck = self.dec4(bottleneck, encoded4)
+ bottleneck = self.dec3(bottleneck, encoded3)
+ bottleneck = self.dec2(bottleneck, encoded2)
+ # Concatenate the LSTM module output for temporal feature enhancement.
+ bottleneck = torch.cat([bottleneck, self.lstm_dec2(bottleneck)], dim=1)
+ bottleneck = self.dec1(bottleneck, encoded1)
+
+ return bottleneck
+
+
+class CascadedNet(nn.Module):
+ """
+ CascadedNet Class:
+ This class defines a cascaded network architecture that processes input in multiple stages, each stage focusing on different frequency bands.
+ It utilizes the BaseNet for processing, and combines outputs from different stages to produce the final mask for vocal removal.
+ """
+
+ def __init__(self, n_fft, nn_arch_size=51000, nout=32, nout_lstm=128):
+ super(CascadedNet, self).__init__()
+ # Calculate frequency bins based on FFT size.
+ self.max_bin = n_fft // 2
+ self.output_bin = n_fft // 2 + 1
+ self.nin_lstm = self.max_bin // 2
+ self.offset = 64
+ # Adjust output channels based on the architecture size.
+ nout = 64 if nn_arch_size == 218409 else nout
+
+ # print(nout, nout_lstm, n_fft)
+
+ # Initialize the network stages, each focusing on different frequency bands and progressively refining the output.
+ self.stg1_low_band_net = nn.Sequential(
+ BaseNet(2, nout // 2, self.nin_lstm // 2, nout_lstm),
+ layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0),
+ )
+ self.stg1_high_band_net = BaseNet(
+ 2, nout // 4, self.nin_lstm // 2, nout_lstm // 2
+ )
+
+ self.stg2_low_band_net = nn.Sequential(
+ BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm),
+ layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0),
+ )
+ self.stg2_high_band_net = BaseNet(
+ nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2
+ )
+
+ self.stg3_full_band_net = BaseNet(
+ 3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm
+ )
+
+ # Output layer for generating the final mask.
+ self.out = nn.Conv2d(nout, 2, 1, bias=False)
+ # Auxiliary output layer for intermediate supervision during training.
+ self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False)
+
+ def forward(self, input_tensor):
+ # Preprocess input tensor to match the maximum frequency bin.
+ input_tensor = input_tensor[:, :, : self.max_bin]
+
+ # Split the input into low and high frequency bands.
+ bandw = input_tensor.size()[2] // 2
+ l1_in = input_tensor[:, :, :bandw]
+ h1_in = input_tensor[:, :, bandw:]
+
+ # Process each band through the first stage networks.
+ l1 = self.stg1_low_band_net(l1_in)
+ h1 = self.stg1_high_band_net(h1_in)
+
+ # Combine the outputs for auxiliary supervision.
+ aux1 = torch.cat([l1, h1], dim=2)
+
+ # Prepare inputs for the second stage by concatenating the original and processed bands.
+ l2_in = torch.cat([l1_in, l1], dim=1)
+ h2_in = torch.cat([h1_in, h1], dim=1)
+
+ # Process through the second stage networks.
+ l2 = self.stg2_low_band_net(l2_in)
+ h2 = self.stg2_high_band_net(h2_in)
+
+ # Combine the outputs for auxiliary supervision.
+ aux2 = torch.cat([l2, h2], dim=2)
+
+ # Prepare input for the third stage by concatenating all previous outputs with the original input.
+ f3_in = torch.cat([input_tensor, aux1, aux2], dim=1)
+
+ # Process through the third stage network.
+ f3 = self.stg3_full_band_net(f3_in)
+
+ # Apply the output layer to generate the final mask and apply sigmoid for normalization.
+ mask = torch.sigmoid(self.out(f3))
+
+ # Pad the mask to match the output frequency bin size.
+ mask = F.pad(
+ input=mask,
+ pad=(0, 0, 0, self.output_bin - mask.size()[2]),
+ mode="replicate",
+ )
+
+ # During training, generate and pad the auxiliary output for additional supervision.
+ if self.training:
+ aux = torch.cat([aux1, aux2], dim=1)
+ aux = torch.sigmoid(self.aux_out(aux))
+ aux = F.pad(
+ input=aux,
+ pad=(0, 0, 0, self.output_bin - aux.size()[2]),
+ mode="replicate",
+ )
+ return mask, aux
+ else:
+ return mask
+
+ # Method for predicting the mask given an input tensor.
+ def predict_mask(self, input_tensor):
+ mask = self.forward(input_tensor)
+
+ # If an offset is specified, crop the mask to remove edge artifacts.
+ if self.offset > 0:
+ mask = mask[:, :, :, self.offset : -self.offset]
+ assert mask.size()[3] > 0
+
+ return mask
+
+ # Method for applying the predicted mask to the input tensor to obtain the predicted magnitude.
+ def predict(self, input_tensor):
+ mask = self.forward(input_tensor)
+ pred_mag = input_tensor * mask
+
+ # If an offset is specified, crop the predicted magnitude to remove edge artifacts.
+ if self.offset > 0:
+ pred_mag = pred_mag[:, :, :, self.offset : -self.offset]
+ assert pred_mag.size()[3] > 0
+
+ return pred_mag
diff --git a/mvsepless/models/vr_arch/spec_utils.py b/mvsepless/models/vr_arch/spec_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cdf421b38fb437cc412780987c19c257b71e5ce
--- /dev/null
+++ b/mvsepless/models/vr_arch/spec_utils.py
@@ -0,0 +1,1494 @@
+import audioread
+import librosa
+import numpy as np
+import soundfile as sf
+import math
+import platform
+import traceback
+from scipy.signal import correlate, hilbert
+import io
+
+OPERATING_SYSTEM = platform.system()
+SYSTEM_ARCH = platform.platform()
+SYSTEM_PROC = platform.processor()
+ARM = "arm"
+
+AUTO_PHASE = "Automatic"
+POSITIVE_PHASE = "Positive Phase"
+NEGATIVE_PHASE = "Negative Phase"
+NONE_P = ("None",)
+LOW_P = ("Shifts: Low",)
+MED_P = ("Shifts: Medium",)
+HIGH_P = ("Shifts: High",)
+VHIGH_P = "Shifts: Very High"
+MAXIMUM_P = "Shifts: Maximum"
+
+progress_value = 0
+last_update_time = 0
+is_macos = False
+
+
+if OPERATING_SYSTEM == "Darwin":
+ wav_resolution = (
+ "polyphase" if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else "sinc_fastest"
+ )
+ wav_resolution_float_resampling = (
+ "kaiser_best" if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else wav_resolution
+ )
+ is_macos = True
+else:
+ wav_resolution = "sinc_fastest"
+ wav_resolution_float_resampling = wav_resolution
+
+MAX_SPEC = "Max Spec"
+MIN_SPEC = "Min Spec"
+LIN_ENSE = "Linear Ensemble"
+
+MAX_WAV = MAX_SPEC
+MIN_WAV = MIN_SPEC
+
+AVERAGE = "Average"
+
+
+def crop_center(h1, h2):
+ """
+ This function crops the center of the first input tensor to match the size of the second input tensor.
+ It is used to ensure that the two tensors have the same size in the time dimension.
+ """
+ h1_shape = h1.size()
+ h2_shape = h2.size()
+
+ # If the time dimensions are already equal, return the first tensor as is
+ if h1_shape[3] == h2_shape[3]:
+ return h1
+ # If the time dimension of the first tensor is smaller, raise an error
+ elif h1_shape[3] < h2_shape[3]:
+ raise ValueError("h1_shape[3] must be greater than h2_shape[3]")
+
+ # Calculate the start and end indices for cropping
+ s_time = (h1_shape[3] - h2_shape[3]) // 2
+ e_time = s_time + h2_shape[3]
+ # Crop the first tensor
+ h1 = h1[:, :, :, s_time:e_time]
+
+ return h1
+
+
+def preprocess(X_spec):
+ """
+ This function preprocesses a spectrogram by separating it into magnitude and phase components.
+ This is a common preprocessing step in audio processing tasks.
+ """
+ X_mag = np.abs(X_spec)
+ X_phase = np.angle(X_spec)
+
+ return X_mag, X_phase
+
+
+def make_padding(width, cropsize, offset):
+ """
+ This function calculates the padding needed to make the width of an image divisible by the crop size.
+ It is used in the process of splitting an image into smaller patches.
+ """
+ left = offset
+ roi_size = cropsize - offset * 2
+ if roi_size == 0:
+ roi_size = cropsize
+ right = roi_size - (width % roi_size) + left
+
+ return left, right, roi_size
+
+
+def normalize(wave, max_peak=1.0, min_peak=None):
+ """Normalize (or amplify) audio waveform to a specified peak value.
+
+ Args:
+ wave (array-like): Audio waveform.
+ max_peak (float): Maximum peak value for normalization.
+
+ Returns:
+ array-like: Normalized or original waveform.
+ """
+ maxv = np.abs(wave).max()
+ if maxv > max_peak:
+ wave *= max_peak / maxv
+ elif min_peak is not None and maxv < min_peak:
+ wave *= min_peak / maxv
+
+ return wave
+
+
+def auto_transpose(audio_array: np.ndarray):
+ """
+ Ensure that the audio array is in the (channels, samples) format.
+
+ Parameters:
+ audio_array (ndarray): Input audio array.
+
+ Returns:
+ ndarray: Transposed audio array if necessary.
+ """
+
+ # If the second dimension is 2 (indicating stereo channels), transpose the array
+ if audio_array.shape[1] == 2:
+ return audio_array.T
+ return audio_array
+
+
+def write_array_to_mem(audio_data, subtype):
+ if isinstance(audio_data, np.ndarray):
+ audio_buffer = io.BytesIO()
+ sf.write(audio_buffer, audio_data, 44100, subtype=subtype, format="WAV")
+ audio_buffer.seek(0)
+ return audio_buffer
+ else:
+ return audio_data
+
+
+def spectrogram_to_image(spec, mode="magnitude"):
+ if mode == "magnitude":
+ if np.iscomplexobj(spec):
+ y = np.abs(spec)
+ else:
+ y = spec
+ y = np.log10(y**2 + 1e-8)
+ elif mode == "phase":
+ if np.iscomplexobj(spec):
+ y = np.angle(spec)
+ else:
+ y = spec
+
+ y -= y.min()
+ y *= 255 / y.max()
+ img = np.uint8(y)
+
+ if y.ndim == 3:
+ img = img.transpose(1, 2, 0)
+ img = np.concatenate([np.max(img, axis=2, keepdims=True), img], axis=2)
+
+ return img
+
+
+def reduce_vocal_aggressively(X, y, softmask):
+ v = X - y
+ y_mag_tmp = np.abs(y)
+ v_mag_tmp = np.abs(v)
+
+ v_mask = v_mag_tmp > y_mag_tmp
+ y_mag = np.clip(y_mag_tmp - v_mag_tmp * v_mask * softmask, 0, np.inf)
+
+ return y_mag * np.exp(1.0j * np.angle(y))
+
+
+def merge_artifacts(y_mask, thres=0.01, min_range=64, fade_size=32):
+ mask = y_mask
+
+ try:
+ if min_range < fade_size * 2:
+ raise ValueError("min_range must be >= fade_size * 2")
+
+ idx = np.where(y_mask.min(axis=(0, 1)) > thres)[0]
+ start_idx = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0])
+ end_idx = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1])
+ artifact_idx = np.where(end_idx - start_idx > min_range)[0]
+ weight = np.zeros_like(y_mask)
+ if len(artifact_idx) > 0:
+ start_idx = start_idx[artifact_idx]
+ end_idx = end_idx[artifact_idx]
+ old_e = None
+ for s, e in zip(start_idx, end_idx):
+ if old_e is not None and s - old_e < fade_size:
+ s = old_e - fade_size * 2
+
+ if s != 0:
+ weight[:, :, s : s + fade_size] = np.linspace(0, 1, fade_size)
+ else:
+ s -= fade_size
+
+ if e != y_mask.shape[2]:
+ weight[:, :, e - fade_size : e] = np.linspace(1, 0, fade_size)
+ else:
+ e += fade_size
+
+ weight[:, :, s + fade_size : e - fade_size] = 1
+ old_e = e
+
+ v_mask = 1 - y_mask
+ y_mask += weight * v_mask
+
+ mask = y_mask
+ except Exception as e:
+ error_name = f"{type(e).__name__}"
+ traceback_text = "".join(traceback.format_tb(e.__traceback__))
+ message = f'{error_name}: "{e}"\n{traceback_text}"'
+ print("Post Process Failed: ", message)
+
+ return mask
+
+
+def align_wave_head_and_tail(a, b):
+ l = min([a[0].size, b[0].size])
+
+ return a[:l, :l], b[:l, :l]
+
+
+def convert_channels(spec, mp, band):
+ cc = mp.param["band"][band].get("convert_channels")
+
+ if "mid_side_c" == cc:
+ spec_left = np.add(spec[0], spec[1] * 0.25)
+ spec_right = np.subtract(spec[1], spec[0] * 0.25)
+ elif "mid_side" == cc:
+ spec_left = np.add(spec[0], spec[1]) / 2
+ spec_right = np.subtract(spec[0], spec[1])
+ elif "stereo_n" == cc:
+ spec_left = np.add(spec[0], spec[1] * 0.25) / 0.9375
+ spec_right = np.add(spec[1], spec[0] * 0.25) / 0.9375
+ else:
+ return spec
+
+ return np.asfortranarray([spec_left, spec_right])
+
+
+def combine_spectrograms(specs, mp, is_v51_model=False):
+ l = min([specs[i].shape[2] for i in specs])
+ spec_c = np.zeros(shape=(2, mp.param["bins"] + 1, l), dtype=np.complex64)
+ offset = 0
+ bands_n = len(mp.param["band"])
+
+ for d in range(1, bands_n + 1):
+ h = mp.param["band"][d]["crop_stop"] - mp.param["band"][d]["crop_start"]
+ spec_c[:, offset : offset + h, :l] = specs[d][
+ :, mp.param["band"][d]["crop_start"] : mp.param["band"][d]["crop_stop"], :l
+ ]
+ offset += h
+
+ if offset > mp.param["bins"]:
+ raise ValueError("Too much bins")
+
+ # lowpass fiter
+
+ if mp.param["pre_filter_start"] > 0:
+ if is_v51_model:
+ spec_c *= get_lp_filter_mask(
+ spec_c.shape[1],
+ mp.param["pre_filter_start"],
+ mp.param["pre_filter_stop"],
+ )
+ else:
+ if bands_n == 1:
+ spec_c = fft_lp_filter(
+ spec_c, mp.param["pre_filter_start"], mp.param["pre_filter_stop"]
+ )
+ else:
+ gp = 1
+ for b in range(
+ mp.param["pre_filter_start"] + 1, mp.param["pre_filter_stop"]
+ ):
+ g = math.pow(
+ 10, -(b - mp.param["pre_filter_start"]) * (3.5 - gp) / 20.0
+ )
+ gp = g
+ spec_c[:, b, :] *= g
+
+ return np.asfortranarray(spec_c)
+
+
+def wave_to_spectrogram(wave, hop_length, n_fft, mp, band, is_v51_model=False):
+
+ if wave.ndim == 1:
+ wave = np.asfortranarray([wave, wave])
+
+ if not is_v51_model:
+ if mp.param["reverse"]:
+ wave_left = np.flip(np.asfortranarray(wave[0]))
+ wave_right = np.flip(np.asfortranarray(wave[1]))
+ elif mp.param["mid_side"]:
+ wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2)
+ wave_right = np.asfortranarray(np.subtract(wave[0], wave[1]))
+ elif mp.param["mid_side_b2"]:
+ wave_left = np.asfortranarray(np.add(wave[1], wave[0] * 0.5))
+ wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * 0.5))
+ else:
+ wave_left = np.asfortranarray(wave[0])
+ wave_right = np.asfortranarray(wave[1])
+ else:
+ wave_left = np.asfortranarray(wave[0])
+ wave_right = np.asfortranarray(wave[1])
+
+ spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length)
+ spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length)
+
+ spec = np.asfortranarray([spec_left, spec_right])
+
+ if is_v51_model:
+ spec = convert_channels(spec, mp, band)
+
+ return spec
+
+
+def spectrogram_to_wave(spec, hop_length=1024, mp={}, band=0, is_v51_model=True):
+ spec_left = np.asfortranarray(spec[0])
+ spec_right = np.asfortranarray(spec[1])
+
+ wave_left = librosa.istft(spec_left, hop_length=hop_length)
+ wave_right = librosa.istft(spec_right, hop_length=hop_length)
+
+ if is_v51_model:
+ cc = mp.param["band"][band].get("convert_channels")
+ if "mid_side_c" == cc:
+ return np.asfortranarray(
+ [
+ np.subtract(wave_left / 1.0625, wave_right / 4.25),
+ np.add(wave_right / 1.0625, wave_left / 4.25),
+ ]
+ )
+ elif "mid_side" == cc:
+ return np.asfortranarray(
+ [
+ np.add(wave_left, wave_right / 2),
+ np.subtract(wave_left, wave_right / 2),
+ ]
+ )
+ elif "stereo_n" == cc:
+ return np.asfortranarray(
+ [
+ np.subtract(wave_left, wave_right * 0.25),
+ np.subtract(wave_right, wave_left * 0.25),
+ ]
+ )
+ else:
+ if mp.param["reverse"]:
+ return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)])
+ elif mp.param["mid_side"]:
+ return np.asfortranarray(
+ [
+ np.add(wave_left, wave_right / 2),
+ np.subtract(wave_left, wave_right / 2),
+ ]
+ )
+ elif mp.param["mid_side_b2"]:
+ return np.asfortranarray(
+ [
+ np.add(wave_right / 1.25, 0.4 * wave_left),
+ np.subtract(wave_left / 1.25, 0.4 * wave_right),
+ ]
+ )
+
+ return np.asfortranarray([wave_left, wave_right])
+
+
+def cmb_spectrogram_to_wave(
+ spec_m, mp, extra_bins_h=None, extra_bins=None, is_v51_model=False
+):
+ bands_n = len(mp.param["band"])
+ offset = 0
+
+ for d in range(1, bands_n + 1):
+ bp = mp.param["band"][d]
+ spec_s = np.zeros(
+ shape=(2, bp["n_fft"] // 2 + 1, spec_m.shape[2]), dtype=complex
+ )
+ h = bp["crop_stop"] - bp["crop_start"]
+ spec_s[:, bp["crop_start"] : bp["crop_stop"], :] = spec_m[
+ :, offset : offset + h, :
+ ]
+
+ offset += h
+ if d == bands_n: # higher
+ if extra_bins_h: # if --high_end_process bypass
+ max_bin = bp["n_fft"] // 2
+ spec_s[:, max_bin - extra_bins_h : max_bin, :] = extra_bins[
+ :, :extra_bins_h, :
+ ]
+ if bp["hpf_start"] > 0:
+ if is_v51_model:
+ spec_s *= get_hp_filter_mask(
+ spec_s.shape[1], bp["hpf_start"], bp["hpf_stop"] - 1
+ )
+ else:
+ spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1)
+ if bands_n == 1:
+ wave = spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model)
+ else:
+ wave = np.add(
+ wave, spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model)
+ )
+ else:
+ sr = mp.param["band"][d + 1]["sr"]
+ if d == 1: # lower
+ if is_v51_model:
+ spec_s *= get_lp_filter_mask(
+ spec_s.shape[1], bp["lpf_start"], bp["lpf_stop"]
+ )
+ else:
+ spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"])
+
+ try:
+ wave = librosa.resample(
+ spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model),
+ orig_sr=bp["sr"],
+ target_sr=sr,
+ res_type=wav_resolution,
+ )
+ except ValueError as e:
+ print(f"Error during resampling: {e}")
+ print(
+ f"Spec_s shape: {spec_s.shape}, SR: {sr}, Res type: {wav_resolution}"
+ )
+
+ else: # mid
+ if is_v51_model:
+ spec_s *= get_hp_filter_mask(
+ spec_s.shape[1], bp["hpf_start"], bp["hpf_stop"] - 1
+ )
+ spec_s *= get_lp_filter_mask(
+ spec_s.shape[1], bp["lpf_start"], bp["lpf_stop"]
+ )
+ else:
+ spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1)
+ spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"])
+
+ wave2 = np.add(
+ wave, spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model)
+ )
+
+ try:
+ wave = librosa.resample(
+ wave2, orig_sr=bp["sr"], target_sr=sr, res_type=wav_resolution
+ )
+ except ValueError as e:
+ print(f"Error during resampling: {e}")
+ print(
+ f"Spec_s shape: {spec_s.shape}, SR: {sr}, Res type: {wav_resolution}"
+ )
+
+ return wave
+
+
+def get_lp_filter_mask(n_bins, bin_start, bin_stop):
+ mask = np.concatenate(
+ [
+ np.ones((bin_start - 1, 1)),
+ np.linspace(1, 0, bin_stop - bin_start + 1)[:, None],
+ np.zeros((n_bins - bin_stop, 1)),
+ ],
+ axis=0,
+ )
+
+ return mask
+
+
+def get_hp_filter_mask(n_bins, bin_start, bin_stop):
+ mask = np.concatenate(
+ [
+ np.zeros((bin_stop + 1, 1)),
+ np.linspace(0, 1, 1 + bin_start - bin_stop)[:, None],
+ np.ones((n_bins - bin_start - 2, 1)),
+ ],
+ axis=0,
+ )
+
+ return mask
+
+
+def fft_lp_filter(spec, bin_start, bin_stop):
+ g = 1.0
+ for b in range(bin_start, bin_stop):
+ g -= 1 / (bin_stop - bin_start)
+ spec[:, b, :] = g * spec[:, b, :]
+
+ spec[:, bin_stop:, :] *= 0
+
+ return spec
+
+
+def fft_hp_filter(spec, bin_start, bin_stop):
+ g = 1.0
+ for b in range(bin_start, bin_stop, -1):
+ g -= 1 / (bin_start - bin_stop)
+ spec[:, b, :] = g * spec[:, b, :]
+
+ spec[:, 0 : bin_stop + 1, :] *= 0
+
+ return spec
+
+
+def spectrogram_to_wave_old(spec, hop_length=1024):
+ if spec.ndim == 2:
+ wave = librosa.istft(spec, hop_length=hop_length)
+ elif spec.ndim == 3:
+ spec_left = np.asfortranarray(spec[0])
+ spec_right = np.asfortranarray(spec[1])
+
+ wave_left = librosa.istft(spec_left, hop_length=hop_length)
+ wave_right = librosa.istft(spec_right, hop_length=hop_length)
+ wave = np.asfortranarray([wave_left, wave_right])
+
+ return wave
+
+
+def wave_to_spectrogram_old(wave, hop_length, n_fft):
+ wave_left = np.asfortranarray(wave[0])
+ wave_right = np.asfortranarray(wave[1])
+
+ spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length)
+ spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length)
+
+ spec = np.asfortranarray([spec_left, spec_right])
+
+ return spec
+
+
+def mirroring(a, spec_m, input_high_end, mp):
+ if "mirroring" == a:
+ mirror = np.flip(
+ np.abs(
+ spec_m[
+ :,
+ mp.param["pre_filter_start"]
+ - 10
+ - input_high_end.shape[1] : mp.param["pre_filter_start"]
+ - 10,
+ :,
+ ]
+ ),
+ 1,
+ )
+ mirror = mirror * np.exp(1.0j * np.angle(input_high_end))
+
+ return np.where(
+ np.abs(input_high_end) <= np.abs(mirror), input_high_end, mirror
+ )
+
+ if "mirroring2" == a:
+ mirror = np.flip(
+ np.abs(
+ spec_m[
+ :,
+ mp.param["pre_filter_start"]
+ - 10
+ - input_high_end.shape[1] : mp.param["pre_filter_start"]
+ - 10,
+ :,
+ ]
+ ),
+ 1,
+ )
+ mi = np.multiply(mirror, input_high_end * 1.7)
+
+ return np.where(np.abs(input_high_end) <= np.abs(mi), input_high_end, mi)
+
+
+def adjust_aggr(mask, is_non_accom_stem, aggressiveness):
+ aggr = aggressiveness["value"] * 2
+
+ if aggr != 0:
+ if is_non_accom_stem:
+ aggr = 1 - aggr
+
+ if np.any(aggr > 10) or np.any(aggr < -10):
+ print(f"Warning: Extreme aggressiveness values detected: {aggr}")
+
+ aggr = [aggr, aggr]
+
+ if aggressiveness["aggr_correction"] is not None:
+ aggr[0] += aggressiveness["aggr_correction"]["left"]
+ aggr[1] += aggressiveness["aggr_correction"]["right"]
+
+ for ch in range(2):
+ mask[ch, : aggressiveness["split_bin"]] = np.power(
+ mask[ch, : aggressiveness["split_bin"]], 1 + aggr[ch] / 3
+ )
+ mask[ch, aggressiveness["split_bin"] :] = np.power(
+ mask[ch, aggressiveness["split_bin"] :], 1 + aggr[ch]
+ )
+
+ return mask
+
+
+def stft(wave, nfft, hl):
+ wave_left = np.asfortranarray(wave[0])
+ wave_right = np.asfortranarray(wave[1])
+ spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
+ spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
+ spec = np.asfortranarray([spec_left, spec_right])
+
+ return spec
+
+
+def istft(spec, hl):
+ spec_left = np.asfortranarray(spec[0])
+ spec_right = np.asfortranarray(spec[1])
+ wave_left = librosa.istft(spec_left, hop_length=hl)
+ wave_right = librosa.istft(spec_right, hop_length=hl)
+ wave = np.asfortranarray([wave_left, wave_right])
+
+ return wave
+
+
+def spec_effects(wave, algorithm="Default", value=None):
+ if np.isnan(wave).any() or np.isinf(wave).any():
+ print(
+ f"Warning: Detected NaN or infinite values in wave input. Shape: {wave.shape}"
+ )
+
+ spec = [stft(wave[0], 2048, 1024), stft(wave[1], 2048, 1024)]
+ if algorithm == "Min_Mag":
+ v_spec_m = np.where(np.abs(spec[1]) <= np.abs(spec[0]), spec[1], spec[0])
+ wave = istft(v_spec_m, 1024)
+ elif algorithm == "Max_Mag":
+ v_spec_m = np.where(np.abs(spec[1]) >= np.abs(spec[0]), spec[1], spec[0])
+ wave = istft(v_spec_m, 1024)
+ elif algorithm == "Default":
+ wave = (wave[1] * value) + (wave[0] * (1 - value))
+ elif algorithm == "Invert_p":
+ X_mag = np.abs(spec[0])
+ y_mag = np.abs(spec[1])
+ max_mag = np.where(X_mag >= y_mag, X_mag, y_mag)
+ v_spec = spec[1] - max_mag * np.exp(1.0j * np.angle(spec[0]))
+ wave = istft(v_spec, 1024)
+
+ return wave
+
+
+def spectrogram_to_wave_no_mp(spec, n_fft=2048, hop_length=1024):
+ wave = librosa.istft(spec, n_fft=n_fft, hop_length=hop_length)
+
+ if wave.ndim == 1:
+ wave = np.asfortranarray([wave, wave])
+
+ return wave
+
+
+def wave_to_spectrogram_no_mp(wave):
+
+ spec = librosa.stft(wave, n_fft=2048, hop_length=1024)
+
+ if spec.ndim == 1:
+ spec = np.asfortranarray([spec, spec])
+
+ return spec
+
+
+def invert_audio(specs, invert_p=True):
+
+ ln = min([specs[0].shape[2], specs[1].shape[2]])
+ specs[0] = specs[0][:, :, :ln]
+ specs[1] = specs[1][:, :, :ln]
+
+ if invert_p:
+ X_mag = np.abs(specs[0])
+ y_mag = np.abs(specs[1])
+ max_mag = np.where(X_mag >= y_mag, X_mag, y_mag)
+ v_spec = specs[1] - max_mag * np.exp(1.0j * np.angle(specs[0]))
+ else:
+ specs[1] = reduce_vocal_aggressively(specs[0], specs[1], 0.2)
+ v_spec = specs[0] - specs[1]
+
+ return v_spec
+
+
+def invert_stem(mixture, stem):
+ mixture = wave_to_spectrogram_no_mp(mixture)
+ stem = wave_to_spectrogram_no_mp(stem)
+ output = spectrogram_to_wave_no_mp(invert_audio([mixture, stem]))
+
+ return -output.T
+
+
+def ensembling(a, inputs, is_wavs=False):
+
+ for i in range(1, len(inputs)):
+ if i == 1:
+ input = inputs[0]
+
+ if is_wavs:
+ ln = min([input.shape[1], inputs[i].shape[1]])
+ input = input[:, :ln]
+ inputs[i] = inputs[i][:, :ln]
+ else:
+ ln = min([input.shape[2], inputs[i].shape[2]])
+ input = input[:, :, :ln]
+ inputs[i] = inputs[i][:, :, :ln]
+
+ if MIN_SPEC == a:
+ input = np.where(np.abs(inputs[i]) <= np.abs(input), inputs[i], input)
+ if MAX_SPEC == a:
+ # input = np.array(np.where(np.greater_equal(np.abs(inputs[i]), np.abs(input)), inputs[i], input), dtype=object)
+ input = np.where(np.abs(inputs[i]) >= np.abs(input), inputs[i], input)
+ # max_spec = np.array([np.where(np.greater_equal(np.abs(inputs[i]), np.abs(input)), s, specs[0]) for s in specs[1:]], dtype=object)[-1]
+
+ # linear_ensemble
+ # input = ensemble_wav(inputs, split_size=1)
+
+ return input
+
+
+def ensemble_for_align(waves):
+
+ specs = []
+
+ for wav in waves:
+ spec = wave_to_spectrogram_no_mp(wav.T)
+ specs.append(spec)
+
+ wav_aligned = spectrogram_to_wave_no_mp(ensembling(MIN_SPEC, specs)).T
+ wav_aligned = match_array_shapes(wav_aligned, waves[1], is_swap=True)
+
+ return wav_aligned
+
+
+def ensemble_inputs(
+ audio_input,
+ algorithm,
+ is_normalization,
+ wav_type_set,
+ save_path,
+ is_wave=False,
+ is_array=False,
+):
+
+ wavs_ = []
+
+ if algorithm == AVERAGE:
+ output = average_audio(audio_input)
+ samplerate = 44100
+ else:
+ specs = []
+
+ for i in range(len(audio_input)):
+ wave, samplerate = librosa.load(audio_input[i], mono=False, sr=44100)
+ wavs_.append(wave)
+ spec = wave if is_wave else wave_to_spectrogram_no_mp(wave)
+ specs.append(spec)
+
+ wave_shapes = [w.shape[1] for w in wavs_]
+ target_shape = wavs_[wave_shapes.index(max(wave_shapes))]
+
+ if is_wave:
+ output = ensembling(algorithm, specs, is_wavs=True)
+ else:
+ output = spectrogram_to_wave_no_mp(ensembling(algorithm, specs))
+
+ output = to_shape(output, target_shape.shape)
+
+ sf.write(
+ save_path,
+ normalize(output.T, is_normalization),
+ samplerate,
+ subtype=wav_type_set,
+ )
+
+
+def to_shape(x, target_shape):
+ padding_list = []
+ for x_dim, target_dim in zip(x.shape, target_shape):
+ pad_value = target_dim - x_dim
+ pad_tuple = (0, pad_value)
+ padding_list.append(pad_tuple)
+
+ return np.pad(x, tuple(padding_list), mode="constant")
+
+
+def to_shape_minimize(x: np.ndarray, target_shape):
+
+ padding_list = []
+ for x_dim, target_dim in zip(x.shape, target_shape):
+ pad_value = target_dim - x_dim
+ pad_tuple = (0, pad_value)
+ padding_list.append(pad_tuple)
+
+ return np.pad(x, tuple(padding_list), mode="constant")
+
+
+def detect_leading_silence(audio, sr, silence_threshold=0.007, frame_length=1024):
+ """
+ Detect silence at the beginning of an audio signal.
+
+ :param audio: np.array, audio signal
+ :param sr: int, sample rate
+ :param silence_threshold: float, magnitude threshold below which is considered silence
+ :param frame_length: int, the number of samples to consider for each check
+
+ :return: float, duration of the leading silence in milliseconds
+ """
+
+ if len(audio.shape) == 2:
+ # If stereo, pick the channel with more energy to determine the silence
+ channel = np.argmax(np.sum(np.abs(audio), axis=1))
+ audio = audio[channel]
+
+ for i in range(0, len(audio), frame_length):
+ if np.max(np.abs(audio[i : i + frame_length])) > silence_threshold:
+ return (i / sr) * 1000
+
+ return (len(audio) / sr) * 1000
+
+
+def adjust_leading_silence(
+ target_audio, reference_audio, silence_threshold=0.01, frame_length=1024
+):
+ """
+ Adjust the leading silence of the target_audio to match the leading silence of the reference_audio.
+
+ :param target_audio: np.array, audio signal that will have its silence adjusted
+ :param reference_audio: np.array, audio signal used as a reference
+ :param sr: int, sample rate
+ :param silence_threshold: float, magnitude threshold below which is considered silence
+ :param frame_length: int, the number of samples to consider for each check
+
+ :return: np.array, target_audio adjusted to have the same leading silence as reference_audio
+ """
+
+ def find_silence_end(audio):
+ if len(audio.shape) == 2:
+ # If stereo, pick the channel with more energy to determine the silence
+ channel = np.argmax(np.sum(np.abs(audio), axis=1))
+ audio_mono = audio[channel]
+ else:
+ audio_mono = audio
+
+ for i in range(0, len(audio_mono), frame_length):
+ if np.max(np.abs(audio_mono[i : i + frame_length])) > silence_threshold:
+ return i
+ return len(audio_mono)
+
+ ref_silence_end = find_silence_end(reference_audio)
+ target_silence_end = find_silence_end(target_audio)
+ silence_difference = ref_silence_end - target_silence_end
+
+ try:
+ ref_silence_end_p = (ref_silence_end / 44100) * 1000
+ target_silence_end_p = (target_silence_end / 44100) * 1000
+ silence_difference_p = ref_silence_end_p - target_silence_end_p
+ print("silence_difference: ", silence_difference_p)
+ except Exception as e:
+ pass
+
+ if silence_difference > 0: # Add silence to target_audio
+ if len(target_audio.shape) == 2: # stereo
+ silence_to_add = np.zeros((target_audio.shape[0], silence_difference))
+ else: # mono
+ silence_to_add = np.zeros(silence_difference)
+ return np.hstack((silence_to_add, target_audio))
+ elif silence_difference < 0: # Remove silence from target_audio
+ if len(target_audio.shape) == 2: # stereo
+ return target_audio[:, -silence_difference:]
+ else: # mono
+ return target_audio[-silence_difference:]
+ else: # No adjustment needed
+ return target_audio
+
+
+def match_array_shapes(array_1: np.ndarray, array_2: np.ndarray, is_swap=False):
+
+ if is_swap:
+ array_1, array_2 = array_1.T, array_2.T
+
+ # print("before", array_1.shape, array_2.shape)
+ if array_1.shape[1] > array_2.shape[1]:
+ array_1 = array_1[:, : array_2.shape[1]]
+ elif array_1.shape[1] < array_2.shape[1]:
+ padding = array_2.shape[1] - array_1.shape[1]
+ array_1 = np.pad(array_1, ((0, 0), (0, padding)), "constant", constant_values=0)
+
+ # print("after", array_1.shape, array_2.shape)
+
+ if is_swap:
+ array_1, array_2 = array_1.T, array_2.T
+
+ return array_1
+
+
+def match_mono_array_shapes(array_1: np.ndarray, array_2: np.ndarray):
+
+ if len(array_1) > len(array_2):
+ array_1 = array_1[: len(array_2)]
+ elif len(array_1) < len(array_2):
+ padding = len(array_2) - len(array_1)
+ array_1 = np.pad(array_1, (0, padding), "constant", constant_values=0)
+
+ return array_1
+
+
+def change_pitch_semitones(y, sr, semitone_shift):
+ factor = 2 ** (
+ semitone_shift / 12
+ ) # Convert semitone shift to factor for resampling
+ y_pitch_tuned = []
+ for y_channel in y:
+ y_pitch_tuned.append(
+ librosa.resample(
+ y_channel,
+ orig_sr=sr,
+ target_sr=sr * factor,
+ res_type=wav_resolution_float_resampling,
+ )
+ )
+ y_pitch_tuned = np.array(y_pitch_tuned)
+ new_sr = sr * factor
+ return y_pitch_tuned, new_sr
+
+def average_audio(audio):
+
+ waves = []
+ wave_shapes = []
+ final_waves = []
+
+ for i in range(len(audio)):
+ wave = librosa.load(audio[i], sr=44100, mono=False)
+ waves.append(wave[0])
+ wave_shapes.append(wave[0].shape[1])
+
+ wave_shapes_index = wave_shapes.index(max(wave_shapes))
+ target_shape = waves[wave_shapes_index]
+ waves.pop(wave_shapes_index)
+ final_waves.append(target_shape)
+
+ for n_array in waves:
+ wav_target = to_shape(n_array, target_shape.shape)
+ final_waves.append(wav_target)
+
+ waves = sum(final_waves)
+ waves = waves / len(audio)
+
+ return waves
+
+
+def average_dual_sources(wav_1, wav_2, value):
+
+ if wav_1.shape > wav_2.shape:
+ wav_2 = to_shape(wav_2, wav_1.shape)
+ if wav_1.shape < wav_2.shape:
+ wav_1 = to_shape(wav_1, wav_2.shape)
+
+ wave = (wav_1 * value) + (wav_2 * (1 - value))
+
+ return wave
+
+
+def reshape_sources(wav_1: np.ndarray, wav_2: np.ndarray):
+
+ if wav_1.shape > wav_2.shape:
+ wav_2 = to_shape(wav_2, wav_1.shape)
+ if wav_1.shape < wav_2.shape:
+ ln = min([wav_1.shape[1], wav_2.shape[1]])
+ wav_2 = wav_2[:, :ln]
+
+ ln = min([wav_1.shape[1], wav_2.shape[1]])
+ wav_1 = wav_1[:, :ln]
+ wav_2 = wav_2[:, :ln]
+
+ return wav_2
+
+
+def reshape_sources_ref(wav_1_shape, wav_2: np.ndarray):
+
+ if wav_1_shape > wav_2.shape:
+ wav_2 = to_shape(wav_2, wav_1_shape)
+
+ return wav_2
+
+
+def combine_arrarys(audio_sources, is_swap=False):
+ source = np.zeros_like(max(audio_sources, key=np.size))
+
+ for v in audio_sources:
+ v = match_array_shapes(v, source, is_swap=is_swap)
+ source += v
+
+ return source
+
+
+def combine_audio(
+ paths: list, audio_file_base=None, wav_type_set="FLOAT", save_format=None
+):
+
+ source = combine_arrarys([load_audio(i) for i in paths])
+ save_path = f"{audio_file_base}_combined.wav"
+ sf.write(save_path, source.T, 44100, subtype=wav_type_set)
+ save_format(save_path)
+
+
+def reduce_mix_bv(inst_source, voc_source, reduction_rate=0.9):
+ # Reduce the volume
+ inst_source = inst_source * (1 - reduction_rate)
+
+ mix_reduced = combine_arrarys([inst_source, voc_source], is_swap=True)
+
+ return mix_reduced
+
+
+def organize_inputs(inputs):
+ input_list = {"target": None, "reference": None, "reverb": None, "inst": None}
+
+ for i in inputs:
+ if i.endswith("_(Vocals).wav"):
+ input_list["reference"] = i
+ elif "_RVC_" in i:
+ input_list["target"] = i
+ elif i.endswith("reverbed_stem.wav"):
+ input_list["reverb"] = i
+ elif i.endswith("_(Instrumental).wav"):
+ input_list["inst"] = i
+
+ return input_list
+
+
+def check_if_phase_inverted(wav1, wav2, is_mono=False):
+ # Load the audio files
+ if not is_mono:
+ wav1 = np.mean(wav1, axis=0)
+ wav2 = np.mean(wav2, axis=0)
+
+ # Compute the correlation
+ correlation = np.corrcoef(wav1[:1000], wav2[:1000])
+
+ return correlation[0, 1] < 0
+
+
+def align_audio(
+ file1,
+ file2,
+ file2_aligned,
+ file_subtracted,
+ wav_type_set,
+ is_save_aligned,
+ command_Text,
+ save_format,
+ align_window: list,
+ align_intro_val: list,
+ db_analysis: tuple,
+ set_progress_bar,
+ phase_option,
+ phase_shifts,
+ is_match_silence,
+ is_spec_match,
+):
+
+ global progress_value
+ progress_value = 0
+ is_mono = False
+
+ def get_diff(a, b):
+ corr = np.correlate(a, b, "full")
+ diff = corr.argmax() - (b.shape[0] - 1)
+
+ return diff
+
+ def progress_bar(length):
+ global progress_value
+ progress_value += 1
+
+ if (0.90 / length * progress_value) >= 0.9:
+ length = progress_value + 1
+
+ set_progress_bar(0.1, (0.9 / length * progress_value))
+
+ # read tracks
+
+ if file1.endswith(".mp3") and is_macos:
+ length1 = rerun_mp3(file1)
+ wav1, sr1 = librosa.load(file1, duration=length1, sr=44100, mono=False)
+ else:
+ wav1, sr1 = librosa.load(file1, sr=44100, mono=False)
+
+ if file2.endswith(".mp3") and is_macos:
+ length2 = rerun_mp3(file2)
+ wav2, sr2 = librosa.load(file2, duration=length2, sr=44100, mono=False)
+ else:
+ wav2, sr2 = librosa.load(file2, sr=44100, mono=False)
+
+ if wav1.ndim == 1 and wav2.ndim == 1:
+ is_mono = True
+ elif wav1.ndim == 1:
+ wav1 = np.asfortranarray([wav1, wav1])
+ elif wav2.ndim == 1:
+ wav2 = np.asfortranarray([wav2, wav2])
+
+ # Check if phase is inverted
+ if phase_option == AUTO_PHASE:
+ if check_if_phase_inverted(wav1, wav2, is_mono=is_mono):
+ wav2 = -wav2
+ elif phase_option == POSITIVE_PHASE:
+ wav2 = +wav2
+ elif phase_option == NEGATIVE_PHASE:
+ wav2 = -wav2
+
+ if is_match_silence:
+ wav2 = adjust_leading_silence(wav2, wav1)
+
+ wav1_length = int(librosa.get_duration(y=wav1, sr=44100))
+ wav2_length = int(librosa.get_duration(y=wav2, sr=44100))
+
+ if not is_mono:
+ wav1 = wav1.transpose()
+ wav2 = wav2.transpose()
+
+ wav2_org = wav2.copy()
+
+ command_Text("Processing files... \n")
+ seconds_length = min(wav1_length, wav2_length)
+
+ wav2_aligned_sources = []
+
+ for sec_len in align_intro_val:
+ # pick a position at 1 second in and get diff
+ sec_seg = 1 if sec_len == 1 else int(seconds_length // sec_len)
+ index = sr1 * sec_seg # 1 second in, assuming sr1 = sr2 = 44100
+
+ if is_mono:
+ samp1, samp2 = wav1[index : index + sr1], wav2[index : index + sr1]
+ diff = get_diff(samp1, samp2)
+ # print(f"Estimated difference: {diff}\n")
+ else:
+ index = sr1 * sec_seg # 1 second in, assuming sr1 = sr2 = 44100
+ samp1, samp2 = wav1[index : index + sr1, 0], wav2[index : index + sr1, 0]
+ samp1_r, samp2_r = (
+ wav1[index : index + sr1, 1],
+ wav2[index : index + sr1, 1],
+ )
+ diff, diff_r = get_diff(samp1, samp2), get_diff(samp1_r, samp2_r)
+ # print(f"Estimated difference Left Channel: {diff}\nEstimated difference Right Channel: {diff_r}\n")
+
+ # make aligned track 2
+ if diff > 0:
+ zeros_to_append = np.zeros(diff) if is_mono else np.zeros((diff, 2))
+ wav2_aligned = np.append(zeros_to_append, wav2_org, axis=0)
+ elif diff < 0:
+ wav2_aligned = wav2_org[-diff:]
+ else:
+ wav2_aligned = wav2_org
+ # command_Text(f"Audio files already aligned.\n")
+
+ if not any(
+ np.array_equal(wav2_aligned, source) for source in wav2_aligned_sources
+ ):
+ wav2_aligned_sources.append(wav2_aligned)
+
+ # print("Unique Sources: ", len(wav2_aligned_sources))
+
+ unique_sources = len(wav2_aligned_sources)
+
+ sub_mapper_big_mapper = {}
+
+ for s in wav2_aligned_sources:
+ wav2_aligned = (
+ match_mono_array_shapes(s, wav1)
+ if is_mono
+ else match_array_shapes(s, wav1, is_swap=True)
+ )
+
+ if align_window:
+ wav_sub = time_correction(
+ wav1,
+ wav2_aligned,
+ seconds_length,
+ align_window=align_window,
+ db_analysis=db_analysis,
+ progress_bar=progress_bar,
+ unique_sources=unique_sources,
+ phase_shifts=phase_shifts,
+ )
+ wav_sub_size = np.abs(wav_sub).mean()
+ sub_mapper_big_mapper = {**sub_mapper_big_mapper, **{wav_sub_size: wav_sub}}
+ else:
+ wav2_aligned = wav2_aligned * np.power(10, db_analysis[0] / 20)
+ db_range = db_analysis[1]
+
+ for db_adjustment in db_range:
+ # Adjust the dB of track2
+ s_adjusted = wav2_aligned * (10 ** (db_adjustment / 20))
+ wav_sub = wav1 - s_adjusted
+ wav_sub_size = np.abs(wav_sub).mean()
+ sub_mapper_big_mapper = {
+ **sub_mapper_big_mapper,
+ **{wav_sub_size: wav_sub},
+ }
+
+ # print(sub_mapper_big_mapper.keys(), min(sub_mapper_big_mapper.keys()))
+
+ sub_mapper_value_list = list(sub_mapper_big_mapper.values())
+
+ if is_spec_match and len(sub_mapper_value_list) >= 2:
+ # print("using spec ensemble with align")
+ wav_sub = ensemble_for_align(list(sub_mapper_big_mapper.values()))
+ else:
+ # print("using linear ensemble with align")
+ wav_sub = ensemble_wav(list(sub_mapper_big_mapper.values()))
+
+ # print(f"Mix Mean: {np.abs(wav1).mean()}\nInst Mean: {np.abs(wav2).mean()}")
+ # print('Final: ', np.abs(wav_sub).mean())
+ wav_sub = np.clip(wav_sub, -1, +1)
+
+ command_Text(f"Saving inverted track... ")
+
+ if is_save_aligned or is_spec_match:
+ wav1 = (
+ match_mono_array_shapes(wav1, wav_sub)
+ if is_mono
+ else match_array_shapes(wav1, wav_sub, is_swap=True)
+ )
+ wav2_aligned = wav1 - wav_sub
+
+ if is_spec_match:
+ if wav1.ndim == 1 and wav2.ndim == 1:
+ wav2_aligned = np.asfortranarray([wav2_aligned, wav2_aligned]).T
+ wav1 = np.asfortranarray([wav1, wav1]).T
+
+ wav2_aligned = ensemble_for_align([wav2_aligned, wav1])
+ wav_sub = wav1 - wav2_aligned
+
+ if is_save_aligned:
+ sf.write(file2_aligned, wav2_aligned, sr1, subtype=wav_type_set)
+ save_format(file2_aligned)
+
+ sf.write(file_subtracted, wav_sub, sr1, subtype=wav_type_set)
+ save_format(file_subtracted)
+
+
+def phase_shift_hilbert(signal, degree):
+ analytic_signal = hilbert(signal)
+ return (
+ np.cos(np.radians(degree)) * analytic_signal.real
+ - np.sin(np.radians(degree)) * analytic_signal.imag
+ )
+
+
+def get_phase_shifted_tracks(track, phase_shift):
+ if phase_shift == 180:
+ return [track, -track]
+
+ step = phase_shift
+ end = 180 - (180 % step) if 180 % step == 0 else 181
+ phase_range = range(step, end, step)
+
+ flipped_list = [track, -track]
+ for i in phase_range:
+ flipped_list.extend(
+ [phase_shift_hilbert(track, i), phase_shift_hilbert(track, -i)]
+ )
+
+ return flipped_list
+
+
+def time_correction(
+ mix: np.ndarray,
+ instrumental: np.ndarray,
+ seconds_length,
+ align_window,
+ db_analysis,
+ sr=44100,
+ progress_bar=None,
+ unique_sources=None,
+ phase_shifts=NONE_P,
+):
+ # Function to align two tracks using cross-correlation
+
+ def align_tracks(track1, track2):
+ # A dictionary to store each version of track2_shifted and its mean absolute value
+ shifted_tracks = {}
+
+ # Loop to adjust dB of track2
+ track2 = track2 * np.power(10, db_analysis[0] / 20)
+ db_range = db_analysis[1]
+
+ if phase_shifts == 190:
+ track2_flipped = [track2]
+ else:
+ track2_flipped = get_phase_shifted_tracks(track2, phase_shifts)
+
+ for db_adjustment in db_range:
+ for t in track2_flipped:
+ # Adjust the dB of track2
+ track2_adjusted = t * (10 ** (db_adjustment / 20))
+ corr = correlate(track1, track2_adjusted)
+ delay = np.argmax(np.abs(corr)) - (len(track1) - 1)
+ track2_shifted = np.roll(track2_adjusted, shift=delay)
+
+ # Compute the mean absolute value of track2_shifted
+ track2_shifted_sub = track1 - track2_shifted
+ mean_abs_value = np.abs(track2_shifted_sub).mean()
+
+ # Store track2_shifted and its mean absolute value in the dictionary
+ shifted_tracks[mean_abs_value] = track2_shifted
+
+ # Return the version of track2_shifted with the smallest mean absolute value
+
+ return shifted_tracks[min(shifted_tracks.keys())]
+
+ # Make sure the audio files have the same shape
+
+ assert (
+ mix.shape == instrumental.shape
+ ), f"Audio files must have the same shape - Mix: {mix.shape}, Inst: {instrumental.shape}"
+
+ seconds_length = seconds_length // 2
+
+ sub_mapper = {}
+
+ progress_update_interval = 120
+ total_iterations = 0
+
+ if len(align_window) > 2:
+ progress_update_interval = 320
+
+ for secs in align_window:
+ step = secs / 2
+ window_size = int(sr * secs)
+ step_size = int(sr * step)
+
+ if len(mix.shape) == 1:
+ total_mono = (
+ len(range(0, len(mix) - window_size, step_size))
+ // progress_update_interval
+ ) * unique_sources
+ total_iterations += total_mono
+ else:
+ total_stereo_ = len(range(0, len(mix[:, 0]) - window_size, step_size)) * 2
+ total_stereo = (total_stereo_ // progress_update_interval) * unique_sources
+ total_iterations += total_stereo
+
+ # print(total_iterations)
+
+ for secs in align_window:
+ sub = np.zeros_like(mix)
+ divider = np.zeros_like(mix)
+ step = secs / 2
+ window_size = int(sr * secs)
+ step_size = int(sr * step)
+ window = np.hanning(window_size)
+
+ # For the mono case:
+ if len(mix.shape) == 1:
+ # The files are mono
+ counter = 0
+ for i in range(0, len(mix) - window_size, step_size):
+ counter += 1
+ if counter % progress_update_interval == 0:
+ progress_bar(total_iterations)
+ window_mix = mix[i : i + window_size] * window
+ window_instrumental = instrumental[i : i + window_size] * window
+ window_instrumental_aligned = align_tracks(
+ window_mix, window_instrumental
+ )
+ sub[i : i + window_size] += window_mix - window_instrumental_aligned
+ divider[i : i + window_size] += window
+ else:
+ # The files are stereo
+ counter = 0
+ for ch in range(mix.shape[1]):
+ for i in range(0, len(mix[:, ch]) - window_size, step_size):
+ counter += 1
+ if counter % progress_update_interval == 0:
+ progress_bar(total_iterations)
+ window_mix = mix[i : i + window_size, ch] * window
+ window_instrumental = instrumental[i : i + window_size, ch] * window
+ window_instrumental_aligned = align_tracks(
+ window_mix, window_instrumental
+ )
+ sub[i : i + window_size, ch] += (
+ window_mix - window_instrumental_aligned
+ )
+ divider[i : i + window_size, ch] += window
+
+ # Normalize the result by the overlap count
+ sub = np.where(divider > 1e-6, sub / divider, sub)
+ sub_size = np.abs(sub).mean()
+ sub_mapper = {**sub_mapper, **{sub_size: sub}}
+
+ # print("SUB_LEN", len(list(sub_mapper.values())))
+
+ sub = ensemble_wav(list(sub_mapper.values()), split_size=12)
+
+ return sub
+
+
+def ensemble_wav(waveforms, split_size=240):
+ # Create a dictionary to hold the thirds of each waveform and their mean absolute values
+ waveform_thirds = {
+ i: np.array_split(waveform, split_size) for i, waveform in enumerate(waveforms)
+ }
+
+ # Initialize the final waveform
+ final_waveform = []
+
+ # For chunk
+ for third_idx in range(split_size):
+ # Compute the mean absolute value of each third from each waveform
+ means = [
+ np.abs(waveform_thirds[i][third_idx]).mean() for i in range(len(waveforms))
+ ]
+
+ # Find the index of the waveform with the lowest mean absolute value for this third
+ min_index = np.argmin(means)
+
+ # Add the least noisy third to the final waveform
+ final_waveform.append(waveform_thirds[min_index][third_idx])
+
+ # Concatenate all the thirds to create the final waveform
+ final_waveform = np.concatenate(final_waveform)
+
+ return final_waveform
+
+
+def ensemble_wav_min(waveforms):
+ for i in range(1, len(waveforms)):
+ if i == 1:
+ wave = waveforms[0]
+
+ ln = min(len(wave), len(waveforms[i]))
+ wave = wave[:ln]
+ waveforms[i] = waveforms[i][:ln]
+
+ wave = np.where(np.abs(waveforms[i]) <= np.abs(wave), waveforms[i], wave)
+
+ return wave
+
+
+def align_audio_test(wav1, wav2, sr1=44100):
+ def get_diff(a, b):
+ corr = np.correlate(a, b, "full")
+ diff = corr.argmax() - (b.shape[0] - 1)
+ return diff
+
+ # read tracks
+ wav1 = wav1.transpose()
+ wav2 = wav2.transpose()
+
+ # print(f"Audio file shapes: {wav1.shape} / {wav2.shape}\n")
+
+ wav2_org = wav2.copy()
+
+ # pick a position at 1 second in and get diff
+ index = sr1 # *seconds_length # 1 second in, assuming sr1 = sr2 = 44100
+ samp1 = wav1[index : index + sr1, 0] # currently use left channel
+ samp2 = wav2[index : index + sr1, 0]
+ diff = get_diff(samp1, samp2)
+
+ # make aligned track 2
+ if diff > 0:
+ wav2_aligned = np.append(np.zeros((diff, 1)), wav2_org, axis=0)
+ elif diff < 0:
+ wav2_aligned = wav2_org[-diff:]
+ else:
+ wav2_aligned = wav2_org
+
+ return wav2_aligned
+
+
+def load_audio(audio_file):
+ wav, sr = librosa.load(audio_file, sr=44100, mono=False)
+
+ if wav.ndim == 1:
+ wav = np.asfortranarray([wav, wav])
+
+ return wav
+
+
+def rerun_mp3(audio_file):
+ with audioread.audio_open(audio_file) as f:
+ track_length = int(f.duration)
+
+ return track_length
diff --git a/mvsepless/models/windowed_roformer/flex_attention_utils.py b/mvsepless/models/windowed_roformer/flex_attention_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ef9b0e06ebb2087bf3cf0a4ce965a820eaa74e9
--- /dev/null
+++ b/mvsepless/models/windowed_roformer/flex_attention_utils.py
@@ -0,0 +1,110 @@
+from functools import lru_cache
+from typing import Optional
+import torch
+import torch.nn as nn
+from torch.nn.attention.flex_attention import (
+ flex_attention,
+ create_block_mask,
+ _mask_mod_signature,
+)
+
+@lru_cache(maxsize=128)
+def create_block_mask_cached(
+ mask_mod: _mask_mod_signature,
+ B: Optional[int],
+ H: Optional[int],
+ Q_LEN: int,
+ KV_LEN: int,
+ device: torch.device
+):
+ block_mask = create_block_mask(
+ mask_mod,
+ B=B,
+ H=H,
+ Q_LEN=Q_LEN,
+ KV_LEN=KV_LEN,
+ device=device
+ )
+ return block_mask
+
+
+def get_compiled_flex_attention(compile: bool = True, mode: str = "default"):
+ if compile:
+ return torch.compile(flex_attention, dynamic=False, mode=mode)
+ return flex_attention
+
+def generate_sliding_window_with_sinks(
+ window_size: int,
+ num_sink_tokens: int
+) -> _mask_mod_signature:
+
+ half_window = window_size // 2
+
+ def sliding_window_with_global_sinks(b, h, q_idx, kv_idx):
+ is_query_sink = q_idx < num_sink_tokens
+ is_kv_sink = kv_idx < num_sink_tokens
+ is_in_window = torch.abs(q_idx - kv_idx) <= half_window
+
+ # Query sinks can attend to everything
+ # Regular queries can attend to: sinks OR tokens within sliding window
+ return is_query_sink | is_kv_sink | is_in_window
+
+ sliding_window_with_global_sinks.__name__ = f"sliding_window_w{window_size}_sinks{num_sink_tokens}"
+ return sliding_window_with_global_sinks
+
+class FlexAttention(nn.Module):
+ def __init__(
+ self,
+ mask_mod: _mask_mod_signature,
+ dropout: float = 0.0,
+ scale: Optional[float] = None,
+ compile: bool = False,
+ compile_mode: str = "max-autotune"
+ ):
+ super().__init__()
+ self.mask_mod = mask_mod
+ self.scale = scale
+ self.dropout = dropout
+ self.attn_dropout = nn.Dropout(dropout) if dropout > 0 else None
+
+ # Get compiled or uncompiled flex_attention
+ self.flex_attn_fn = get_compiled_flex_attention(compile, mode=compile_mode)
+
+ def forward(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ ) -> torch.Tensor:
+
+ batch_size, num_heads, seq_len, head_dim = q.shape
+ device = q.device
+
+ # Create block mask (cached based on mask_mod and dimensions)
+ # Note: We use None for B and H to allow broadcasting across batches/heads
+ block_mask = create_block_mask_cached(
+ self.mask_mod,
+ B=None,
+ H=None,
+ Q_LEN=seq_len,
+ KV_LEN=seq_len,
+ device=device.type
+ )
+ # Apply scale if specified
+ if self.scale is not None:
+ # Adjust query by scale factor
+ default_scale = head_dim ** -0.5
+ q = q * (self.scale / default_scale)
+
+ # Apply flex attention with block mask
+ out = self.flex_attn_fn(
+ q, k, v,
+ block_mask=block_mask,
+ scale=self.scale
+ )
+
+ # Apply dropout if training
+ if self.training and self.attn_dropout is not None:
+ out = self.attn_dropout(out)
+
+ return out
\ No newline at end of file
diff --git a/mvsepless/models/windowed_roformer/model.py b/mvsepless/models/windowed_roformer/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..fde4704b643c772dd46e0ee006128d723abbe7a9
--- /dev/null
+++ b/mvsepless/models/windowed_roformer/model.py
@@ -0,0 +1,340 @@
+import torch
+import torch.nn as nn
+
+import torch
+
+from rotary_embedding_torch import RotaryEmbedding
+
+from einops import rearrange, pack, unpack, reduce, repeat
+from librosa import filters
+from .modules import BandSplit, MaskEstimator, Transformer, pack_one, unpack_one
+from functools import partial
+
+class MelBandRoformerWSA(nn.Module):
+ def __init__(
+ self,
+ dim: int = 384,
+ depth: int = 6,
+ # num_stems: int = 1,
+ num_bands: int = 60,
+ dim_head: int = 64,
+ heads: int = 8,
+ attn_dropout: float = 0.0,
+ ff_dropout: float = 0.0,
+ sample_rate: int = 44100,
+ stft_n_fft: int = 2048,
+ stft_hop_length: int = 441,
+ stft_win_length: int = 2048,
+ stft_normalized: bool = False,
+ wsa_window_len: int = 10,
+ n_wsa_sinks: int = 8,
+ **kwargs
+ ):
+ super().__init__()
+
+ # Store configuration parameters
+ self.audio_channels = 2
+ # self.num_stems = num_stems
+ self.n_wsa_sinks = n_wsa_sinks
+ self.dim = dim
+ self.depth = depth
+ self.num_bands = num_bands
+ self.sample_rate = sample_rate
+ self.stft_n_fft = stft_n_fft
+ self.stft_hop_length = stft_hop_length
+ self.stft_win_length = stft_win_length
+ self.stft_normalized = stft_normalized
+ self.wsa_window_len = wsa_window_len
+
+ # Store transformer configuration
+ self.time_transformer_kwargs = dict(
+ dim=dim,
+ heads=heads,
+ dim_head=dim_head,
+ attn_dropout=attn_dropout,
+ ff_dropout=ff_dropout,
+ wsa_window_len=wsa_window_len,
+ n_wsa_sinks=n_wsa_sinks
+ )
+ self.freq_transformer_kwargs = dict(
+ dim=dim,
+ heads=heads,
+ dim_head=dim_head,
+ attn_dropout=attn_dropout,
+ ff_dropout=ff_dropout,
+ )
+
+ # Build components
+ self._build_stft_config()
+ self._build_mel_filter_bank()
+ self._build_transformer_layers()
+ self._build_band_split_and_mask_estimators()
+ self._build_sink_tokens()
+
+ # Initialize weights
+ self.apply(self._init_weights)
+
+ def _init_weights(self, module):
+ """Initialize weights for linear and conv layers"""
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+ def _build_stft_config(self):
+ """Build STFT configuration and window function"""
+ self.stft_window_fn = partial(torch.hann_window, self.stft_win_length)
+ self.stft_kwargs = dict(
+ n_fft=self.stft_n_fft,
+ hop_length=self.stft_hop_length,
+ win_length=self.stft_win_length,
+ normalized=self.stft_normalized
+ )
+
+ def _build_mel_filter_bank(self):
+ """Build mel filter bank and frequency indices"""
+ # Get number of frequencies from STFT
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(self.stft_n_fft), return_complex=True).shape[1]
+
+ # Create mel filter bank with librosa.filters.mel as in section 2 of paper
+ mel_filter_bank_numpy = filters.mel(sr=self.sample_rate, n_fft=self.stft_n_fft, n_mels=self.num_bands)
+ mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
+
+ # Fix edge cases
+ mel_filter_bank[0][0] = 1. # First frequency
+ mel_filter_bank[-1, -1] = 1. # Last frequency
+
+ # Binary mask as in paper (estimated masks are averaged for overlapping regions)
+ freqs_per_band = mel_filter_bank > 0
+ assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
+
+ # Create frequency indices for band splitting
+ repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=self.num_bands)
+ freq_indices = repeated_freq_indices[freqs_per_band]
+
+ # if self.stereo:
+ freq_indices = repeat(freq_indices, 'f -> f s', s=2)
+ freq_indices = freq_indices * 2 + torch.arange(2)
+ freq_indices = rearrange(freq_indices, 'f s -> (f s)')
+
+ # Register buffers
+ self.register_buffer('freq_indices', freq_indices, persistent=False)
+ self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
+
+ # Calculate frequency statistics
+ num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
+ num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
+
+ self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
+ self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
+
+ def _build_transformer_layers(self):
+ """Build transformer layers with rotary embeddings"""
+ self.layers = nn.ModuleList([])
+
+ time_rotary_embed = RotaryEmbedding(dim=self.time_transformer_kwargs['dim_head'])
+ freq_rotary_embed = RotaryEmbedding(dim=self.freq_transformer_kwargs['dim_head'])
+
+ for _ in range(self.depth):
+ tran_modules = []
+ tran_modules.append(
+ Transformer(rotary_embed=time_rotary_embed, **self.time_transformer_kwargs)
+ )
+ tran_modules.append(
+ Transformer(rotary_embed=freq_rotary_embed, **self.freq_transformer_kwargs)
+ )
+ self.layers.append(nn.ModuleList(tran_modules))
+
+ def _build_band_split_and_mask_estimators(self):
+ """Build band split module and mask estimators"""
+ # Calculate input dimensions for each band
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in self.num_freqs_per_band.tolist())
+
+ # Band split module
+ self.band_split = BandSplit(
+ dim=self.dim,
+ dim_inputs=freqs_per_bands_with_complex
+ )
+
+ # Mask estimators
+ self.mask_estimators = nn.ModuleList([])
+ for _ in range(1):
+ mask_estimator = MaskEstimator(
+ dim=self.dim,
+ dim_inputs=freqs_per_bands_with_complex,
+ depth=2,
+ mlp_expansion_factor=4,
+ )
+ self.mask_estimators.append(mask_estimator)
+
+ def _build_sink_tokens(self):
+ """Build learnable sink tokens for efficient attention"""
+ if self.n_wsa_sinks > 0:
+ self.sink_tokens = nn.Parameter(
+ torch.randn(self.n_wsa_sinks, self.num_bands, self.dim)
+ )
+ print(f"Using {self.n_wsa_sinks} sink tokens for attention")
+ else:
+ self.sink_tokens = None
+
+
+ def forward(self, raw_audio):
+ """
+ Main forward pass for audio source separation
+
+ einops notation:
+ b - batch
+ f - freq
+ t - time
+ s - audio channel (2 for stereo)
+ n - number of 'stems'
+ c - complex (2)
+ d - feature dimension
+ """
+ # Preprocess input audio
+ raw_audio, batch_info = self._preprocess_audio(raw_audio)
+
+ # Convert to STFT representation
+ stft_repr = self._audio_to_stft(raw_audio, batch_info)
+
+ # Extract features and apply band splitting
+ features = self._extract_features(stft_repr, batch_info)
+
+ # Apply transformer layers with sink tokens
+ processed_features = self._apply_transformer_layers(features)
+
+ # Generate masks for each stem
+ masks = self._generate_masks(processed_features)
+
+ # Apply masks and reconstruct audio
+ recon_audio = self._reconstruct_audio(stft_repr, masks, batch_info)
+
+ return recon_audio
+
+ def _preprocess_audio(self, raw_audio):
+ """Preprocess input audio and validate dimensions"""
+ device = raw_audio.device
+
+ if raw_audio.ndim == 2:
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
+
+ batch, channels, raw_audio_length = raw_audio.shape
+
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
+
+ batch_info = {
+ 'batch_size': batch,
+ 'channels': channels,
+ 'device': device,
+ 'packed_shape': batch_audio_channel_packed_shape
+ }
+
+ return raw_audio, batch_info
+
+ def _audio_to_stft(self, raw_audio, batch_info):
+ """Convert audio to STFT representation"""
+ stft_window = self.stft_window_fn(device=batch_info['device'])
+
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
+ stft_repr = torch.view_as_real(stft_repr)
+
+ stft_repr = unpack_one(stft_repr, batch_info['packed_shape'], '* f t c')
+
+ # Merge stereo/mono into frequency dimension for band splitting
+ stft_repr = rearrange(stft_repr, 'b s f t c -> b (f s) t c')
+
+ return stft_repr
+
+ def _extract_features(self, stft_repr, batch_info):
+ """Extract features using band splitting"""
+ batch_arange = torch.arange(batch_info['batch_size'], device=batch_info['device'])[..., None]
+
+ # Index frequencies for band splitting
+ x = stft_repr[batch_arange, self.freq_indices]
+
+ # Fold complex dimensions into frequency dimension
+ x = rearrange(x, 'b f t c -> b t (f c)')
+
+ # Apply band split
+ x = self.band_split(x)
+
+ return x
+
+ def _apply_transformer_layers(self, features):
+ """Apply transformer layers with sink tokens"""
+ x = features
+
+ # Add sink tokens at the beginning of the time dimension
+ if self.sink_tokens is not None:
+ batch_size = x.shape[0]
+ sinks = repeat(self.sink_tokens, 'n f d -> b n f d', b=batch_size)
+ x = torch.cat([sinks, x], dim=1)
+
+ # Apply axial/hierarchical attention
+ for transformer_block in self.layers:
+ time_transformer, freq_transformer = transformer_block
+
+ # Time transformer
+ x = rearrange(x, 'b t f d -> b f t d')
+ x, ps = pack([x], '* t d')
+ x = time_transformer(x)
+ x, = unpack(x, ps, '* t d')
+ x = rearrange(x, 'b f t d -> b t f d')
+
+ # Frequency transformer
+ x, ps = pack([x], '* f d')
+ x = freq_transformer(x)
+ x, = unpack(x, ps, '* f d')
+
+ # Remove sink tokens before mask estimation
+ if self.sink_tokens is not None:
+ x = x[:, self.n_wsa_sinks:, :, :]
+
+ return x
+
+ def _generate_masks(self, processed_features):
+ """Generate masks for each stem"""
+ masks = torch.stack([fn(processed_features) for fn in self.mask_estimators], dim=1)
+ masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
+ return masks
+
+ def _reconstruct_audio(self, stft_repr, masks, batch_info):
+ """Apply masks and reconstruct audio"""
+ batch = batch_info['batch_size']
+ channels = batch_info['channels']
+ device = batch_info['device']
+ num_stems = len(self.mask_estimators)
+
+ # Prepare STFT for modulation
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
+
+ # Convert to complex representation
+ stft_repr = torch.view_as_complex(stft_repr)
+ masks = torch.view_as_complex(masks)
+ masks = masks.type(stft_repr.dtype)
+
+ # Average masks for overlapping frequencies
+ scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
+ stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
+ masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
+
+ denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
+ masks_averaged = masks_summed / denom.clamp(min=1e-8)
+
+ # Apply masks to STFT
+ stft_repr = stft_repr * masks_averaged
+
+ # Convert back to real representation for ISTFT
+ stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
+
+ # Zero out DC component
+ stft_repr = stft_repr.index_fill(1, torch.tensor(0, device=device), 0.)
+
+ # ISTFT reconstruction
+ stft_window = self.stft_window_fn(device=device)
+ recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=None)
+
+ # Reshape output
+ recon_audio = rearrange(recon_audio, '(b s) t -> b s t', b=batch, s=self.audio_channels)
+
+ return recon_audio
\ No newline at end of file
diff --git a/mvsepless/models/windowed_roformer/modules.py b/mvsepless/models/windowed_roformer/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..c872af4b98e2bb410ff0f1c385ea709996f1b89e
--- /dev/null
+++ b/mvsepless/models/windowed_roformer/modules.py
@@ -0,0 +1,365 @@
+from functools import partial
+
+import torch
+from torch import nn
+from torch.nn import Module, ModuleList
+import torch.nn.functional as F
+
+from einops import rearrange, pack, unpack
+from typing import Tuple
+from functools import wraps
+from packaging import version
+from collections import namedtuple
+import os
+from .flex_attention_utils import (
+ FlexAttention,
+ generate_sliding_window_with_sinks,
+)
+
+def pack_one(t, pattern):
+ return pack([t], pattern)
+
+
+def unpack_one(t, ps, pattern):
+ return unpack(t, ps, pattern)[0]
+
+class RMSNorm(Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.scale = dim ** 0.5
+ self.gamma = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
+
+class FeedForward(Module):
+ def __init__(
+ self,
+ dim,
+ mult=4,
+ dropout=0.
+ ):
+ super().__init__()
+ dim_inner = int(dim * mult)
+ self.net = nn.Sequential(
+ RMSNorm(dim),
+ nn.Linear(dim, dim_inner),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(dim_inner, dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+class Attention(Module):
+ def __init__(
+ self,
+ dim,
+ heads=8,
+ dim_head=64,
+ dropout=0.,
+ rotary_embed=None,
+ flash=True,
+ wsa_window_len=None,
+ n_wsa_sinks=None
+ ):
+ super().__init__()
+ self.heads = heads
+ self.scale = dim_head ** -0.5
+ dim_inner = heads * dim_head
+
+ self.rotary_embed = rotary_embed
+
+
+ self.attend = Attend(flash=flash, dropout=dropout, wsa_window_len=wsa_window_len, n_wsa_sinks=n_wsa_sinks)
+ self.norm = RMSNorm(dim)
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
+
+ self.to_gates = nn.Linear(dim, heads)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(dim_inner, dim, bias=False),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, return_attn=False):
+ x = self.norm(x)
+
+ q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
+
+ if self.rotary_embed is not None:
+ q = self.rotary_embed.rotate_queries_or_keys(q)
+ k = self.rotary_embed.rotate_queries_or_keys(k)
+
+ if return_attn:
+ out, attn = self.attend(q, k, v, return_attn=True)
+ else:
+ out = self.attend(q, k, v)
+
+ gates = self.to_gates(x)
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
+
+ out = rearrange(out, 'b h n d -> b n (h d)')
+ result = self.to_out(out)
+
+ if return_attn:
+ return result, attn
+ return result
+
+
+class Transformer(Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ depth:int = 1,
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.,
+ ff_dropout=0.,
+ ff_mult=4,
+ norm_output=True,
+ rotary_embed=None,
+ use_flash=True,
+ wsa_window_len=None,
+ n_wsa_sinks=None
+ ):
+ super().__init__()
+ self.layers = ModuleList([])
+
+ for _ in range(depth):
+
+ attn = Attention(
+ dim=dim,
+ dim_head=dim_head,
+ heads=heads,
+ dropout=attn_dropout,
+ rotary_embed=rotary_embed,
+ flash=use_flash,
+ wsa_window_len=wsa_window_len,
+ n_wsa_sinks=n_wsa_sinks
+ )
+
+ self.layers.append(ModuleList([
+ attn,
+ FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
+ ]))
+
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
+
+ def forward(self, x):
+
+ for attn, ff in self.layers:
+ x = attn(x) + x
+ x = ff(x) + x
+ return self.norm(x)
+
+class BandSplit(Module):
+ def __init__(
+ self,
+ dim: int,
+ dim_inputs: Tuple[int, ...]
+ ):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_features = ModuleList([])
+
+ for dim_in in dim_inputs:
+ net = nn.Sequential(
+ RMSNorm(dim_in),
+ nn.Linear(dim_in, dim)
+ )
+
+ self.to_features.append(net)
+
+ def forward(self, x):
+ x = x.split(self.dim_inputs, dim=-1)
+
+ outs = []
+ for split_input, to_feature in zip(x, self.to_features):
+ split_output = to_feature(split_input)
+ outs.append(split_output)
+
+ return torch.stack(outs, dim=-2)
+
+
+def MLP(
+ dim_in,
+ dim_out,
+ dim_hidden=None,
+ depth=1,
+ activation=nn.Tanh
+):
+ dim_hidden = dim_hidden if dim_hidden is not None else dim_in
+
+ net = []
+ dims = (dim_in, *((dim_hidden,) * depth), dim_out)
+
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
+ is_last = ind == (len(dims) - 2)
+
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
+
+ if is_last:
+ continue
+
+ net.append(activation())
+
+ return nn.Sequential(*net)
+
+
+class MaskEstimator(Module):
+ def __init__(
+ self,
+ dim,
+ dim_inputs: Tuple[int, ...],
+ depth,
+ mlp_expansion_factor=4
+ ):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_freqs = ModuleList([])
+ dim_hidden = dim * mlp_expansion_factor
+
+ for dim_in in dim_inputs:
+ mlp = nn.Sequential(
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
+ nn.GLU(dim=-1)
+ )
+
+ self.to_freqs.append(mlp)
+
+ def forward(self, x):
+ x = x.unbind(dim=-2)
+
+ outs = []
+
+ for band_features, mlp in zip(x, self.to_freqs):
+ freq_out = mlp(band_features)
+ outs.append(freq_out)
+
+ return torch.cat(outs, dim=-1)
+
+
+
+# constants
+FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
+
+def once(fn):
+ called = False
+ @wraps(fn)
+ def inner(x):
+ nonlocal called
+ if called:
+ return
+ called = True
+ return fn(x)
+ return inner
+
+print_once = once(print)
+
+class Attend(nn.Module):
+ def __init__(
+ self,
+ dropout = 0.,
+ flash = False,
+ scale = None,
+ wsa_window_len = None,
+ n_wsa_sinks = None
+ ):
+ super().__init__()
+ self.scale = scale
+ self.dropout = dropout
+ self.attn_dropout = nn.Dropout(dropout)
+ self.wsa_window_len = wsa_window_len
+ self.n_wsa_sinks = n_wsa_sinks
+ self.use_flash = flash
+
+ # Initialize FlexAttention module if enabled
+ if wsa_window_len is not None and n_wsa_sinks > 0:
+ assert not (version.parse(torch.__version__) < version.parse('2.5.0')), \
+ 'in order to use flex attention, you must be using pytorch 2.5 or above'
+ # Create the appropriate mask function
+ mask_mod = generate_sliding_window_with_sinks(wsa_window_len, n_wsa_sinks)
+
+ # Create FlexAttention module with compilation enabled
+ self.flex_attn = FlexAttention(
+ mask_mod=mask_mod,
+ dropout=dropout,
+ scale=scale,
+ compile=True,
+ )
+ else:
+ self.flex_attn = None
+
+ # Setup for standard flash attention
+ assert not (self.use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), \
+ 'in order to use flash attention, you must be using pytorch 2.0 or above'
+
+ # determine efficient attention configs for cuda and cpu
+ self.cpu_config = FlashAttentionConfig(True, True, True)
+ self.cuda_config = None
+
+ if not torch.cuda.is_available() or not self.use_flash:
+ return
+
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
+ device_version = version.parse(f'{device_properties.major}.{device_properties.minor}')
+
+ if device_version >= version.parse('8.0'):
+ if os.name == 'nt':
+ print_once('Windows OS detected, using math or mem efficient attention if input tensor is on cuda')
+ self.cuda_config = FlashAttentionConfig(False, True, True)
+ else:
+ print_once('GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda')
+ self.cuda_config = FlashAttentionConfig(True, False, False)
+ else:
+ print_once('GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda')
+ self.cuda_config = FlashAttentionConfig(False, True, True)
+
+ def flash_attn(self, q, k, v):
+ _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
+
+ if self.scale is not None:
+ default_scale = q.shape[-1] ** -0.5
+ q = q * (self.scale / default_scale)
+
+ # Check if there is a compatible device for flash attention
+ config = self.cuda_config if is_cuda else self.cpu_config
+
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
+ out = F.scaled_dot_product_attention(
+ q, k, v,
+ dropout_p = self.dropout if self.training else 0.
+ )
+
+ return out
+
+ def forward(self, q, k, v, return_attn=False):
+
+ # Use FlexAttention module if enabled
+ if self.flex_attn is not None:
+ return self.flex_attn(q, k, v)
+
+ # Standard flash attention path
+ if self.use_flash:
+ return self.flash_attn(q, k, v)
+
+ # Fallback: Manual attention computation
+ q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
+ scale = self.scale if self.scale is not None else q.shape[-1] ** -0.5
+
+ # similarity
+ sim = torch.einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
+
+ # attention
+ attn = sim.softmax(dim=-1)
+ attn = self.attn_dropout(attn)
+
+ # aggregate values
+ out = torch.einsum(f"b h i j, b h j d -> b h i d", attn, v)
+
+ return out
\ No newline at end of file
diff --git a/mvsepless/namer.py b/mvsepless/namer.py
new file mode 100644
index 0000000000000000000000000000000000000000..08ba77f0f2621122c8a6c5b6e69e7148c3f1c495
--- /dev/null
+++ b/mvsepless/namer.py
@@ -0,0 +1,112 @@
+import os
+import re
+
+class Namer:
+ def __init__(self, max_length: int = 255, offset: int = 10):
+ if max_length < 40:
+ self.max_length = 40
+ else:
+ self.max_length = max_length
+ if offset < max_length:
+ self.safe_max_length = max_length - offset
+ else:
+ self.safe_max_length = max_length
+
+ def sanitize(self, name: str):
+ """
+ Очищает имя файла от запрещенных символов
+ """
+ sanitized = re.sub(r'[<>:"/\\|?*]', '_', name)
+ sanitized = re.sub(r'_+', '_', sanitized)
+ sanitized = sanitized.strip('_. ')
+ return sanitized
+
+ def short(self, name: str, length: int = None):
+ """
+ Укорачивание имени файла до безопасной максимальной длины
+ """
+ if length:
+ if len(name) > length:
+ return f"{name[:int(length // 2)]}...{name[-int(length // 2.5):]}"
+ else:
+ return name
+ else:
+ if len(name) > self.safe_max_length:
+ return f"{name[:int(self.safe_max_length // 4)]}...{name[-int(self.safe_max_length // 4):]}"
+ else:
+ return name
+
+ def iter(self, filepath: str):
+ """
+ Позволяет избежать перезаписи
+ """
+ if not os.path.exists(filepath):
+ return filepath
+
+ directory, filename = os.path.split(filepath)
+ name, ext = os.path.splitext(filename)
+
+ counter = 1
+ while True:
+ new_filename = f"{name} ({counter}){ext}"
+ new_filepath = os.path.join(directory, new_filename)
+ if not os.path.exists(new_filepath):
+ return new_filepath
+ counter += 1
+
+ def template(self, template: str, **kwargs):
+ """
+ Заменяет ключи на значения (ключи указыаются в именованных аргументах типа KEY="value")
+ """
+ if kwargs:
+ for key in kwargs:
+ template = template.replace(str(key), str(kwargs[key]))
+ return template
+
+ def dedup_template(self, template: str, keys: list = []):
+ """
+ Убирает дубликаты ключей в шаблоне
+ """
+ seen = set()
+ pattern = r"({})".format("|".join(re.escape(key) for key in keys))
+
+ def replace(match):
+ key = match.group(1)
+ if key in seen:
+ return ""
+ seen.add(key)
+ return key
+
+ result = re.sub(pattern, replace, template)
+ return result
+
+ def short_input_name_template(self, template: str, **kwargs):
+ """
+ Укорачивает имя входного файла для шаблона
+ """
+ if kwargs:
+ input_file_name = kwargs.get("NAME", None)
+ if input_file_name:
+ merged_keys_value = ""
+ no_keys_template = template
+ for key in kwargs:
+ if key != "NAME":
+ merged_keys_value += str(kwargs[key])
+ for key in kwargs:
+ no_keys_template = no_keys_template.replace(str(key), "")
+ len_merged_keys = len(merged_keys_value)
+ len_no_keys = len(no_keys_template)
+ free_length = self.safe_max_length - (len_merged_keys + len_no_keys)
+ len_file_name = len(input_file_name)
+ start_index = free_length // 2
+ end_index = free_length // 2.5
+ if len_file_name > free_length:
+ return f"{input_file_name[:int(start_index)]}...{input_file_name[-int(end_index):]}"
+ else:
+ return input_file_name
+ else:
+ print("Не был введен ключ NAME")
+ return ""
+ else:
+ print("Сначала введите ключи")
+ return ""
diff --git a/mvsepless/plugins/chainless.py b/mvsepless/plugins/chainless.py
new file mode 100644
index 0000000000000000000000000000000000000000..1df3eb9b76c37d1465b83dfb786a45f30ae91851
--- /dev/null
+++ b/mvsepless/plugins/chainless.py
@@ -0,0 +1,443 @@
+import gradio as gr
+import os, sys, subprocess
+import pandas as pd
+from datetime import datetime
+import tempfile
+import json
+if not __package__:
+ from __init__ import Separator
+ from downloader import dw_yt_dlp
+else:
+ from .. import Separator
+ from ..downloader import dw_yt_dlp
+
+class Plugin(Separator):
+ def __init__(self):
+ self.name = "Авто-цепочка разделений"
+ self.requirements = []
+ self.install_requirements(self.requirements)
+
+ def install_requirements(self, requirements: list):
+ if requirements:
+ cmd = [os.sys.executable, "-m", "pip", "install"]
+ for pkg in requirements:
+ cmd.append(pkg)
+ result = subprocess.run(cmd, text=True, capture_output=True)
+
+
+ class ModelManager(Separator):
+ def __init__(self):
+ self.data = []
+ self.dir_presets = os.path.join(tempfile.tempdir, "presets")
+ os.makedirs(self.dir_presets, exist_ok=True)
+
+ def save(self, name):
+ if not name:
+ name = "chainless_preset"
+ filepath = os.path.join(self.dir_presets, f"{self.namer.short(self.namer.sanitize(name), length=50)}.json")
+ with open(filepath, "w") as f:
+ json.dump(self.data, f, indent=4, ensure_ascii=False)
+ return filepath
+
+ def load(self, filepath):
+ with open(filepath, "r") as f:
+ self.data = json.load(f)
+
+ def add(self, mt, mn, s_stem, out_stem, int_stem):
+ if int_stem:
+ self.data.append((mt, mn, s_stem, out_stem, int_stem))
+
+ def replace(self, mt, mn, s_stem, out_stem, int_stem, index=1):
+ if self.data:
+ len_data = len(self.data)
+ if index >= 1:
+ if index <= len_data:
+ self.data[index - 1] = (mt, mn, s_stem, out_stem, int_stem)
+ elif index == 0:
+ self.data[0] = (mt, mn, s_stem, out_stem, int_stem)
+
+ def remove(self, index=1):
+ if self.data:
+ len_data = len(self.data)
+ if index >= 1:
+ if index <= len_data:
+ del self.data[index - 1]
+ elif index == 0:
+ del self.data[0]
+
+ def clear(self):
+ self.data = []
+
+ def get_df(self):
+ if not self.data:
+ columns = ["#", "Имя модели", "Выбранные стемы", "Остаток", "Промежуточный стем"]
+ return pd.DataFrame(columns=columns)
+
+ data = []
+ for i, model in enumerate(self.data):
+ data.append(
+ [
+ f"{i+1}",
+ model[1],
+ str(model[2]),
+ str(model[3]),
+ model[4],
+ ]
+ )
+ columns = ["#", "Имя модели", "Выбранные стемы", "Остаток", "Промежуточный стем"]
+ return pd.DataFrame(data, columns=columns)
+
+ def UI(self):
+ def get_output_stems(mt, mn, s_stem):
+ output_stems = []
+ stems = self.model_manager.get_stems(mt, mn)
+ if not s_stem:
+ for stem in stems:
+ output_stems.append(stem)
+ if set(stems) == {"bass", "drums", "vocals", "other"} or set(stems) == {"bass", "drums", "vocals", "other", "guitar", "piano"}:
+ output_stems.append("instrumental +")
+ output_stems.append("instrumental -")
+ elif s_stem:
+ if len(stems) > 2:
+ for stem in stems:
+ if stem in s_stem:
+ output_stems.append(stem)
+ output_stems.append("inverted -")
+ if len(stems) - len(s_stem) > 1:
+ output_stems.append("inverted +")
+
+ return output_stems
+
+ def get_invert_output_stems(mt, mn, s_stem, out_stem):
+ output_stems = []
+ stems = get_output_stems(mt, mn, s_stem)
+ for stem in stems:
+ if stem not in out_stem:
+ output_stems.append(stem)
+ return output_stems
+
+ default_model = {
+ "mt": self.model_manager.get_mt(),
+ "mn": self.model_manager.get_mn(self.model_manager.get_mt()[0]),
+ "stems": self.model_manager.get_stems(
+ self.model_manager.get_mt()[0],
+ self.model_manager.get_mn(self.model_manager.get_mt()[0])[0],
+ ),
+ "output_stems": get_output_stems(
+ self.model_manager.get_mt()[0],
+ self.model_manager.get_mn(self.model_manager.get_mt()[0])[0],
+ []
+ ),
+ "int_stems": get_invert_output_stems(
+ self.model_manager.get_mt()[0],
+ self.model_manager.get_mn(self.model_manager.get_mt()[0])[0],
+ [],
+ "vocals",
+ ),
+ }
+
+ chain_manager = self.ModelManager()
+
+ gr.Markdown("Пресет
")
+ with gr.Group():
+ with gr.Row(equal_height=True):
+ export_preset_name = gr.Textbox(
+ label="Имя пресета",
+ interactive=True,
+ value="chainless_preset", scale=9
+ )
+ export_btn = gr.DownloadButton("Экспорт", variant="secondary", scale=3, interactive=True)
+ import_btn = gr.UploadButton(
+ "Импорт", file_types=[".json"], file_count="single", scale=3, interactive=True
+ )
+ gr.Markdown("Цепочка разделений
")
+ with gr.Row():
+ with gr.Column(scale=3): # логика добавлеия моделей
+ model_type = gr.Dropdown(label="Тип модели", choices=default_model["mt"], value=default_model["mt"][0], interactive=True, filterable=False)
+ model_name = gr.Dropdown(label="Имя модели", choices=default_model["mn"], value=default_model["mn"][0], interactive=True, filterable=False)
+ selected_stems = gr.Dropdown(label="Выбранные стемы", choices=default_model["stems"], value=[], multiselect=True, interactive=False, filterable=False)
+ output_stems = gr.Dropdown(label="Остаток", choices=default_model["output_stems"], value=[default_model["output_stems"][0]], multiselect=True, interactive=True, filterable=False)
+ intermediate_stem = gr.Dropdown(label="Проиежуточный стем", choices=default_model["int_stems"], value=default_model["int_stems"][0], interactive=True, filterable=False)
+ @model_type.change(
+ inputs=[model_type],
+ outputs=[model_name]
+ )
+ def update_model_names(mt):
+ model_names = self.model_manager.get_mn(mt)
+ new_mn = model_names[0] if model_names else ""
+
+ return gr.update(choices=model_names, value=new_mn)
+ @model_name.change(
+ inputs=[model_type, model_name],
+ outputs=[selected_stems, output_stems, intermediate_stem]
+ )
+ def update_stems_after_model_change(mt, mn):
+ stems = self.model_manager.get_stems(mt, mn)
+ _output_stems = get_output_stems(mt, mn, [])
+ invert_stems = get_invert_output_stems(mt, mn, [], [_output_stems[0]])
+
+ new_out_stem = [_output_stems[0]]
+ new_inv_out_stem = invert_stems[0]
+
+ return (
+ gr.update(choices=stems, value=[], interactive=False if len(stems) <= 2 else True),
+ gr.update(choices=_output_stems, value=new_out_stem, max_choices=len(_output_stems) - 1),
+ gr.update(choices=invert_stems, value=new_inv_out_stem)
+ )
+
+ @selected_stems.change(
+ inputs=[model_type, model_name, selected_stems],
+ outputs=[output_stems, intermediate_stem]
+ )
+ def update_invert_stems(mt, mn, s_stem):
+ stems = get_output_stems(mt, mn, s_stem)
+ new_i_stem = [stems[0]]
+ invert_stems = get_invert_output_stems(mt, mn, s_stem, new_i_stem)
+ return gr.update(choices=stems, value=new_i_stem), gr.update(choices=invert_stems, value=invert_stems[0])
+
+ @output_stems.change(
+ inputs=[model_type, model_name, selected_stems, output_stems],
+ outputs=[intermediate_stem]
+ )
+ def update_invert2_stems(mt, mn, s_stem, out_stem):
+ invert_stems = get_invert_output_stems(mt, mn, s_stem, out_stem)
+ return gr.update(choices=invert_stems, value=invert_stems[0])
+
+ model_add_button = gr.Button("Добавить", interactive=True)
+
+
+ with gr.Column(scale=10):
+ df = gr.DataFrame(
+ value=chain_manager.get_df(),
+ headers=["#", "Имя модели", "Выбранные стемы", "Остаток", "Промежуточный стем"],
+ datatype=["number", "str", "str", "str", "str"],
+ interactive=False
+ )
+
+ with gr.Group():
+ with gr.Row(equal_height=True):
+ with gr.Column():
+ model_index = gr.Number(label="Индекс модели", value=1, interactive=True)
+ model_clear_btn = gr.Button("Очистить", variant="stop", interactive=True)
+ with gr.Column():
+ model_replace_btn = gr.Button("Заменить", variant="primary", interactive=True)
+ model_delete_btn = gr.Button("Удалить", variant="stop", interactive=True)
+
+ @model_add_button.click(
+ inputs=[model_type, model_name, selected_stems, output_stems, intermediate_stem],
+ outputs=df
+ )
+ def add_model_to_auto_ensemble(mt, mn, s_stem, out_stem, int_stem):
+ chain_manager.add(mt, mn, s_stem, out_stem, int_stem)
+ return chain_manager.get_df()
+
+ @model_replace_btn.click(
+ inputs=[model_type, model_name, selected_stems, output_stems, intermediate_stem, model_index],
+ outputs=df
+ )
+ def replace_model_to_auto_ensemble(mt, mn, s_stem, out_stem, int_stem, index):
+ chain_manager.replace(mt, mn, s_stem, out_stem, int_stem, index)
+ return chain_manager.get_df()
+
+ @model_delete_btn.click(
+ inputs=[model_index],
+ outputs=df
+ )
+ def delete_model_to_auto_ensemble(index):
+ chain_manager.remove(index)
+ return chain_manager.get_df()
+
+ @model_clear_btn.click(
+ outputs=df
+ )
+ def clear_model_to_auto_ensemble():
+ chain_manager.clear()
+ return chain_manager.get_df()
+
+ gr.on(fn=chain_manager.get_df, outputs=df)
+
+ df.change(
+ fn=chain_manager.save,
+ inputs=export_preset_name,
+ outputs=export_btn
+ )
+
+ export_preset_name.change(
+ fn=chain_manager.save,
+ inputs=export_preset_name,
+ outputs=export_btn
+ )
+
+ @import_btn.upload(
+ inputs=import_btn,
+ outputs=df
+ )
+ def load_ensemble_preset(filepath):
+ chain_manager.load(filepath)
+ return chain_manager.get_df()
+
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown("Входное аудио
")
+ with gr.Group():
+ with gr.Group(visible=False) as add_inputs:
+ input_path = gr.Textbox(label="Путь к входному файлу", interactive=True)
+ add_inputs_btn = gr.Button("Загрузить файл", variant="primary")
+ with gr.Group(visible=False) as add_inputs_from_url:
+ input_url = gr.Textbox(label="URL входного файла", interactive=True)
+ with gr.Row():
+ inputs_url_format = gr.Dropdown(label="Формат входного файла", interactive=True,
+ choices=self.audio.output_formats,
+ value="mp3", filterable=False)
+ inputs_url_bitrate = gr.Slider(label="Битрейт входного файла", minimum=64, maximum=512, step=32, value=320, interactive=True)
+ with gr.Row():
+ inputs_url_cookie = gr.UploadButton(label="Файл cookie (необязательно)", interactive=True, type="filepath", file_count="single", file_types=[".txt", ".cookies"], variant="secondary")
+ add_inputs_url_btn = gr.Button("Загрузить файл", variant="primary")
+ with gr.Row(visible=True) as add_buttons_row:
+ add_path_btn = gr.Button("Загрузить файл по пути", variant="secondary")
+ add_url_btn = gr.Button("Загрузить файл по URL", variant="secondary")
+ with gr.Group():
+ input_audio = gr.File(label="Входное аудио", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.audio.input_formats])
+ with gr.Column():
+ gr.Markdown("Настройки
")
+ with gr.Group():
+ save_only_last_intermediate_stem_check = gr.Checkbox(label="Сохранить только последний промежуточный стем", interactive=True, value=False)
+ output_format = gr.Dropdown(label="Формат выходного файла", interactive=True,
+ choices=self.audio.output_formats,
+ value="mp3", filterable=False)
+ run_btn = gr.Button("Создать цепочку разделений", variant="primary", interactive=True)
+
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown("Результаты
")
+ last_intermediate_stem = gr.Audio(label="Последний промежуточный стем", type="filepath", interactive=False, show_download_button=True)
+ with gr.Group():
+ invert_method = gr.Radio(
+ choices=["waveform", "spectrogram"],
+ label="Метод создания инверсии",
+ value="waveform",
+ )
+ invert_btn = gr.Button("Инвертировать")
+ output_inverted_audio = gr.Audio(label="Инверсия", type="filepath", interactive=False, show_download_button=True)
+ @invert_btn.click(inputs=[input_audio, last_intermediate_stem, invert_method, output_format], outputs=[output_inverted_audio])
+ def invert_result_ensemble(input_file, output_file, method, out_format):
+ if input_file and output_file:
+ o_dir = os.path.dirname(output_file)
+ basename = os.path.splitext(os.path.basename(input_file))[0]
+ output_path = os.path.join(o_dir, f"chainless_{self.namer.short(basename, length=50)}_{method}_invert.{out_format}")
+ inverted = self.inverter.process_audio(audio1_path=input_file, audio2_path=output_file, out_format=out_format, method=method, output_path=output_path)
+ return inverted
+ else:
+ return None
+
+ with gr.Column():
+ gr.Markdown("Исходники цепочки
")
+ output_source_files = gr.Files(type="filepath", interactive=False, show_label=False)
+ output_source_preview_check = gr.Checkbox(label="Показать плееры для исходников цепочки", interactive=True, value=False)
+ @gr.render(inputs=[output_source_preview_check, output_source_files])
+ def show_output_auto_ensemble_players(preview, audios):
+ if preview:
+ if audios:
+ with gr.Group():
+ for file in audios:
+ gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath")
+
+
+ @run_btn.click(
+ inputs=[input_audio, output_format, save_only_last_intermediate_stem_check],
+ outputs=[last_intermediate_stem, output_source_files]
+ )
+ def chain(
+ input_audio,
+ out_format,
+ save_only_last_intermediate_stem,
+ progress=gr.Progress(track_tqdm=True)
+ ):
+
+ input_settings = chain_manager.data
+
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ o = tempfile.mkdtemp(prefix=f"chainless_outputs_{timestamp}_")
+ os.makedirs(o, exist_ok=True)
+
+ base_name = os.path.splitext(os.path.basename(input_audio))[0]
+ last_intermediate_stem = None
+ last_intermediate_stem_name = None
+ _output_stems = []
+
+ block_count = len(input_settings)
+
+ for i, model in enumerate(input_settings, start=1):
+ input_model_type = model[0]
+ input_model_name = model[1]
+ selected_stems = model[2]
+ selected_output_stems = model[3]
+ intermediate_stem = model[4]
+ output_p = self.separate(
+ input=(
+ input_audio
+ if not last_intermediate_stem
+ else last_intermediate_stem
+ ),
+ output_dir=os.path.join(o, input_model_name),
+ model_type=input_model_type,
+ model_name=input_model_name,
+ ext_inst=True,
+ output_format=out_format,
+ template=f"NAME_{i}_{f'({i - 1}_{str(last_intermediate_stem_name)})_' if last_intermediate_stem_name else ''}MODEL_STEM",
+ selected_stems=selected_stems,
+ add_settings={"add_single_sep_text_progress": f"{i} из {block_count}"},
+ progress=progress
+ )
+ for stem, file in output_p:
+ if stem in selected_output_stems:
+ _output_stems.append(file)
+ elif stem == intermediate_stem:
+ last_intermediate_stem = file
+ last_intermediate_stem_name = stem
+ _output_stems.append(file)
+
+ return last_intermediate_stem, _output_stems if not save_only_last_intermediate_stem else []
+
+ @add_inputs_btn.click(
+ inputs=[input_path, input_audio],
+ outputs=[add_inputs, input_audio, add_buttons_row])
+ def add_inputs_fn(input_p, input_a):
+ if input_p and os.path.exists(input_p):
+ if input_a is None:
+ input_a = None
+ if self.audio.check(input_p):
+ input_a = input_p
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
+
+ @add_inputs_url_btn.click(
+ inputs=[input_url, input_audio, inputs_url_format, inputs_url_bitrate, inputs_url_cookie],
+ outputs=[add_inputs_from_url, input_audio, add_buttons_row])
+ def add_inputs_from_url_fn(input_u, input_a, fmt, br, cookie):
+ if input_u:
+ if input_a is None:
+ input_a = None
+ downloaded_file = dw_yt_dlp(
+ url=input_u,
+ output_format=fmt,
+ output_bitrate=str(int(br)),
+ cookie=cookie
+ )
+ if downloaded_file and os.path.exists(downloaded_file):
+ if self.audio.check(downloaded_file):
+ input_a = downloaded_file
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
+ return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
+
+ add_path_btn.click(
+ lambda: (gr.update(visible=True), gr.update(visible=False)),
+ outputs=[add_inputs, add_buttons_row])
+
+ add_url_btn.click(
+ lambda: (gr.update(visible=True), gr.update(visible=False)),
+ outputs=[add_inputs_from_url, add_buttons_row])
+
+ inputs_url_format.change(lambda x: gr.update(visible=False if x in ["wav", "flac", "aiff"] else True), inputs=inputs_url_format, outputs=inputs_url_bitrate)
+
+
diff --git a/mvsepless/plugins/matchering.py b/mvsepless/plugins/matchering.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a9d1cd4afd99236bbbfad7f2bab439b8b450088
--- /dev/null
+++ b/mvsepless/plugins/matchering.py
@@ -0,0 +1,93 @@
+import gradio as gr
+import os, sys, subprocess
+if not __package__:
+ from audio import Audio
+else:
+ from ..audio import Audio
+import tempfile
+
+class Plugin(Audio):
+ def __init__(self):
+ super().__init__()
+ self.name = "Matchering"
+ self.requirements = ["matchering"]
+ self.install_requirements(self.requirements)
+
+ def install_requirements(self, requirements: list):
+ if requirements:
+ cmd = [os.sys.executable, "-m", "pip", "install"]
+ for pkg in requirements:
+ cmd.append(pkg)
+ result = subprocess.run(cmd, text=True, capture_output=True)
+
+ def match(self, target_audio: str = None, reference_audio: str = None, output_path: str = "matchered.mp3", output_format: str = "mp3"):
+ from matchering.log import Code, info, debug, debug_line, ModuleError
+ from matchering import Config, Result
+ from matchering.stages import main
+ from matchering.checker import check, check_equality
+ from matchering.dsp import channel_count, size
+ config = Config()
+ target, target_sample_rate, _ = self.read(target_audio, mono=False, sr=44100)
+ target = target.T
+ target, target_sample_rate = check(target, target_sample_rate, config, "target")
+ reference, reference_sample_rate, _ = self.read(reference_audio, mono=False, sr=44100)
+ reference = reference.T
+ reference, reference_sample_rate = check(
+ reference, reference_sample_rate, config, "reference"
+ )
+
+ if not config.allow_equality:
+ check_equality(target, reference)
+
+ if (
+ not (target_sample_rate == reference_sample_rate == config.internal_sample_rate)
+ or not (channel_count(target) == channel_count(reference) == 2)
+ or not (size(target) > config.fft_size and size(reference) > config.fft_size)
+ ):
+ raise ModuleError(Code.ERROR_VALIDATION)
+
+ result, result_no_limiter, result_no_limiter_normalized = main(
+ target,
+ reference,
+ config,
+ need_default=True,
+ need_no_limiter=False,
+ need_no_limiter_normalized=False,
+ )
+
+ output_path = self.write(o=output_path, array=result, of=output_format, sr=config.internal_sample_rate)
+ return output_path
+
+ def UI(self):
+ with gr.Group():
+ with gr.Row():
+ target = gr.File(label="Целевое аудио", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.input_formats])
+ reference = gr.File(label="Референс", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.input_formats])
+ @gr.render(inputs=[target, reference])
+ def check_audio_files(tgt, rfr):
+ if tgt and rfr:
+ info_tgt = self.get_info(i=tgt)
+ info_rfr = self.get_info(i=rfr)
+ sr_tgt = info_tgt[0]["sample_rate"]
+ sr_rfr = info_rfr[0]["sample_rate"]
+ status = f"{sr_tgt} = {sr_rfr}" if sr_tgt == sr_rfr else f"{sr_tgt} != {sr_rfr}"
+ gr.Markdown(f"{status}
")
+
+ output_format = gr.Dropdown(label="Формат выходного файла", interactive=True,
+ choices=self.output_formats,
+ value="mp3", filterable=False)
+ matchered_audio = gr.Audio(label="Результат", type="filepath", show_download_button=True, interactive=False)
+ match_btn = gr.Button("Сопоставить")
+ @match_btn.click(
+ inputs=[target, reference, output_format],
+ outputs=[matchered_audio]
+ )
+ def match_audio(tgt, rfr, of):
+ if tgt and rfr:
+ o_dir = tempfile.mkdtemp(prefix="matchering")
+ basename = os.path.splitext(os.path.basename(tgt))[0]
+ output_path = os.path.join(o_dir, f"{self.short(basename, length=120)}_matchered.{of}")
+ matchered = self.match(target_audio=tgt, reference_audio=rfr, output_path=output_path, output_format=of)
+ return matchered
+ else:
+ return None
\ No newline at end of file
diff --git a/mvsepless/plugins/remove_center.py b/mvsepless/plugins/remove_center.py
new file mode 100644
index 0000000000000000000000000000000000000000..d48763f2bdec1ff6fd09b7157927a170c8baf2d3
--- /dev/null
+++ b/mvsepless/plugins/remove_center.py
@@ -0,0 +1,259 @@
+import gradio as gr
+import os, sys, subprocess
+import tempfile
+from scipy import signal
+import numpy as np
+from datetime import datetime
+if not __package__:
+ from audio import Audio
+else:
+ from ..audio import Audio
+
+class Plugin(Audio):
+ def __init__(self):
+ super().__init__()
+ self.name = "Вычитание фантомного центра"
+ self.requirements = []
+ self.install_requirements(self.requirements)
+ self.w_types = [
+ "boxcar", # Прямоугольное окно
+ "triang", # Треугольное окно
+ "blackman", # Окно Блэкмана
+ "hamming", # Окно Хэмминга
+ "hann", # Окно Ханна
+ "bartlett", # Окно Бартлетта
+ "flattop", # Окно с плоской вершиной
+ "parzen", # Окно Парзена
+ "bohman", # Окно Бохмана
+ "blackmanharris", # Окно Блэкмана-Харриса
+ "nuttall", # Окно Нуттала
+ "barthann", # Окно Бартлетта-Ханна
+ "cosine", # Косинусное окно
+ "exponential", # Экспоненциальное окно
+ "tukey", # Окно Туки
+ "taylor", # Окно Тейлора
+ "lanczos", # Окно Ланцоша
+ ]
+
+ def install_requirements(self, requirements: list):
+ if requirements:
+ cmd = [os.sys.executable, "-m", "pip", "install"]
+ for pkg in requirements:
+ cmd.append(pkg)
+ result = subprocess.run(cmd, text=True, capture_output=True)
+
+ def remove_center(
+ self,
+ input_file,
+ output_format="flac",
+ out_center="center.flac",
+ out_stereo_base="stereo_base.flac",
+ rdf=0.99999,
+ window_size=4096,
+ overlap=2,
+ window_type="blackman",
+ stereo_mode="stereo",
+ ):
+ output_file = out_stereo_base
+ output_center_file = out_center
+ data, samplerate, _ = self.read(i=input_file, mono=False, sr=None)
+
+ if data.ndim != 2 or data.shape[0] != 2:
+ raise ValueError("Требуется стереофайл (2 канала)")
+
+ left = data[0, :]
+ right = data[1, :]
+ mono = left * 0.5 + right * 0.5
+
+ nperseg = window_size # Размер окна
+ noverlap = nperseg // overlap # Перекрытие окон
+
+ f, t, Z_left = signal.stft(
+ left, fs=samplerate, nperseg=nperseg, noverlap=noverlap, window=window_type
+ )
+ f, t, Z_right = signal.stft(
+ right, fs=samplerate, nperseg=nperseg, noverlap=noverlap, window=window_type
+ )
+ f, t, Z_mono = signal.stft(
+ mono, fs=samplerate, nperseg=nperseg, noverlap=noverlap, window=window_type
+ )
+ if stereo_mode == "mono":
+ Z_common_left = np.minimum(np.abs(Z_left), np.abs(Z_right)) * np.exp(
+ 1j * np.angle(Z_mono)
+ )
+ Z_common_right = np.minimum(np.abs(Z_left), np.abs(Z_right)) * np.exp(
+ 1j * np.angle(Z_mono)
+ )
+ else:
+ Z_common_left = np.minimum(np.abs(Z_left), np.abs(Z_right)) * np.exp(
+ 1j * np.angle(Z_right)
+ )
+ Z_common_right = np.minimum(np.abs(Z_left), np.abs(Z_right)) * np.exp(
+ 1j * np.angle(Z_left)
+ )
+
+ reduction_factor = rdf
+
+ Z_new_left = Z_left - Z_common_left * reduction_factor
+ Z_new_right = Z_right - Z_common_right * reduction_factor
+
+ _, new_left = signal.istft(
+ Z_new_left,
+ fs=samplerate,
+ nperseg=nperseg,
+ noverlap=noverlap,
+ window=window_type,
+ )
+ _, new_right = signal.istft(
+ Z_new_right,
+ fs=samplerate,
+ nperseg=nperseg,
+ noverlap=noverlap,
+ window=window_type,
+ )
+
+ _, common_signal_left = signal.istft(
+ Z_common_left,
+ fs=samplerate,
+ nperseg=nperseg,
+ noverlap=noverlap,
+ window=window_type,
+ )
+ _, common_signal_right = signal.istft(
+ Z_common_right,
+ fs=samplerate,
+ nperseg=nperseg,
+ noverlap=noverlap,
+ window=window_type,
+ )
+
+ new_left = new_left[: len(left)]
+ new_right = new_right[: len(right)]
+ common_signal_left = common_signal_left[: len(left)]
+ common_signal_right = common_signal_right[: len(right)]
+
+ peak = np.max([np.abs(new_left).max(), np.abs(new_right).max()])
+ if peak > 1.0:
+ new_left = new_left / peak
+ new_right = new_right / peak
+
+ output_file = self.write(
+ o=output_file,
+ array=np.column_stack((new_left, new_right)),
+ sr=samplerate,
+ of=output_format,
+ br="320k",
+ )
+
+ inverted_center_left = -common_signal_left
+ inverted_center_right = -common_signal_right
+
+ mixed_left = left + inverted_center_left
+ mixed_right = right + inverted_center_right
+
+ peak_mixed = np.max([np.abs(mixed_left).max(), np.abs(mixed_right).max()])
+ if peak_mixed > 1.0:
+ mixed_left = mixed_left / peak_mixed
+ mixed_right = mixed_right / peak_mixed
+
+ output_center_file = self.write(
+ o=output_center_file,
+ array=np.column_stack((common_signal_left, common_signal_right)),
+ sr=samplerate,
+ of=output_format,
+ br="320k",
+ )
+
+ return (output_file, output_center_file)
+
+ def UI(self):
+ with gr.Row():
+ rmv_center_ui_input_audio = gr.File(label="Входное аудио", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.input_formats])
+ with gr.Group():
+ with gr.Row():
+ rmv_center_ui_reduction_f = gr.Slider(
+ 0.1,
+ 10,
+ value=1,
+ step=0.1,
+ label="Фактор подавления",
+ interactive=True,
+ visible=False,
+ )
+ rmv_center_ui_overlap = gr.Slider(
+ 2,
+ 30,
+ value=2,
+ step=1,
+ label="Перекрытие",
+ interactive=True,
+ visible=True,
+ )
+ rmv_center_ui_window_size = gr.Number(
+ label="Размер окна",
+ interactive=True,
+ visible=True,
+ minimum=32,
+ maximum=882000,
+ precision=1,
+ value=2048,
+ )
+ with gr.Row():
+ rmv_center_ui_format = gr.Dropdown(
+ self.output_formats,
+ value=self.output_formats[0],
+ filterable=False,
+ label="Формат выходного файла",
+ interactive=True
+ )
+ rmv_center_ui_window_types = gr.Dropdown(
+ self.w_types,
+ value=self.w_types[4],
+ filterable=False,
+ label="Тип окна",
+ interactive=True
+ )
+
+ rmv_center_ui_mono_mode = gr.Dropdown(
+ ["mono", "stereo"],
+ value="mono",
+ filterable=False,
+ label="Стерео-режим",
+ interactive=True
+ )
+ rmv_center_ui_extract_btn = gr.Button("Разделить")
+ with gr.Group():
+ with gr.Column():
+ with gr.Row():
+ rmv_center_ui_mid = gr.Audio(
+ type="filepath",
+ interactive=False,
+ label="Фантомный центр",
+ visible=True,
+ show_download_button=True,
+ )
+ rmv_center_ui_side = gr.Audio(
+ type="filepath",
+ interactive=False,
+ label="Стерео-база",
+ visible=True,
+ show_download_button=True,
+ )
+ @rmv_center_ui_extract_btn.click(
+ inputs=[
+ rmv_center_ui_input_audio,
+ rmv_center_ui_format,
+ rmv_center_ui_reduction_f,
+ rmv_center_ui_window_size,
+ rmv_center_ui_overlap,
+ rmv_center_ui_window_types,
+ rmv_center_ui_mono_mode,
+ ],
+ outputs=[rmv_center_ui_side, rmv_center_ui_mid],
+ )
+ def wrap_remove_center(input_audio, output_format, rf, ws, ovlp, wt, mono_mode):
+ if input_audio:
+ temp_dir = tempfile.mkdtemp(prefix="remove_center_")
+ basename = self.short(os.path.splitext(os.path.basename(input_audio))[0], length=80)
+ side, mid = self.remove_center(input_file=input_audio, output_format=output_format, out_center=os.path.join(temp_dir, f"{basename}_center.{output_format}"), out_stereo_base=os.path.join(temp_dir, f"{basename}_stereo_base.{output_format}"), overlap=ovlp, rdf=rf, stereo_mode=mono_mode, window_size=ws, window_type=wt)
+ return side, mid
\ No newline at end of file
diff --git a/mvsepless/plugins/test.py b/mvsepless/plugins/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d4e2de68a41308108f66458ce3e17896f62fc09
--- /dev/null
+++ b/mvsepless/plugins/test.py
@@ -0,0 +1,31 @@
+import gradio as gr
+import os, sys, subprocess
+if not __package__:
+ from __init__ import Separator
+else:
+ from .. import Separator
+
+class Plugin(Separator):
+ def __init__(self):
+ self.name = "Тестовый плагин"
+ self.requirements = []
+ self.install_requirements(self.requirements)
+
+ def install_requirements(self, requirements: list):
+ if requirements:
+ cmd = [os.sys.executable, "-m", "pip", "install"]
+ for pkg in requirements:
+ cmd.append(pkg)
+ result = subprocess.run(cmd, text=True, capture_output=True)
+
+ def test(self):
+ print("Тест")
+ print(self.model_manager.get_mt())
+
+ def UI(self):
+ with gr.Column():
+ gr.Markdown("Пример рабочего плагина
")
+ gr.Button("Показать все типы моделей", variant="primary").click(self.test)
+
+
+
\ No newline at end of file
diff --git a/mvsepless/plugins/whatbpm.py b/mvsepless/plugins/whatbpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9115af91c196209fc0ef11df2d892eb5237b8339
--- /dev/null
+++ b/mvsepless/plugins/whatbpm.py
@@ -0,0 +1,89 @@
+import gradio as gr
+import os, sys, subprocess
+import librosa
+if not __package__:
+ from audio import Audio
+else:
+ from ..audio import Audio
+
+class Plugin(Audio):
+ def __init__(self):
+ super().__init__()
+ self.name = "Узнать темп (через Librosa)"
+ self.requirements = []
+ self.install_requirements(self.requirements)
+
+ def install_requirements(self, requirements: list):
+ if requirements:
+ cmd = [os.sys.executable, "-m", "pip", "install"]
+ for pkg in requirements:
+ cmd.append(pkg)
+ result = subprocess.run(cmd, text=True, capture_output=True)
+
+ def get_tempo(self, input_audio) -> float:
+ y, sr, _ = self.read(i=input_audio, sr=44100, mono=True)
+ tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sr)
+ return float(tempo)
+
+ def UI(self):
+ with gr.Tab("Узнать темп из множества файлов"):
+ with gr.Row():
+ with gr.Column():
+ input_audio = gr.File(label="Аудио", interactive=True, type="filepath", file_count="multiple", file_types=[f".{of}" for of in self.input_formats])
+ input_preview_check = gr.Checkbox(label="Показать плееры для входных аудио", interactive=True, value=False)
+ @gr.render(inputs=[input_preview_check, input_audio])
+ def show_input_players(preview, audios):
+ if preview:
+ if audios:
+ with gr.Group():
+ for file in audios:
+ gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath")
+
+ @gr.render(inputs=[input_audio])
+ def get_tempos_from_multiple_audios(audio_list):
+ status = f"Результаты анализа темпа\n---"
+ if audio_list:
+ for file in audio_list:
+ if self.check(file):
+ basename = os.path.splitext(os.path.basename(file))[0]
+ status += f"\n {self.short(basename, length=30)} - {self.get_tempo(file)}"
+ gr.Textbox(container=False, interactive=False, value=status, lines=len(status.split("\n")))
+
+ with gr.Tab("Сравнить темпы двух файлов"):
+ with gr.Row():
+ audio1 = gr.File(label="Аудио 1", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.input_formats])
+ audio2 = gr.File(label="Аудио 2", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.input_formats])
+ @gr.render(inputs=[audio1, audio2])
+ def compare_tempos(a1, a2):
+ if a1 and a2:
+ if self.check(a1) and self.check(a2):
+ tempo1 = self.get_tempo(a1)
+ tempo2 = self.get_tempo(a2)
+ with gr.Group():
+ with gr.Row(variant="compact"):
+ with gr.Column(scale=1, min_width=80):
+ gr.Markdown("Темп 1
")
+ t1 = gr.Number(container=False, value=tempo1, interactive=False)
+ with gr.Row():
+ t1_div = gr.Button("/2", interactive=True, variant="stop", min_width=20)
+ t1_multiplic = gr.Button("*2", interactive=True, min_width=20)
+ t1_div.click(lambda x: float(x) / 2, inputs=t1, outputs=t1)
+ t1_multiplic.click(lambda x: float(x) * 2, inputs=t1, outputs=t1)
+ with gr.Column(scale=1, min_width=80):
+ gr.Markdown("Темп 2
")
+ t2 = gr.Number(container=False, value=tempo2, interactive=False)
+ with gr.Row():
+ t2_div = gr.Button("/2", interactive=True, variant="stop", min_width=20)
+ t2_multiplic = gr.Button("*2", interactive=True, min_width=20)
+ t2_div.click(lambda x: x / 2, inputs=t2, outputs=t2)
+ t2_multiplic.click(lambda x: x * 2, inputs=t2, outputs=t2)
+ compare_result = gr.Textbox(container=False, interactive=False, lines=2)
+ compare_btn = gr.Button("Сравнить", variant="primary", interactive=True)
+ @compare_btn.click(
+ inputs=[t1, t2],
+ outputs=compare_result
+ )
+ def compare_tempo_fn(y1, y2):
+ result = f""" Темп 1 / Темп 2 = {y1 / y2}
+Темп 2 / Темп 1 = {y2 / y1}"""
+ return result
\ No newline at end of file
diff --git a/mvsepless/vbach_infer.py b/mvsepless/vbach_infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..88cbdd38289ad02c845c704cf343d6170a68b21a
--- /dev/null
+++ b/mvsepless/vbach_infer.py
@@ -0,0 +1,1207 @@
+import os
+import gc
+import torch
+import torch.nn.functional as F
+import torchcrepe
+import faiss
+import librosa
+import numpy as np
+from scipy import signal
+import argparse
+script_dir = os.path.dirname(os.path.abspath(__file__))
+
+FILTER_ORDER = 5
+CUTOFF_FREQUENCY = 48
+SAMPLE_RATE = 16000
+bh, ah = signal.butter(N=FILTER_ORDER, Wn=CUTOFF_FREQUENCY, btype="high", fs=SAMPLE_RATE)
+
+import librosa
+from multiprocessing import cpu_count
+
+if not __package__:
+ from model_manager import VbachModelManager
+ from audio import Audio
+ from namer import Namer
+ from vbach_lib.fairseq import load_model_ensemble_and_task, load_checkpoint_to_cpu
+ from vbach_lib.algorithm.synthesizers import Synthesizer
+ from vbach_lib.predictors.FCPE import FCPEF0Predictor
+ from vbach_lib.predictors.RMVPE import RMVPE0Predictor
+else:
+ from .model_manager import VbachModelManager
+ from .audio import Audio
+ from .namer import Namer
+ from .vbach_lib.fairseq import load_model_ensemble_and_task, load_checkpoint_to_cpu
+ from .vbach_lib.algorithm.synthesizers import Synthesizer
+ from .vbach_lib.predictors.FCPE import FCPEF0Predictor
+ from .vbach_lib.predictors.RMVPE import RMVPE0Predictor
+
+
+model_manager = VbachModelManager()
+audio = Audio()
+namer = Namer()
+
+RMVPE_DIR = model_manager.rmvpe_path
+FCPE_DIR = model_manager.fcpe_path
+
+def remove_center(input_array, samplerate, rdf=0.99999, window_size=2048, overlap=2, window_type="blackman"):
+
+ left = input_array[0]
+ right = input_array[1]
+
+ nperseg = min(window_size, len(left))
+ if nperseg < 16:
+ nperseg = 16
+ if len(left) < 16:
+ import warnings
+ warnings.warn(f"Input too short ({len(left)} samples), returning original audio")
+ return left, right, left, right
+
+ noverlap = nperseg // overlap
+ if noverlap >= nperseg:
+ noverlap = nperseg - 1
+
+ f, t, Z_left = signal.stft(left, fs=samplerate, nperseg=nperseg, noverlap=noverlap, window=window_type)
+ f, t, Z_right = signal.stft(right, fs=samplerate, nperseg=nperseg, noverlap=noverlap, window=window_type)
+
+ Z_common_left = np.minimum(np.abs(Z_left), np.abs(Z_right)) * np.exp(1j*np.angle(Z_right))
+ Z_common_right = np.minimum(np.abs(Z_left), np.abs(Z_right)) * np.exp(1j*np.angle(Z_left))
+
+ reduction_factor = rdf
+
+ Z_new_left = Z_left - Z_common_left * reduction_factor
+ Z_new_right = Z_right - Z_common_right * reduction_factor
+
+ _, new_left = signal.istft(Z_new_left, fs=samplerate, nperseg=nperseg, noverlap=noverlap, window=window_type)
+ _, new_right = signal.istft(Z_new_right, fs=samplerate, nperseg=nperseg, noverlap=noverlap, window=window_type)
+ _, common_signal_left = signal.istft(Z_common_left, fs=samplerate, nperseg=nperseg, noverlap=noverlap, window=window_type)
+ _, common_signal_right = signal.istft(Z_common_right, fs=samplerate, nperseg=nperseg, noverlap=noverlap, window=window_type)
+
+ new_left = new_left[:len(left)]
+ new_right = new_right[:len(right)]
+ common_signal_left = common_signal_left[:len(left)]
+ common_signal_right = common_signal_right[:len(left)]
+
+ peak = np.max([np.abs(new_left).max(), np.abs(new_right).max()])
+ if peak > 1.0:
+ new_left = new_left / peak
+ new_right = new_right / peak
+
+ inverted_center_left = -common_signal_left
+ inverted_center_right = -common_signal_right
+
+ mixed_left = left + inverted_center_left
+ mixed_right = right + inverted_center_right
+
+ peak_mixed = np.max([np.abs(mixed_left).max(), np.abs(mixed_right).max()])
+ if peak_mixed > 1.0:
+ mixed_left = mixed_left / peak_mixed
+ mixed_right = mixed_right / peak_mixed
+
+ return common_signal_left, common_signal_right, new_left, new_right
+
+class AudioProcessor:
+ @staticmethod
+ def change_rms(sourceaudio, source_rate, targetaudio, target_rate, rate):
+ """
+ Изменяет RMS (среднеквадратичное значение) аудио.
+ """
+ rms1 = librosa.feature.rms(
+ y=sourceaudio,
+ frame_length=source_rate // 2 * 2,
+ hop_length=source_rate // 2,
+ )
+ rms2 = librosa.feature.rms(
+ y=targetaudio,
+ frame_length=target_rate // 2 * 2,
+ hop_length=target_rate // 2,
+ )
+
+ rms1 = F.interpolate(
+ torch.from_numpy(rms1).float().unsqueeze(0),
+ size=targetaudio.shape[0],
+ mode="linear",
+ ).squeeze()
+ rms2 = F.interpolate(
+ torch.from_numpy(rms2).float().unsqueeze(0),
+ size=targetaudio.shape[0],
+ mode="linear",
+ ).squeeze()
+ rms2 = torch.maximum(rms2, torch.zeros_like(rms2) + 1e-6)
+
+ adjustedaudio = (
+ targetaudio * (torch.pow(rms1, 1 - rate) * torch.pow(rms2, rate - 1)).numpy()
+ )
+ return adjustedaudio
+
+
+# Класс для преобразования голоса
+class VC:
+ def __init__(self, tgt_sr, config):
+ """
+ Инициализация параметров для преобразования голоса.
+ """
+ self.x_pad = config.x_pad
+ self.x_query = config.x_query
+ self.x_center = config.x_center
+ self.x_max = config.x_max
+ self.is_half = config.is_half
+ self.sample_rate = 16000
+ self.window = 160
+ self.t_pad = self.sample_rate * self.x_pad
+ self.t_pad_tgt = tgt_sr * self.x_pad
+ self.t_pad2 = self.t_pad * 2
+ self.t_query = self.sample_rate * self.x_query
+ self.t_center = self.sample_rate * self.x_center
+ self.t_max = self.sample_rate * self.x_max
+ self.time_step = self.window / self.sample_rate * 1000
+ self.device = config.device
+
+ def get_f0_crepe(self, x, f0_min, f0_max, p_len, hop_length, model="full"):
+ """
+ Получает F0 с использованием модели crepe.
+ """
+ x = x.astype(np.float32)
+ x /= np.quantile(np.abs(x), 0.999)
+ audio = torch.from_numpy(x).to(self.device, copy=True).unsqueeze(0)
+ if audio.ndim == 2 and audio.shape[0] > 1:
+ audio = torch.mean(audio, dim=0, keepdim=True)
+
+ pitch = torchcrepe.predict(
+ audio,
+ self.sample_rate,
+ hop_length,
+ f0_min,
+ f0_max,
+ model,
+ batch_size=hop_length * 2,
+ device=self.device,
+ pad=True,
+ )
+
+ p_len = p_len or x.shape[0] // hop_length
+ source = np.array(pitch.squeeze(0).cpu().float().numpy())
+ source[source < 0.001] = np.nan
+ target = np.interp(
+ np.arange(0, len(source) * p_len, len(source)) / p_len,
+ np.arange(0, len(source)),
+ source,
+ )
+ f0 = np.nan_to_num(target)
+ return f0
+
+ def get_f0_rmvpe(self, x, f0_min=1, f0_max=40000, *args, **kwargs):
+ """
+ Получает F0 с использованием модели rmvpe.
+ """
+ if not hasattr(self, "model_rmvpe"):
+ self.model_rmvpe = RMVPE0Predictor(
+ RMVPE_DIR, is_half=self.is_half, device=self.device
+ )
+ f0 = self.model_rmvpe.infer_from_audio_with_pitch(
+ x, thred=0.03, f0_min=f0_min, f0_max=f0_max
+ )
+ return f0
+
+ def get_f0(
+ self,
+ inputaudio_path,
+ x,
+ p_len,
+ pitch,
+ f0_method,
+ filter_radius,
+ hop_length,
+ inp_f0=None,
+ f0_min=50,
+ f0_max=1100,
+ ):
+ """
+ Получает F0 с использованием выбранного метода.
+ """
+ global inputaudio_path2wav
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
+
+ if f0_method == "mangio-crepe":
+ f0 = self.get_f0_crepe(x, f0_min, f0_max, p_len, int(hop_length))
+
+ elif f0_method == "rmvpe+":
+ params = {
+ "x": x,
+ "p_len": p_len,
+ "pitch": pitch,
+ "f0_min": f0_min,
+ "f0_max": f0_max,
+ "time_step": self.time_step,
+ "filter_radius": filter_radius,
+ "crepe_hop_length": int(hop_length),
+ "model": "full",
+ }
+ f0 = self.get_f0_rmvpe(**params)
+
+ elif f0_method == "fcpe":
+ self.model_fcpe = FCPEF0Predictor(
+ FCPE_DIR,
+ f0_min=int(f0_min),
+ f0_max=int(f0_max),
+ dtype=torch.float32,
+ device=self.device,
+ sample_rate=self.sample_rate,
+ threshold=0.03,
+ )
+ f0 = self.model_fcpe.compute_f0(x, p_len=p_len)
+ del self.model_fcpe
+ gc.collect()
+
+ f0 *= pow(2, pitch / 12)
+ tf0 = self.sample_rate // self.window
+ if inp_f0 is not None:
+ delta_t = np.round(
+ (inp_f0[:, 0].max() - inp_f0[:, 0].min()) * tf0 + 1
+ ).astype("int16")
+ replace_f0 = np.interp(list(range(delta_t)), inp_f0[:, 0] * 100, inp_f0[:, 1])
+ shape = f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)].shape[0]
+ f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)] = replace_f0[:shape]
+
+ f0bak = f0.copy()
+ f0_mel = 1127 * np.log(1 + f0 / 700)
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (
+ f0_mel_max - f0_mel_min
+ ) + 1
+ f0_mel[f0_mel <= 1] = 1
+ f0_mel[f0_mel > 255] = 255
+ f0_coarse = np.rint(f0_mel).astype(int)
+ return f0_coarse, f0bak
+
+ def vc(
+ self,
+ model,
+ net_g,
+ sid,
+ audio0,
+ pitch,
+ pitchf,
+ index,
+ big_npy,
+ index_rate,
+ version,
+ protect,
+ ):
+ """
+ Преобразует аудио с использованием модели.
+ """
+ feats = torch.from_numpy(audio0)
+ feats = feats.half() if self.is_half else feats.float()
+ if feats.dim() == 2:
+ feats = feats.mean(-1)
+ assert feats.dim() == 1, feats.dim()
+ feats = feats.view(1, -1)
+ padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)
+
+ inputs = {
+ "source": feats.to(self.device),
+ "padding_mask": padding_mask,
+ "output_layer": 9 if version == "v1" else 12,
+ }
+
+ with torch.no_grad():
+ logits = model.extract_features(**inputs)
+ feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
+ if protect < 0.5 and pitch is not None and pitchf is not None:
+ feats0 = feats.clone()
+ if index is not None and big_npy is not None and index_rate != 0:
+ npy = feats[0].cpu().numpy()
+ npy = npy.astype("float32") if self.is_half else npy
+ score, ix = index.search(npy, k=8)
+ weight = np.square(1 / score)
+ weight /= weight.sum(axis=1, keepdims=True)
+ npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
+ npy = npy.astype("float16") if self.is_half else npy
+ feats = (
+ torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate
+ + (1 - index_rate) * feats
+ )
+
+ feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
+ if protect < 0.5 and pitch is not None and pitchf is not None:
+ feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(
+ 0, 2, 1
+ )
+ p_len = audio0.shape[0] // self.window
+ if feats.shape[1] < p_len:
+ p_len = feats.shape[1]
+ if pitch is not None and pitchf is not None:
+ pitch = pitch[:, :p_len]
+ pitchf = pitchf[:, :p_len]
+
+ if protect < 0.5 and pitch is not None and pitchf is not None:
+ pitchff = pitchf.clone()
+ pitchff[pitchf > 0] = 1
+ pitchff[pitchf < 1] = protect
+ pitchff = pitchff.unsqueeze(-1)
+ feats = feats * pitchff + feats0 * (1 - pitchff)
+ feats = feats.to(feats0.dtype)
+ p_len = torch.tensor([p_len], device=self.device).long()
+ with torch.no_grad():
+ if pitch is not None and pitchf is not None:
+ audio1 = (
+ (net_g.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0])
+ .data.cpu()
+ .float()
+ .numpy()
+ )
+ else:
+ audio1 = (
+ (net_g.infer(feats, p_len, sid)[0][0, 0]).data.cpu().float().numpy()
+ )
+ del feats, p_len, padding_mask
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ return audio1
+
+ def pipeline(
+ self,
+ model,
+ net_g,
+ sid,
+ audio,
+ inputaudio_path,
+ pitch,
+ f0_method,
+ file_index,
+ index_rate,
+ pitch_guidance,
+ filter_radius,
+ tgt_sr,
+ resample_sr,
+ volume_envelope,
+ version,
+ protect,
+ hop_length,
+ f0_file,
+ f0_min=50,
+ f0_max=1100,
+ ):
+ """
+ Основной конвейер для преобразования аудио.
+ """
+ if (
+ file_index is not None
+ and file_index != ""
+ and os.path.exists(file_index)
+ and index_rate != 0
+ ):
+ try:
+ index = faiss.read_index(file_index)
+ big_npy = index.reconstruct_n(0, index.ntotal)
+ except Exception as e:
+ print(f"Произошла ошибка при чтении индекса FAISS: {e}")
+ index = big_npy = None
+ else:
+ index = big_npy = None
+ audio = signal.filtfilt(bh, ah, audio)
+ audio_pad = np.pad(audio, (self.window // 2, self.window // 2), mode="reflect")
+ opt_ts = []
+ if audio_pad.shape[0] > self.t_max:
+ audio_sum = np.zeros_like(audio)
+ for i in range(self.window):
+ audio_sum += audio_pad[i : i - self.window]
+ for t in range(self.t_center, audio.shape[0], self.t_center):
+ opt_ts.append(
+ t
+ - self.t_query
+ + np.where(
+ np.abs(audio_sum[t - self.t_query : t + self.t_query])
+ == np.abs(audio_sum[t - self.t_query : t + self.t_query]).min()
+ )[0][0]
+ )
+ s = 0
+ audio_opt = []
+ t = None
+ audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
+ p_len = audio_pad.shape[0] // self.window
+ inp_f0 = None
+ if f0_file and hasattr(f0_file, "name"):
+ try:
+ with open(f0_file.name, "r") as f:
+ lines = f.read().strip("\n").split("\n")
+ inp_f0 = np.array(
+ [[float(i) for i in line.split(",")] for line in lines],
+ dtype="float32",
+ )
+ except Exception as e:
+ print(f"Произошла ошибка при чтении файла F0: {e}")
+ sid = torch.tensor(sid, device=self.device).unsqueeze(0).long()
+ if pitch_guidance:
+ pitch, pitchf = self.get_f0(
+ inputaudio_path,
+ audio_pad,
+ p_len,
+ pitch,
+ f0_method,
+ filter_radius,
+ hop_length,
+ inp_f0,
+ f0_min,
+ f0_max,
+ )
+ pitch = pitch[:p_len]
+ pitchf = pitchf[:p_len]
+ if self.device == "mps":
+ pitchf = pitchf.astype(np.float32)
+ pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long()
+ pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
+ for t in opt_ts:
+ t = t // self.window * self.window
+ if pitch_guidance:
+ audio_opt.append(
+ self.vc(
+ model,
+ net_g,
+ sid,
+ audio_pad[s : t + self.t_pad2 + self.window],
+ pitch[:, s // self.window : (t + self.t_pad2) // self.window],
+ pitchf[:, s // self.window : (t + self.t_pad2) // self.window],
+ index,
+ big_npy,
+ index_rate,
+ version,
+ protect,
+ )[self.t_pad_tgt : -self.t_pad_tgt]
+ )
+ else:
+ audio_opt.append(
+ self.vc(
+ model,
+ net_g,
+ sid,
+ audio_pad[s : t + self.t_pad2 + self.window],
+ None,
+ None,
+ index,
+ big_npy,
+ index_rate,
+ version,
+ protect,
+ )[self.t_pad_tgt : -self.t_pad_tgt]
+ )
+ s = t
+ if pitch_guidance:
+ audio_opt.append(
+ self.vc(
+ model,
+ net_g,
+ sid,
+ audio_pad[t:],
+ pitch[:, t // self.window :] if t is not None else pitch,
+ pitchf[:, t // self.window :] if t is not None else pitchf,
+ index,
+ big_npy,
+ index_rate,
+ version,
+ protect,
+ )[self.t_pad_tgt : -self.t_pad_tgt]
+ )
+ else:
+ audio_opt.append(
+ self.vc(
+ model,
+ net_g,
+ sid,
+ audio_pad[t:],
+ None,
+ None,
+ index,
+ big_npy,
+ index_rate,
+ version,
+ protect,
+ )[self.t_pad_tgt : -self.t_pad_tgt]
+ )
+
+ audio_opt = np.concatenate(audio_opt)
+ if volume_envelope != 1:
+ audio_opt = AudioProcessor.change_rms(
+ audio, self.sample_rate, audio_opt, tgt_sr, volume_envelope
+ )
+ if resample_sr >= self.sample_rate and tgt_sr != resample_sr:
+ audio_opt = librosa.resample(audio_opt, orig_sr=tgt_sr, target_sr=resample_sr)
+
+ audio_max = np.abs(audio_opt).max() / 0.99
+ max_int16 = 32768
+ if audio_max > 1:
+ max_int16 /= audio_max
+ audio_opt = (audio_opt * max_int16).astype(np.int16)
+
+ del pitch, pitchf, sid
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return audio_opt
+
+def overlay_mono_on_stereo(monoaudio, stereoaudio, gain=0.5):
+ if monoaudio is None or stereoaudio is None:
+ raise ValueError("Input audio arrays cannot be None")
+
+ # Ensure float32 for processing
+ monoaudio = monoaudio.astype(np.float32)
+ stereoaudio = stereoaudio.astype(np.float32)
+
+ # Convert mono to stereo if needed
+ if monoaudio.ndim == 1:
+ monoaudio = np.vstack([monoaudio, monoaudio])
+ elif monoaudio.shape[0] == 1:
+ monoaudio = np.vstack([monoaudio[0], monoaudio[0]])
+
+ if monoaudio.shape[0] != 2 or stereoaudio.shape[0] != 2:
+ raise ValueError("Shapes must be (2, N)")
+
+ min_len = min(monoaudio.shape[1], stereoaudio.shape[1])
+ if min_len == 0:
+ raise ValueError("Audio arrays cannot be empty")
+
+ monoaudio = monoaudio[:, :min_len]
+ stereoaudio = stereoaudio[:, :min_len]
+
+ result = stereoaudio + monoaudio * gain
+
+ # Normalize to prevent clipping
+ max_amp = np.max(np.abs(result))
+ if max_amp > 0:
+ result /= max_amp
+
+ # Convert back to int16 for output (if needed)
+ result = (result * 32767).astype(np.int16)
+
+ return result
+
+def loadaudio(
+ file_path: str,
+ target_sr: int,
+ stereo_mode: str
+) -> np.ndarray:
+ """
+ Загружает аудиофайл с помощью librosa, обрабатывает и возвращает аудиосигнал
+
+ Параметры:
+ file_path: Путь к аудиофайлу
+ target_sr: Целевая частота дискретизации
+ mono: Преобразовать в моно (по умолчанию True)
+ normalize: Нормализовать аудио (по умолчанию False)
+ duration: Загрузить только указанную длительность (в секундах)
+ offset: Начальное смещение для загрузки (в секундах)
+
+ Возвращает:
+ Аудиоданные в виде numpy array (моно: (samples,), стерео: (channels, samples))
+
+ Исключения:
+ RuntimeError: При ошибках загрузки или обработки аудио
+ """
+ try:
+ mid, left, right = None, None, None
+
+ if stereo_mode == "mono":
+ # Загрузка аудио с помощью librosa
+ midaudio, sr, _ = audio.read(
+ i=file_path,
+ sr=None,
+ mono=True
+ )
+ midaudio = librosa.resample(
+ midaudio, # Исправлено: было audio
+ orig_sr=sr,
+ target_sr=target_sr
+ )
+ mid = midaudio.flatten()
+
+ elif stereo_mode == "left/right" or stereo_mode == "sim/dif":
+ # Загрузка аудио с помощью librosa
+ stereoaudio, sr, _ = audio.read(
+ i=file_path,
+ sr=None,
+ mono=False
+ )
+
+ if stereo_mode == "left/right":
+ leftaudio = stereoaudio[0] # Исправлено: было [:, 0]
+ rightaudio = stereoaudio[1] # Исправлено: было [:, 1]
+ leftaudio = librosa.resample(
+ leftaudio,
+ orig_sr=sr,
+ target_sr=target_sr
+ )
+ rightaudio = librosa.resample(
+ rightaudio,
+ orig_sr=sr,
+ target_sr=target_sr
+ )
+
+ left = leftaudio.flatten()
+ right = rightaudio.flatten()
+
+ elif stereo_mode == "sim/dif":
+ mid_left, mid_right, dif_left, dif_right = remove_center(input_array=stereoaudio, samplerate=sr)
+ midaudio = (mid_left + mid_right) * 0.5
+
+ midaudio = librosa.resample(
+ midaudio,
+ orig_sr=sr,
+ target_sr=target_sr
+ )
+ dif_left = librosa.resample(
+ dif_left,
+ orig_sr=sr,
+ target_sr=target_sr
+ )
+ dif_right = librosa.resample(
+ dif_right,
+ orig_sr=sr,
+ target_sr=target_sr
+ )
+
+ mid = midaudio.flatten()
+ left = dif_left.flatten() # Исправлено: было leftaudio
+ right = dif_right.flatten() # Исправлено: было rightaudio
+
+ return mid, left, right
+
+ except Exception as e:
+ raise RuntimeError(f"Ошибка загрузки аудио '{file_path}': {str(e)}")
+
+class Config:
+ def __init__(self):
+ self.device = self.get_device()
+ self.is_half = self.device == "cpu"
+ self.n_cpu = cpu_count()
+ self.gpu_name = None
+ self.gpu_mem = None
+ self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
+
+ def get_device(self):
+ if torch.cuda.is_available():
+ return "cuda"
+ elif torch.backends.mps.is_available():
+ return "mps"
+ else:
+ return "cpu"
+
+ def device_config(self):
+ if torch.cuda.is_available():
+ print("Используется устройство CUDA")
+ self._configure_gpu()
+ elif torch.backends.mps.is_available():
+ print("Используется устройство MPS")
+ self.device = "mps"
+ else:
+ print("Используется CPU")
+ self.device = "cpu"
+ self.is_half = True
+
+ x_pad, x_query, x_center, x_max = (
+ (3, 10, 60, 65) if self.is_half else (1, 6, 38, 41)
+ )
+ if self.gpu_mem is not None and self.gpu_mem <= 4:
+ x_pad, x_query, x_center, x_max = (1, 5, 30, 32)
+
+ return x_pad, x_query, x_center, x_max
+
+ def _configure_gpu(self):
+ self.gpu_name = torch.cuda.get_device_name(self.device)
+ low_end_gpus = ["16", "P40", "P10", "1060", "1070", "1080"]
+ if (
+ any(gpu in self.gpu_name for gpu in low_end_gpus)
+ and "V100" not in self.gpu_name.upper()
+ ):
+ self.is_half = False
+ self.gpu_mem = int(
+ torch.cuda.get_device_properties(self.device).total_memory
+ / 1024
+ / 1024
+ / 1024
+ + 0.4
+ )
+
+# Загрузка модели Hubert
+def load_hubert(device, is_half, model_path):
+ models, saved_cfg, task = load_model_ensemble_and_task(
+ [model_path], suffix=""
+ )
+ hubert = models[0].to(device)
+ hubert = hubert.half() if is_half else hubert.float()
+ hubert.eval()
+ return hubert
+
+# Получение голосового преобразователя
+def get_vc(device, is_half, config, model_path):
+ cpt = torch.load(model_path, map_location="cpu", weights_only=False)
+ if "config" not in cpt or "weight" not in cpt:
+ raise ValueError(
+ f"Некорректный формат для {model_path}. "
+ "Используйте голосовую модель, обученную с использованием RVC v2."
+ )
+
+ tgt_sr = cpt["config"][-1]
+ cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
+ pitch_guidance = cpt.get("f0", 1)
+ version = cpt.get("version", "v1")
+ input_dim = 768 if version == "v2" else 256
+
+ net_g = Synthesizer(
+ *cpt["config"],
+ use_f0=pitch_guidance,
+ input_dim=input_dim,
+ is_half=is_half,
+ )
+
+ del net_g.enc_q
+ print(net_g.load_state_dict(cpt["weight"], strict=False))
+ net_g.eval().to(device)
+ net_g = net_g.half() if is_half else net_g.float()
+
+ vc = VC(tgt_sr, config)
+ return cpt, version, net_g, tgt_sr, vc
+
+def rvc_infer(
+ index_path,
+ index_rate,
+ input_path,
+ output_path,
+ pitch,
+ f0_method,
+ cpt,
+ version,
+ net_g,
+ filter_radius,
+ tgt_sr,
+ volume_envelope,
+ protect,
+ hop_length,
+ vc,
+ hubert_model,
+ f0_min=50,
+ f0_max=1100,
+ format_output="wav",
+ output_bitrate="320k",
+ stereo_mode="mono"
+) -> str:
+
+ mid, left, right = loadaudio(input_path, 16000, stereo_mode)
+ pitch_guidance = cpt.get("f0", 1)
+
+ if stereo_mode == "mono":
+ if mid is None:
+ raise ValueError("Mono audio data is None")
+ audio_opt = vc.pipeline(
+ hubert_model,
+ net_g,
+ 0,
+ mid,
+ input_path,
+ pitch,
+ f0_method,
+ index_path,
+ index_rate,
+ pitch_guidance,
+ filter_radius,
+ tgt_sr,
+ 0,
+ volume_envelope,
+ version,
+ protect,
+ hop_length,
+ f0_file=None,
+ f0_min=f0_min,
+ f0_max=f0_max,
+ )
+
+ elif stereo_mode == "left/right":
+ if left is None or right is None:
+ raise ValueError("Left or right audio channel is None")
+
+ leftaudio_opt = vc.pipeline(
+ hubert_model,
+ net_g,
+ 0,
+ left,
+ input_path,
+ pitch,
+ f0_method,
+ index_path,
+ index_rate,
+ pitch_guidance,
+ filter_radius,
+ tgt_sr,
+ 0,
+ volume_envelope,
+ version,
+ protect,
+ hop_length,
+ f0_file=None,
+ f0_min=f0_min,
+ f0_max=f0_max,
+ )
+ rightaudio_opt = vc.pipeline(
+ hubert_model,
+ net_g,
+ 0,
+ right,
+ input_path,
+ pitch,
+ f0_method,
+ index_path,
+ index_rate,
+ pitch_guidance,
+ filter_radius,
+ tgt_sr,
+ 0,
+ volume_envelope,
+ version,
+ protect,
+ hop_length,
+ f0_file=None,
+ f0_min=f0_min,
+ f0_max=f0_max,
+ )
+
+ # Ensure both channels have the same length
+ min_len = min(len(leftaudio_opt), len(rightaudio_opt))
+ if min_len == 0:
+ raise ValueError("Processed audio is empty")
+
+ leftaudio_opt = leftaudio_opt[:min_len]
+ rightaudio_opt = rightaudio_opt[:min_len]
+
+ audio_opt = np.stack((leftaudio_opt, rightaudio_opt), axis=0)
+
+ elif stereo_mode == "sim/dif":
+ if mid is None or left is None or right is None:
+ raise ValueError("Mid, left or right audio channel is None")
+
+ midaudio_opt = vc.pipeline(
+ hubert_model,
+ net_g,
+ 0,
+ mid,
+ input_path,
+ pitch,
+ f0_method,
+ index_path,
+ index_rate,
+ pitch_guidance,
+ filter_radius,
+ tgt_sr,
+ 0,
+ volume_envelope,
+ version,
+ protect,
+ hop_length,
+ f0_file=None,
+ f0_min=f0_min,
+ f0_max=f0_max,
+ )
+ leftaudio_opt = vc.pipeline(
+ hubert_model,
+ net_g,
+ 0,
+ left,
+ input_path,
+ pitch,
+ f0_method,
+ index_path,
+ index_rate,
+ pitch_guidance,
+ filter_radius,
+ tgt_sr,
+ 0,
+ volume_envelope,
+ version,
+ protect,
+ hop_length,
+ f0_file=None,
+ f0_min=f0_min,
+ f0_max=f0_max,
+ )
+ rightaudio_opt = vc.pipeline(
+ hubert_model,
+ net_g,
+ 0,
+ right,
+ input_path,
+ pitch,
+ f0_method,
+ index_path,
+ index_rate,
+ pitch_guidance,
+ filter_radius,
+ tgt_sr,
+ 0,
+ volume_envelope,
+ version,
+ protect,
+ hop_length,
+ f0_file=None,
+ f0_min=f0_min,
+ f0_max=f0_max,
+ )
+
+ # Ensure all channels have the same length
+ min_len = min(len(midaudio_opt), len(leftaudio_opt), len(rightaudio_opt))
+ if min_len == 0:
+ raise ValueError("Processed audio is empty")
+
+ midaudio_opt = midaudio_opt[:min_len]
+ leftaudio_opt = leftaudio_opt[:min_len]
+ rightaudio_opt = rightaudio_opt[:min_len]
+
+ difaudio_opt = np.stack((leftaudio_opt, rightaudio_opt), axis=0)
+
+ audio_opt = overlay_mono_on_stereo(midaudio_opt, difaudio_opt)
+
+ output_path = audio.write(o=output_path, array=audio_opt, sr=tgt_sr, of=format_output, br=output_bitrate)
+ return output_path
+
+
+def load_rvc_model(voice_model):
+
+ if voice_model in model_manager.parse_voice_models():
+ rvc_model_path, rvc_index_path = model_manager.parse_pth_and_index(voice_model)
+
+ if not rvc_model_path:
+ raise ValueError(
+ f"[91mФайла для модели {voice_model} не существует. "
+ "Возможно, вы неправильно её установили.[0m"
+ )
+
+ else:
+ raise ValueError(
+ f"[91mМодели {voice_model} не существует. "
+ "Возможно, вы неправильно ввели имя.[0m"
+ )
+
+ return rvc_model_path, rvc_index_path
+
+def voice_conversion(
+ voice_model,
+ vocals_path,
+ output_path,
+ pitch,
+ f0_method,
+ index_rate,
+ filter_radius,
+ volume_envelope,
+ protect,
+ hop_length,
+ f0_min,
+ f0_max,
+ format_output,
+ output_bitrate,
+ stereo_mode,
+ hubert_path=None
+):
+ rvc_model_path, rvc_index_path = load_rvc_model(voice_model)
+
+ config = Config()
+ hubert_model = load_hubert(config.device, config.is_half, hubert_path if hubert_path else model_manager.hubert_path)
+ cpt, version, net_g, tgt_sr, vc = get_vc(
+ config.device, config.is_half, config, rvc_model_path
+ )
+
+ outputaudio = rvc_infer(
+ rvc_index_path,
+ index_rate,
+ vocals_path,
+ output_path,
+ pitch,
+ f0_method,
+ cpt,
+ version,
+ net_g,
+ filter_radius,
+ tgt_sr,
+ volume_envelope,
+ protect,
+ hop_length,
+ vc,
+ hubert_model,
+ f0_min,
+ f0_max,
+ format_output,
+ output_bitrate,
+ stereo_mode
+ )
+
+ del hubert_model, cpt, net_g, vc
+ gc.collect()
+ torch.cuda.empty_cache()
+ return outputaudio
+
+def vbach_inference(
+ input_file: str,
+ model_name: str,
+ output_dir: str,
+ output_name: str,
+ output_format: str,
+ output_bitrate: str | int,
+ pitch: int,
+ method_pitch: str,
+ format_name: bool = False,
+ add_params: dict = {
+ "index_rate": 0,
+ "filter_radius": 3,
+ "protect": 0.33,
+ "rms": 0.25,
+ "mangio_crepe_hop_length": 128,
+ "f0_min": 50,
+ "f0_max": 1100,
+ "stereo_mode": "mono"
+ }
+ ):
+ stereo_mode = add_params.get("stereo_mode", "mono")
+ index_rate = add_params.get("index_rate", 0)
+ filter_radius = add_params.get("filter_radius", 3)
+ protect = add_params.get("protect", 0.33)
+ rms = add_params.get("rms", 0.25)
+ mangio_crepe_hop_length = add_params.get("mangio_crepe_hop_length", 0)
+ f0_min = add_params.get("f0_min", 50)
+ f0_max = add_params.get("f0_max", 1100)
+ if not input_file:
+ raise ValueError("Входной файл не указан")
+ if not os.path.exists(input_file):
+ raise ValueError(f"Входной файл не найден: {input}")
+ if not audio.check(input_file):
+ raise ValueError("Входной файл не содержит аудио")
+ basename = os.path.splitext(os.path.basename(input_file))[0]
+
+ final_output_name = None
+
+ print("Инференс запущен")
+
+ if format_name:
+ cleaned_output_name_template = namer.sanitize(namer.dedup_template(output_name, keys=["NAME", "MODEL", "F0METHOD", "PITCH"]))
+ short_basename = namer.short_input_name_template(cleaned_output_name_template, MODEL=model_name, F0METHOD=method_pitch, PITCH=pitch, NAME=basename)
+ final_output_name = namer.template(cleaned_output_name_template, MODEL=model_name, F0METHOD=method_pitch, PITCH=pitch, NAME=short_basename)
+
+ else:
+ final_output_name = output_name
+
+ final_output_path = os.path.join(output_dir, f"{final_output_name}.{output_format}")
+ output_converted_voice = voice_conversion(voice_model=model_name, vocals_path=input_file, output_path=final_output_path, pitch=pitch, f0_method=method_pitch, index_rate=index_rate, filter_radius=filter_radius, volume_envelope=rms, protect=protect, hop_length=mangio_crepe_hop_length, f0_min=f0_min, f0_max=f0_max, format_output=output_format, hubert_path=None, output_bitrate=output_bitrate, stereo_mode=stereo_mode)
+ print(f"Инференс завершен\nПуть к выходному файлу: \"{output_converted_voice}\"")
+ return output_converted_voice
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Vbach - форк Polgen-RVC 1.2.0"
+ )
+ parser.add_argument("--input", type=str, help="Путь к входному файлу или папке")
+ parser.add_argument(
+ "--output_dir", type=str, required=True, help="Путь для сохранения результатов"
+ )
+ parser.add_argument(
+ "--output_format",
+ type=str,
+ default="wav",
+ choices=audio.output_formats,
+ help="Формат выходных файлов",
+ )
+ parser.add_argument(
+ "--output_bitrate", type=str, default="320k", help="Битрейт выходного файла"
+ )
+ parser.add_argument(
+ "--format_name",
+ action="store_true",
+ help="Форматировать имя выходного файла",
+ )
+ parser.add_argument(
+ "--output_name",
+ type=str,
+ default="NAME_STEM",
+ help="Имя выходного файла",
+ )
+ parser.add_argument(
+ "--model_name",
+ type=str,
+ default="model",
+ help="Имя голосовой модели",
+ )
+
+ parser.add_argument(
+ '--index_rate',
+ type=float,
+ default=0,
+ help='Интенсивность использования индексного файла (от 0.0 до 1.0)',
+ metavar='[0.0-1.0]'
+ )
+ parser.add_argument(
+ '--stereo_mode',
+ type=str,
+ default="mono",
+ choices=["mono", "left/right", "sim/dif"],
+ help='Режим каналов: моно или стерео'
+ )
+ parser.add_argument(
+ '--method_pitch',
+ type=str,
+ default="rmvpe+",
+ help='Метод извлечения pitch (тона)'
+ )
+ parser.add_argument(
+ '--pitch',
+ type=int,
+ default=0,
+ help='Корректировка тона в полутонах'
+ )
+ parser.add_argument(
+ '--hop_length',
+ type=int,
+ default=128,
+ help='Длина hop (в семплах) для обработки'
+ )
+ parser.add_argument(
+ '--filter_radius',
+ type=int,
+ default=3,
+ help='Радиус фильтра для сглаживания'
+ )
+ parser.add_argument(
+ '--rms',
+ type=float,
+ default=0.25,
+ help='Масштабирование огибающей громкости (RMS)'
+ )
+ parser.add_argument(
+ '--protect',
+ type=float,
+ default=0.33,
+ help='Защита для глухих согласных звуков'
+ )
+ parser.add_argument(
+ '--f0_min',
+ type=int,
+ default=50,
+ help='Минимальная частота pitch (F0) в Hz'
+ )
+ parser.add_argument(
+ '--f0_max',
+ type=int,
+ default=1100,
+ help='Максимальная частота pitch (F0) в Hz'
+ )
+
+ args = parser.parse_args()
+
+ if args.input:
+ if os.path.exists(args.input) and os.path.isfile(args.input):
+ if audio.check(args.input):
+ vbach_inference(input_file=args.input, model_name=args.model_name, output_dir=args.output_dir, output_name=args.output_name, output_bitrate=args.output_bitrate, output_format=args.output_format, pitch=args.pitch, method_pitch=args.method_pitch, format_name=args.format_name, add_params={ "index_rate": args.index_rate,"filter_radius": args.filter_radius,"protect": args.protect,"rms": args.rms,"mangio_crepe_hop_length": args.hop_length,"f0_min": args.f0_min,"f0_max": args.f0_max,"stereo_mode": args.stereo_mode})
+ elif os.path.exists(args.input) and os.path.isdir(args.input):
+ list_valid_files = []
+ for file in os.listdir(args.input):
+ if os.path.isfile(os.path.join(args.input, file)):
+ if audio.check(os.path.join(args.input, file)):
+ list_valid_files.append(os.path.join(args.input, file))
+ if list_valid_files:
+ for i, vocals_file in enumerate(list_valid_files, start=1):
+ print(f"Файл {i} из {len(list_valid_files)}: {vocals_file}")
+ vbach_inference(input_file=vocals_file, model_name=args.model_name, output_dir=args.output_dir, output_name=args.output_name, output_bitrate=args.output_bitrate, output_format=args.output_format, pitch=args.pitch, method_pitch=args.method_pitch, format_name=True if len(list_valid_files) > 1 else args.format_name, add_params={ "index_rate": args.index_rate,"filter_radius": args.filter_radius,"protect": args.protect,"rms": args.rms,"mangio_crepe_hop_length": args.hop_length,"f0_min": args.f0_min,"f0_max": args.f0_max,"stereo_mode": args.stereo_mode})
diff --git a/mvsepless/vbach_lib/algorithm/__init__.py b/mvsepless/vbach_lib/algorithm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0dac65653c4ac827e44864913497651c8434874
--- /dev/null
+++ b/mvsepless/vbach_lib/algorithm/__init__.py
@@ -0,0 +1,2 @@
+
+
\ No newline at end of file
diff --git a/mvsepless/vbach_lib/algorithm/attentions.py b/mvsepless/vbach_lib/algorithm/attentions.py
new file mode 100644
index 0000000000000000000000000000000000000000..2424e79fbba2a0b41f965f9bb4038ecda0fbc848
--- /dev/null
+++ b/mvsepless/vbach_lib/algorithm/attentions.py
@@ -0,0 +1,224 @@
+
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .commons import convert_pad_shape
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(
+ self,
+ channels,
+ out_channels,
+ n_heads,
+ p_dropout=0.0,
+ window_size=None,
+ heads_share=True,
+ block_length=None,
+ proximal_bias=False,
+ proximal_init=False,
+ ):
+ super().__init__()
+ assert channels % n_heads == 0
+
+ self.channels = channels
+ self.out_channels = out_channels
+ self.n_heads = n_heads
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+ self.heads_share = heads_share
+ self.block_length = block_length
+ self.proximal_bias = proximal_bias
+ self.proximal_init = proximal_init
+ self.attn = None
+
+ self.k_channels = channels // n_heads
+ self.conv_q = nn.Conv1d(channels, channels, 1)
+ self.conv_k = nn.Conv1d(channels, channels, 1)
+ self.conv_v = nn.Conv1d(channels, channels, 1)
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
+ self.drop = nn.Dropout(p_dropout)
+
+ if window_size is not None:
+ n_heads_rel = 1 if heads_share else n_heads
+ rel_stddev = self.k_channels**-0.5
+ self.emb_rel_k = nn.Parameter(
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+ * rel_stddev
+ )
+ self.emb_rel_v = nn.Parameter(
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+ * rel_stddev
+ )
+
+ nn.init.xavier_uniform_(self.conv_q.weight)
+ nn.init.xavier_uniform_(self.conv_k.weight)
+ nn.init.xavier_uniform_(self.conv_v.weight)
+ if proximal_init:
+ with torch.no_grad():
+ self.conv_k.weight.copy_(self.conv_q.weight)
+ self.conv_k.bias.copy_(self.conv_q.bias)
+
+ def forward(self, x, c, attn_mask=None):
+ q = self.conv_q(x)
+ k = self.conv_k(c)
+ v = self.conv_v(c)
+
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
+
+ x = self.conv_o(x)
+ return x
+
+ def attention(self, query, key, value, mask=None):
+ b, d, t_s, t_t = (*key.size(), query.size(2))
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
+ if self.window_size is not None:
+ assert t_s == t_t, "Relative attention is only available for self-attention."
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
+ rel_logits = self._matmul_with_relative_keys(
+ query / math.sqrt(self.k_channels), key_relative_embeddings
+ )
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
+ scores = scores + scores_local
+ if self.proximal_bias:
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
+ scores = scores + self._attention_bias_proximal(t_s).to(
+ device=scores.device, dtype=scores.dtype
+ )
+ if mask is not None:
+ scores = scores.masked_fill(mask == 0, -1e4)
+ if self.block_length is not None:
+ assert t_s == t_t, "Local attention is only available for self-attention."
+ block_mask = (
+ torch.ones_like(scores)
+ .triu(-self.block_length)
+ .tril(self.block_length)
+ )
+ scores = scores.masked_fill(block_mask == 0, -1e4)
+ p_attn = F.softmax(scores, dim=-1)
+ p_attn = self.drop(p_attn)
+ output = torch.matmul(p_attn, value)
+ if self.window_size is not None:
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
+ output = output + self._matmul_with_relative_values(
+ relative_weights, value_relative_embeddings
+ )
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t)
+ return output, p_attn
+
+ def _matmul_with_relative_values(self, x, y):
+ ret = torch.matmul(x, y.unsqueeze(0))
+ return ret
+
+ def _matmul_with_relative_keys(self, x, y):
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
+ return ret
+
+ def _get_relative_embeddings(self, relative_embeddings, length):
+ pad_length = max(length - (self.window_size + 1), 0)
+ slice_start_position = max((self.window_size + 1) - length, 0)
+ slice_end_position = slice_start_position + 2 * length - 1
+ if pad_length > 0:
+ padded_relative_embeddings = F.pad(
+ relative_embeddings,
+ convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
+ )
+ else:
+ padded_relative_embeddings = relative_embeddings
+ used_relative_embeddings = padded_relative_embeddings[
+ :, slice_start_position:slice_end_position
+ ]
+ return used_relative_embeddings
+
+ def _relative_position_to_absolute_position(self, x):
+ batch, heads, length, _ = x.size()
+
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
+
+ x_flat = x.view([batch, heads, length * 2 * length])
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
+
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
+ :, :, :length, length - 1 :
+ ]
+ return x_final
+
+ def _absolute_position_to_relative_position(self, x):
+ batch, heads, length, _ = x.size()
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
+ return x_final
+
+ def _attention_bias_proximal(self, length):
+ r = torch.arange(length, dtype=torch.float32)
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
+
+
+class FFN(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout=0.0,
+ activation=None,
+ causal=False,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.activation = activation
+ self.causal = causal
+
+ if causal:
+ self.padding = self._causal_padding
+ else:
+ self.padding = self._same_padding
+
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
+ self.drop = nn.Dropout(p_dropout)
+
+ def forward(self, x, x_mask):
+ x = self.conv_1(self.padding(x * x_mask))
+ if self.activation == "gelu":
+ x = x * torch.sigmoid(1.702 * x)
+ else:
+ x = torch.relu(x)
+ x = self.drop(x)
+ x = self.conv_2(self.padding(x * x_mask))
+ return x * x_mask
+
+ def _causal_padding(self, x):
+ if self.kernel_size == 1:
+ return x
+ pad_l = self.kernel_size - 1
+ pad_r = 0
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+ x = F.pad(x, convert_pad_shape(padding))
+ return x
+
+ def _same_padding(self, x):
+ if self.kernel_size == 1:
+ return x
+ pad_l = (self.kernel_size - 1) // 2
+ pad_r = self.kernel_size // 2
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+ x = F.pad(x, convert_pad_shape(padding))
+ return x
+
+
\ No newline at end of file
diff --git a/mvsepless/vbach_lib/algorithm/commons.py b/mvsepless/vbach_lib/algorithm/commons.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ed8615d1ac464ba2d58f97e78f1af0cbe33965e
--- /dev/null
+++ b/mvsepless/vbach_lib/algorithm/commons.py
@@ -0,0 +1,114 @@
+
+import math
+import torch
+from torch.nn import functional as F
+from typing import List, Optional
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def kl_divergence(m_p, logs_p, m_q, logs_q):
+ kl = (logs_q - logs_p) - 0.5
+ kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
+ return kl
+
+
+def slice_segments(
+ x: torch.Tensor, ids_str: torch.Tensor, segment_size: int = 4, dim: int = 2
+):
+ if dim == 2:
+ ret = torch.zeros_like(x[:, :segment_size])
+ elif dim == 3:
+ ret = torch.zeros_like(x[:, :, :segment_size])
+
+ for i in range(x.size(0)):
+ idx_str = ids_str[i].item()
+ idx_end = idx_str + segment_size
+ if dim == 2:
+ ret[i] = x[i, idx_str:idx_end]
+ else:
+ ret[i] = x[i, :, idx_str:idx_end]
+
+ return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+ b, d, t = x.size()
+ if x_lengths is None:
+ x_lengths = t
+ ids_str_max = x_lengths - segment_size + 1
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
+ ret = slice_segments(x, ids_str, segment_size, dim=3)
+ return ret, ids_str
+
+
+def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
+ position = torch.arange(length, dtype=torch.float)
+ num_timescales = channels // 2
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
+ num_timescales - 1
+ )
+ inv_timescales = min_timescale * torch.exp(
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
+ )
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
+ signal = signal.view(1, channels, length)
+ return signal
+
+
+def subsequent_mask(length):
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
+ return mask
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+ n_channels_int = n_channels[0]
+ in_act = input_a + input_b
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+ return acts
+
+
+def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def clip_grad_value(parameters, clip_value, norm_type=2):
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = List(filter(lambda p: p.grad is not None, parameters))
+ norm_type = float(norm_type)
+ if clip_value is not None:
+ clip_value = float(clip_value)
+
+ total_norm = 0
+ for p in parameters:
+ param_norm = p.grad.data.norm(norm_type)
+ total_norm += param_norm.item() ** norm_type
+ if clip_value is not None:
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
+ total_norm = total_norm ** (1.0 / norm_type)
+ return total_norm
+
+
\ No newline at end of file
diff --git a/mvsepless/vbach_lib/algorithm/discriminators.py b/mvsepless/vbach_lib/algorithm/discriminators.py
new file mode 100644
index 0000000000000000000000000000000000000000..16714ef30199eb1b2bdadda63f8633e9145e84ce
--- /dev/null
+++ b/mvsepless/vbach_lib/algorithm/discriminators.py
@@ -0,0 +1,128 @@
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.utils.parametrizations import spectral_norm, weight_norm
+
+from .commons import get_padding
+from .residuals import LRELU_SLOPE
+
+
+PERIODS_V1 = [2, 3, 5, 7, 11, 17]
+PERIODS_V2 = [2, 3, 5, 7, 11, 17, 23, 37]
+IN_CHANNELS = [1, 32, 128, 512, 1024]
+OUT_CHANNELS = [32, 128, 512, 1024, 1024]
+
+
+class MultiPeriodDiscriminator(nn.Module):
+ def __init__(self, use_spectral_norm=False):
+ super(MultiPeriodDiscriminator, self).__init__()
+ self.discriminators = nn.ModuleList(
+ [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
+ + [DiscriminatorP(p, use_spectral_norm=use_spectral_norm) for p in PERIODS_V1]
+ )
+
+ def forward(self, y, y_hat):
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
+ for d in self.discriminators:
+ y_d_r, fmap_r = d(y)
+ y_d_g, fmap_g = d(y_hat)
+ y_d_rs.append(y_d_r)
+ y_d_gs.append(y_d_g)
+ fmap_rs.append(fmap_r)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class MultiPeriodDiscriminatorV2(nn.Module):
+ def __init__(self, use_spectral_norm=False):
+ super(MultiPeriodDiscriminatorV2, self).__init__()
+ self.discriminators = nn.ModuleList(
+ [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
+ + [DiscriminatorP(p, use_spectral_norm=use_spectral_norm) for p in PERIODS_V2]
+ )
+
+ def forward(self, y, y_hat):
+ y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
+ for d in self.discriminators:
+ y_d_r, fmap_r = d(y)
+ y_d_g, fmap_g = d(y_hat)
+ y_d_rs.append(y_d_r)
+ y_d_gs.append(y_d_g)
+ fmap_rs.append(fmap_r)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class DiscriminatorS(nn.Module):
+ def __init__(self, use_spectral_norm=False):
+ super(DiscriminatorS, self).__init__()
+ norm_f = spectral_norm if use_spectral_norm else weight_norm
+ self.convs = nn.ModuleList(
+ [
+ norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)),
+ norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)),
+ norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)),
+ norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
+ norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
+ norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
+ ]
+ )
+ self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
+ self.lrelu = nn.LeakyReLU(LRELU_SLOPE)
+
+ def forward(self, x):
+ fmap = []
+ for conv in self.convs:
+ x = self.lrelu(conv(x))
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+ return x, fmap
+
+
+class DiscriminatorP(nn.Module):
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
+ super(DiscriminatorP, self).__init__()
+ self.period = period
+ norm_f = spectral_norm if use_spectral_norm else weight_norm
+
+ self.convs = nn.ModuleList(
+ [
+ norm_f(
+ nn.Conv2d(
+ in_ch,
+ out_ch,
+ (kernel_size, 1),
+ (stride, 1),
+ padding=(get_padding(kernel_size, 1), 0),
+ )
+ )
+ for in_ch, out_ch in zip(IN_CHANNELS, OUT_CHANNELS)
+ ]
+ )
+
+ self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+ self.lrelu = nn.LeakyReLU(LRELU_SLOPE)
+
+ def forward(self, x):
+ fmap = []
+ b, c, t = x.shape
+ if t % self.period != 0:
+ n_pad = self.period - (t % self.period)
+ x = F.pad(x, (0, n_pad), "reflect")
+ x = x.view(b, c, -1, self.period)
+
+ for conv in self.convs:
+ x = self.lrelu(conv(x))
+ fmap.append(x)
+
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+ return x, fmap
+
+
\ No newline at end of file
diff --git a/mvsepless/vbach_lib/algorithm/encoders.py b/mvsepless/vbach_lib/algorithm/encoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..61e273a6000aa15adf5890a6f2bdec2fd27fdfd5
--- /dev/null
+++ b/mvsepless/vbach_lib/algorithm/encoders.py
@@ -0,0 +1,183 @@
+
+import math
+import torch
+from torch import nn
+from torch.nn.utils.weight_norm import remove_weight_norm
+from typing import Optional
+
+from .attentions import FFN, MultiHeadAttention
+from .commons import sequence_mask
+from .modules import WaveNet
+from .normalization import LayerNorm
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size=1,
+ p_dropout=0.0,
+ window_size=10,
+ **kwargs
+ ):
+ super().__init__()
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+
+ self.drop = nn.Dropout(p_dropout)
+ self.attn_layers = nn.ModuleList()
+ self.norm_layers_1 = nn.ModuleList()
+ self.ffn_layers = nn.ModuleList()
+ self.norm_layers_2 = nn.ModuleList()
+ for i in range(self.n_layers):
+ self.attn_layers.append(
+ MultiHeadAttention(
+ hidden_channels,
+ hidden_channels,
+ n_heads,
+ p_dropout=p_dropout,
+ window_size=window_size,
+ )
+ )
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
+ self.ffn_layers.append(
+ FFN(
+ hidden_channels,
+ hidden_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout=p_dropout,
+ )
+ )
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+ def forward(self, x, x_mask):
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+ x = x * x_mask
+ for i in range(self.n_layers):
+ y = self.attn_layers[i](x, x, attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_1[i](x + y)
+
+ y = self.ffn_layers[i](x, x_mask)
+ y = self.drop(y)
+ x = self.norm_layers_2[i](x + y)
+ x = x * x_mask
+ return x
+
+
+class TextEncoder(nn.Module):
+ def __init__(
+ self,
+ out_channels,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout,
+ embedding_dim,
+ f0=True,
+ ):
+ super(TextEncoder, self).__init__()
+ self.out_channels = out_channels
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = float(p_dropout)
+ self.emb_phone = nn.Linear(embedding_dim, hidden_channels)
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
+ if f0:
+ self.emb_pitch = nn.Embedding(256, hidden_channels)
+ self.encoder = Encoder(
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ float(p_dropout),
+ )
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+ def forward(
+ self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor
+ ):
+ if pitch is None:
+ x = self.emb_phone(phone)
+ else:
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
+ x = x * math.sqrt(self.hidden_channels)
+ x = self.lrelu(x)
+ x = torch.transpose(x, 1, -1)
+ x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype)
+ x = self.encoder(x * x_mask, x_mask)
+ stats = self.proj(x) * x_mask
+
+ m, logs = torch.split(stats, self.out_channels, dim=1)
+ return m, logs, x_mask
+
+
+class PosteriorEncoder(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ gin_channels=0,
+ ):
+ super(PosteriorEncoder, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
+ self.enc = WaveNet(
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ gin_channels=gin_channels,
+ )
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+ def forward(
+ self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
+ ):
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
+ x = self.pre(x) * x_mask
+ x = self.enc(x, x_mask, g=g)
+ stats = self.proj(x) * x_mask
+ m, logs = torch.split(stats, self.out_channels, dim=1)
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
+ return z, m, logs, x_mask
+
+ def remove_weight_norm(self):
+ self.enc.remove_weight_norm()
+
+ def __prepare_scriptable__(self):
+ for hook in self.enc._forward_pre_hooks.values():
+ if (
+ hook.__module__ == "torch.nn.utils.parametrizations.weight_norm"
+ and hook.__class__.__name__ == "_WeightNorm"
+ ):
+ remove_weight_norm(self.enc)
+ return self
+
+
\ No newline at end of file
diff --git a/mvsepless/vbach_lib/algorithm/generators.py b/mvsepless/vbach_lib/algorithm/generators.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfedba9a1a9e5efa0fa85adab8e5899464333d84
--- /dev/null
+++ b/mvsepless/vbach_lib/algorithm/generators.py
@@ -0,0 +1,159 @@
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.utils.weight_norm import remove_weight_norm
+from torch.nn.utils.parametrizations import weight_norm
+from typing import Optional
+
+from .commons import init_weights
+from .residuals import LRELU_SLOPE, ResBlock1, ResBlock2
+
+
+class Generator(nn.Module):
+ def __init__(
+ self,
+ initial_channel,
+ resblock,
+ resblock_kernel_sizes,
+ resblock_dilation_sizes,
+ upsample_rates,
+ upsample_initial_channel,
+ upsample_kernel_sizes,
+ gin_channels=0,
+ ):
+ super(Generator, self).__init__()
+ self.num_kernels = len(resblock_kernel_sizes)
+ self.num_upsamples = len(upsample_rates)
+ self.conv_pre = nn.Conv1d(
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
+ )
+ resblock = ResBlock1 if resblock == "1" else ResBlock2
+
+ self.ups_and_resblocks = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ self.ups_and_resblocks.append(
+ weight_norm(
+ nn.ConvTranspose1d(
+ upsample_initial_channel // (2**i),
+ upsample_initial_channel // (2 ** (i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+ ch = upsample_initial_channel // (2 ** (i + 1))
+ for j, (k, d) in enumerate(
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
+ ):
+ self.ups_and_resblocks.append(resblock(ch, k, d))
+
+ self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
+ self.ups_and_resblocks.apply(init_weights)
+
+ if gin_channels != 0:
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
+
+ def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None):
+ x = self.conv_pre(x)
+ if g is not None:
+ x = x + self.cond(g)
+
+ resblock_idx = 0
+ for _ in range(self.num_upsamples):
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ x = self.ups_and_resblocks[resblock_idx](x)
+ resblock_idx += 1
+ xs = 0
+ for _ in range(self.num_kernels):
+ xs += self.ups_and_resblocks[resblock_idx](x)
+ resblock_idx += 1
+ x = xs / self.num_kernels
+
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def __prepare_scriptable__(self):
+ for l in self.ups_and_resblocks:
+ for hook in l._forward_pre_hooks.values():
+ if (
+ hook.__module__ == "torch.nn.utils.parametrizations.weight_norm"
+ and hook.__class__.__name__ == "_WeightNorm"
+ ):
+ remove_weight_norm(l)
+ return self
+
+ def remove_weight_norm(self):
+ for l in self.ups_and_resblocks:
+ remove_weight_norm(l)
+
+
+class SineGen(nn.Module):
+ def __init__(
+ self,
+ samp_rate,
+ harmonic_num=0,
+ sine_amp=0.1,
+ noise_std=0.003,
+ voiced_threshold=0,
+ flag_for_pulse=False,
+ ):
+ super(SineGen, self).__init__()
+ self.sine_amp = sine_amp
+ self.noise_std = noise_std
+ self.harmonic_num = harmonic_num
+ self.dim = self.harmonic_num + 1
+ self.sample_rate = samp_rate
+ self.voiced_threshold = voiced_threshold
+
+ def _f02uv(self, f0):
+ uv = torch.ones_like(f0)
+ uv = uv * (f0 > self.voiced_threshold)
+ return uv
+
+ def forward(self, f0: torch.Tensor, upp: int):
+ with torch.no_grad():
+ f0 = f0[:, None].transpose(1, 2)
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
+ f0_buf[:, :, 0] = f0[:, :, 0]
+ f0_buf[:, :, 1:] = (
+ f0_buf[:, :, 0:1]
+ * torch.arange(2, self.harmonic_num + 2, device=f0.device)[None, None, :]
+ )
+ rad_values = (f0_buf / float(self.sample_rate)) % 1
+ rand_ini = torch.rand(f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device)
+ rand_ini[:, 0] = 0
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
+ tmp_over_one = torch.cumsum(rad_values, 1)
+ tmp_over_one *= upp
+ tmp_over_one = F.interpolate(
+ tmp_over_one.transpose(2, 1),
+ scale_factor=float(upp),
+ mode="linear",
+ align_corners=True,
+ ).transpose(2, 1)
+ rad_values = F.interpolate(
+ rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest"
+ ).transpose(2, 1)
+ tmp_over_one %= 1
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
+ cumsum_shift = torch.zeros_like(rad_values)
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
+ sine_waves = torch.sin(
+ torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi
+ )
+ sine_waves = sine_waves * self.sine_amp
+ uv = self._f02uv(f0)
+ uv = F.interpolate(
+ uv.transpose(2, 1), scale_factor=float(upp), mode="nearest"
+ ).transpose(2, 1)
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
+ noise = noise_amp * torch.randn_like(sine_waves)
+ sine_waves = sine_waves * uv + noise
+ return sine_waves, uv, noise
+
+
\ No newline at end of file
diff --git a/mvsepless/vbach_lib/algorithm/modules.py b/mvsepless/vbach_lib/algorithm/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..4717fa4c0a97179cfcf8d4235811bed1650386eb
--- /dev/null
+++ b/mvsepless/vbach_lib/algorithm/modules.py
@@ -0,0 +1,95 @@
+
+import torch
+from torch import nn
+from torch.nn.utils.weight_norm import remove_weight_norm
+from torch.nn.utils.parametrizations import weight_norm
+
+from .commons import fused_add_tanh_sigmoid_multiply
+
+
+class WaveNet(nn.Module):
+ def __init__(
+ self,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ gin_channels=0,
+ p_dropout=0,
+ ):
+ super(WaveNet, self).__init__()
+ assert kernel_size % 2 == 1
+ self.hidden_channels = hidden_channels
+ self.kernel_size = (kernel_size,)
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+ self.p_dropout = p_dropout
+
+ self.in_layers = nn.ModuleList()
+ self.res_skip_layers = nn.ModuleList()
+ self.drop = nn.Dropout(p_dropout)
+
+ if gin_channels != 0:
+ cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
+ self.cond_layer = weight_norm(cond_layer, name="weight")
+
+ dilations = [dilation_rate**i for i in range(n_layers)]
+ paddings = [(kernel_size * d - d) // 2 for d in dilations]
+
+ for i in range(n_layers):
+ in_layer = nn.Conv1d(
+ hidden_channels,
+ 2 * hidden_channels,
+ kernel_size,
+ dilation=dilations[i],
+ padding=paddings[i],
+ )
+ in_layer = weight_norm(in_layer, name="weight")
+ self.in_layers.append(in_layer)
+
+ res_skip_channels = (
+ hidden_channels if i == n_layers - 1 else 2 * hidden_channels
+ )
+
+ res_skip_layer = nn.Conv1d(hidden_channels, res_skip_channels, 1)
+ res_skip_layer = weight_norm(res_skip_layer, name="weight")
+ self.res_skip_layers.append(res_skip_layer)
+
+ def forward(self, x, x_mask, g=None, **kwargs):
+ output = torch.zeros_like(x)
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
+
+ if g is not None:
+ g = self.cond_layer(g)
+
+ for i in range(self.n_layers):
+ x_in = self.in_layers[i](x)
+ if g is not None:
+ cond_offset = i * 2 * self.hidden_channels
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
+ else:
+ g_l = torch.zeros_like(x_in)
+
+ acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
+
+ acts = self.drop(acts)
+
+ res_skip_acts = self.res_skip_layers[i](acts)
+ if i < self.n_layers - 1:
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
+ x = (x + res_acts) * x_mask
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
+ else:
+ output = output + res_skip_acts
+ return output * x_mask
+
+ def remove_weight_norm(self):
+ if self.gin_channels != 0:
+ remove_weight_norm(self.cond_layer)
+ for l in self.in_layers:
+ remove_weight_norm(l)
+ for l in self.res_skip_layers:
+ remove_weight_norm(l)
+
+
\ No newline at end of file
diff --git a/mvsepless/vbach_lib/algorithm/normalization.py b/mvsepless/vbach_lib/algorithm/normalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b22833a5a5d0b7b06b620c52f5ebf428d6560aa
--- /dev/null
+++ b/mvsepless/vbach_lib/algorithm/normalization.py
@@ -0,0 +1,18 @@
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, channels, eps=1e-5):
+ super().__init__()
+ self.eps = eps
+ self.gamma = nn.Parameter(torch.ones(channels))
+ self.beta = nn.Parameter(torch.zeros(channels))
+
+ def forward(self, x):
+ x = x.transpose(1, -1)
+ x = F.layer_norm(x, (x.size(-1),), self.gamma, self.beta, self.eps)
+ return x.transpose(1, -1)
+
\ No newline at end of file
diff --git a/mvsepless/vbach_lib/algorithm/nsf.py b/mvsepless/vbach_lib/algorithm/nsf.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a0bbdaff53e486465d494d5d42b932442f644e0
--- /dev/null
+++ b/mvsepless/vbach_lib/algorithm/nsf.py
@@ -0,0 +1,170 @@
+
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.utils.weight_norm import remove_weight_norm
+from torch.nn.utils.parametrizations import weight_norm
+from typing import Optional
+
+from .commons import init_weights
+from .generators import SineGen
+from .residuals import LRELU_SLOPE, ResBlock1, ResBlock2
+
+
+class SourceModuleHnNSF(nn.Module):
+ def __init__(
+ self,
+ sample_rate,
+ harmonic_num=0,
+ sine_amp=0.1,
+ add_noise_std=0.003,
+ voiced_threshod=0,
+ is_half=True,
+ ):
+ super(SourceModuleHnNSF, self).__init__()
+
+ self.sine_amp = sine_amp
+ self.noise_std = add_noise_std
+ self.is_half = is_half
+
+ self.l_sin_gen = SineGen(
+ sample_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
+ )
+ self.l_linear = nn.Linear(harmonic_num + 1, 1)
+ self.l_tanh = nn.Tanh()
+
+ def forward(self, x: torch.Tensor, upsample_factor: int = 1):
+ sine_wavs, uv, _ = self.l_sin_gen(x, upsample_factor)
+ sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype)
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
+ return sine_merge, None, None
+
+
+class GeneratorNSF(nn.Module):
+ def __init__(
+ self,
+ initial_channel,
+ resblock,
+ resblock_kernel_sizes,
+ resblock_dilation_sizes,
+ upsample_rates,
+ upsample_initial_channel,
+ upsample_kernel_sizes,
+ gin_channels,
+ sr,
+ is_half=False,
+ ):
+ super(GeneratorNSF, self).__init__()
+
+ self.num_kernels = len(resblock_kernel_sizes)
+ self.num_upsamples = len(upsample_rates)
+ self.f0_upsamp = nn.Upsample(scale_factor=math.prod(upsample_rates))
+ self.m_source = SourceModuleHnNSF(sample_rate=sr, harmonic_num=0, is_half=is_half)
+
+ self.conv_pre = nn.Conv1d(
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
+ )
+ resblock_cls = ResBlock1 if resblock == "1" else ResBlock2
+
+ self.ups = nn.ModuleList()
+ self.noise_convs = nn.ModuleList()
+
+ channels = [
+ upsample_initial_channel // (2 ** (i + 1)) for i in range(len(upsample_rates))
+ ]
+ stride_f0s = [
+ math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1
+ for i in range(len(upsample_rates))
+ ]
+
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ self.ups.append(
+ weight_norm(
+ nn.ConvTranspose1d(
+ upsample_initial_channel // (2**i),
+ channels[i],
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+
+ self.noise_convs.append(
+ nn.Conv1d(
+ 1,
+ channels[i],
+ kernel_size=(stride_f0s[i] * 2 if stride_f0s[i] > 1 else 1),
+ stride=stride_f0s[i],
+ padding=(stride_f0s[i] // 2 if stride_f0s[i] > 1 else 0),
+ )
+ )
+
+ self.resblocks = nn.ModuleList(
+ [
+ resblock_cls(channels[i], k, d)
+ for i in range(len(self.ups))
+ for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes)
+ ]
+ )
+
+ self.conv_post = nn.Conv1d(channels[-1], 1, 7, 1, padding=3, bias=False)
+ self.ups.apply(init_weights)
+
+ if gin_channels != 0:
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
+
+ self.upp = math.prod(upsample_rates)
+ self.lrelu_slope = LRELU_SLOPE
+
+ def forward(self, x, f0, g: Optional[torch.Tensor] = None):
+ har_source, _, _ = self.m_source(f0, self.upp)
+ har_source = har_source.transpose(1, 2)
+ x = self.conv_pre(x)
+
+ if g is not None:
+ x = x + self.cond(g)
+
+ for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
+ x = F.leaky_relu(x, self.lrelu_slope)
+ x = ups(x)
+ x = x + noise_convs(har_source)
+
+ xs = sum(
+ [
+ resblock(x)
+ for j, resblock in enumerate(self.resblocks)
+ if j in range(i * self.num_kernels, (i + 1) * self.num_kernels)
+ ]
+ )
+ x = xs / self.num_kernels
+
+ x = F.leaky_relu(x)
+ x = torch.tanh(self.conv_post(x))
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+
+ def __prepare_scriptable__(self):
+ for l in self.ups:
+ for hook in l._forward_pre_hooks.values():
+ if (
+ hook.__module__ == "torch.nn.utils.parametrizations.weight_norm"
+ and hook.__class__.__name__ == "_WeightNorm"
+ ):
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ for hook in l._forward_pre_hooks.values():
+ if (
+ hook.__module__ == "torch.nn.utils.parametrizations.weight_norm"
+ and hook.__class__.__name__ == "_WeightNorm"
+ ):
+ remove_weight_norm(l)
+ return self
+
+
\ No newline at end of file
diff --git a/mvsepless/vbach_lib/algorithm/residuals.py b/mvsepless/vbach_lib/algorithm/residuals.py
new file mode 100644
index 0000000000000000000000000000000000000000..271a8c7b101fd49cd4bfd4092c647e50c8065e04
--- /dev/null
+++ b/mvsepless/vbach_lib/algorithm/residuals.py
@@ -0,0 +1,235 @@
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.utils.weight_norm import remove_weight_norm
+from torch.nn.utils.parametrizations import weight_norm
+from typing import Optional
+
+from .commons import get_padding, init_weights
+from .modules import WaveNet
+
+
+LRELU_SLOPE = 0.1
+
+
+def create_conv1d_layer(channels, kernel_size, dilation):
+ return weight_norm(
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation,
+ padding=get_padding(kernel_size, dilation),
+ )
+ )
+
+
+def apply_mask(tensor, mask):
+ return tensor * mask if mask is not None else tensor
+
+
+class ResBlockBase(nn.Module):
+ def __init__(self, channels, kernel_size, dilations):
+ super(ResBlockBase, self).__init__()
+ self.convs1 = nn.ModuleList(
+ [create_conv1d_layer(channels, kernel_size, d) for d in dilations]
+ )
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList(
+ [create_conv1d_layer(channels, kernel_size, 1) for _ in dilations]
+ )
+ self.convs2.apply(init_weights)
+
+ def forward(self, x, x_mask=None):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = apply_mask(xt, x_mask)
+ xt = F.leaky_relu(c1(xt), LRELU_SLOPE)
+ xt = apply_mask(xt, x_mask)
+ xt = c2(xt)
+ x = xt + x
+ return apply_mask(x, x_mask)
+
+ def remove_weight_norm(self):
+ for conv in self.convs1 + self.convs2:
+ remove_weight_norm(conv)
+
+
+class ResBlock1(ResBlockBase):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super(ResBlock1, self).__init__(channels, kernel_size, dilation)
+
+
+class ResBlock2(ResBlockBase):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
+ super(ResBlock2, self).__init__(channels, kernel_size, dilation)
+
+
+class Log(nn.Module):
+ def forward(self, x, x_mask, reverse=False, **kwargs):
+ if not reverse:
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
+ logdet = torch.sum(-y, [1, 2])
+ return y, logdet
+ else:
+ x = torch.exp(x) * x_mask
+ return x
+
+
+class Flip(nn.Module):
+ def forward(self, x, *args, reverse=False, **kwargs):
+ x = torch.flip(x, [1])
+ if not reverse:
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
+ return x, logdet
+ else:
+ return x
+
+
+class ElementwiseAffine(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.channels = channels
+ self.m = nn.Parameter(torch.zeros(channels, 1))
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
+
+ def forward(self, x, x_mask, reverse=False, **kwargs):
+ if not reverse:
+ y = self.m + torch.exp(self.logs) * x
+ y = y * x_mask
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
+ return y, logdet
+ else:
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
+ return x
+
+
+class ResidualCouplingBlock(nn.Module):
+ def __init__(
+ self,
+ channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ n_flows=4,
+ gin_channels=0,
+ ):
+ super(ResidualCouplingBlock, self).__init__()
+ self.channels = channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.n_flows = n_flows
+ self.gin_channels = gin_channels
+
+ self.flows = nn.ModuleList()
+ for i in range(n_flows):
+ self.flows.append(
+ ResidualCouplingLayer(
+ channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ gin_channels=gin_channels,
+ mean_only=True,
+ )
+ )
+ self.flows.append(Flip())
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ g: Optional[torch.Tensor] = None,
+ reverse: bool = False,
+ ):
+ if not reverse:
+ for flow in self.flows:
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
+ else:
+ for flow in reversed(self.flows):
+ x = flow.forward(x, x_mask, g=g, reverse=reverse)
+ return x
+
+ def remove_weight_norm(self):
+ for i in range(self.n_flows):
+ self.flows[i * 2].remove_weight_norm()
+
+ def __prepare_scriptable__(self):
+ for i in range(self.n_flows):
+ for hook in self.flows[i * 2]._forward_pre_hooks.values():
+ if (
+ hook.__module__ == "torch.nn.utils.parametrizations.weight_norm"
+ and hook.__class__.__name__ == "_WeightNorm"
+ ):
+ remove_weight_norm(self.flows[i * 2])
+
+ return self
+
+
+class ResidualCouplingLayer(nn.Module):
+ def __init__(
+ self,
+ channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ p_dropout=0,
+ gin_channels=0,
+ mean_only=False,
+ ):
+ assert channels % 2 == 0, "channels should be divisible by 2"
+ super().__init__()
+ self.channels = channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.half_channels = channels // 2
+ self.mean_only = mean_only
+
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
+ self.enc = WaveNet(
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ p_dropout=p_dropout,
+ gin_channels=gin_channels,
+ )
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
+ self.post.weight.data.zero_()
+ self.post.bias.data.zero_()
+
+ def forward(self, x, x_mask, g=None, reverse=False):
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
+ h = self.pre(x0) * x_mask
+ h = self.enc(h, x_mask, g=g)
+ stats = self.post(h) * x_mask
+ if not self.mean_only:
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
+ else:
+ m = stats
+ logs = torch.zeros_like(m)
+
+ if not reverse:
+ x1 = m + x1 * torch.exp(logs) * x_mask
+ x = torch.cat([x0, x1], 1)
+ logdet = torch.sum(logs, [1, 2])
+ return x, logdet
+ else:
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
+ x = torch.cat([x0, x1], 1)
+ return x
+
+ def remove_weight_norm(self):
+ self.enc.remove_weight_norm()
+
+
\ No newline at end of file
diff --git a/mvsepless/vbach_lib/algorithm/synthesizers.py b/mvsepless/vbach_lib/algorithm/synthesizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7981ed59cba25e14166e44800ba6f10804c82613
--- /dev/null
+++ b/mvsepless/vbach_lib/algorithm/synthesizers.py
@@ -0,0 +1,191 @@
+
+import torch
+from torch import nn
+from torch.nn.utils.weight_norm import remove_weight_norm
+from typing import Optional
+
+from .commons import slice_segments, rand_slice_segments
+from .encoders import TextEncoder, PosteriorEncoder
+from .generators import Generator
+from .nsf import GeneratorNSF
+from .residuals import ResidualCouplingBlock
+
+
+class Synthesizer(nn.Module):
+ def __init__(
+ self,
+ spec_channels,
+ segment_size,
+ inter_channels,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout,
+ resblock,
+ resblock_kernel_sizes,
+ resblock_dilation_sizes,
+ upsample_rates,
+ upsample_initial_channel,
+ upsample_kernel_sizes,
+ spk_embed_dim,
+ gin_channels,
+ sr,
+ use_f0,
+ input_dim=768,
+ **kwargs
+ ):
+ super(Synthesizer, self).__init__()
+ self.spec_channels = spec_channels
+ self.inter_channels = inter_channels
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = float(p_dropout)
+ self.resblock = resblock
+ self.resblock_kernel_sizes = resblock_kernel_sizes
+ self.resblock_dilation_sizes = resblock_dilation_sizes
+ self.upsample_rates = upsample_rates
+ self.upsample_initial_channel = upsample_initial_channel
+ self.upsample_kernel_sizes = upsample_kernel_sizes
+ self.segment_size = segment_size
+ self.gin_channels = gin_channels
+ self.spk_embed_dim = spk_embed_dim
+ self.use_f0 = use_f0
+
+ self.enc_p = TextEncoder(
+ inter_channels,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ float(p_dropout),
+ input_dim,
+ f0=use_f0,
+ )
+
+ if use_f0:
+ self.dec = GeneratorNSF(
+ inter_channels,
+ resblock,
+ resblock_kernel_sizes,
+ resblock_dilation_sizes,
+ upsample_rates,
+ upsample_initial_channel,
+ upsample_kernel_sizes,
+ gin_channels=gin_channels,
+ sr=sr,
+ is_half=kwargs["is_half"],
+ )
+ else:
+ self.dec = Generator(
+ inter_channels,
+ resblock,
+ resblock_kernel_sizes,
+ resblock_dilation_sizes,
+ upsample_rates,
+ upsample_initial_channel,
+ upsample_kernel_sizes,
+ gin_channels=gin_channels,
+ )
+
+ self.enc_q = PosteriorEncoder(
+ spec_channels,
+ inter_channels,
+ hidden_channels,
+ 5,
+ 1,
+ 16,
+ gin_channels=gin_channels,
+ )
+ self.flow = ResidualCouplingBlock(
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
+ )
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
+
+ def remove_weight_norm(self):
+ self.dec.remove_weight_norm()
+ self.flow.remove_weight_norm()
+ self.enc_q.remove_weight_norm()
+
+ def __prepare_scriptable__(self):
+ for hook in self.dec._forward_pre_hooks.values():
+ if (
+ hook.__module__ == "torch.nn.utils.parametrizations.weight_norm"
+ and hook.__class__.__name__ == "_WeightNorm"
+ ):
+ remove_weight_norm(self.dec)
+ for hook in self.flow._forward_pre_hooks.values():
+ if (
+ hook.__module__ == "torch.nn.utils.parametrizations.weight_norm"
+ and hook.__class__.__name__ == "_WeightNorm"
+ ):
+ remove_weight_norm(self.flow)
+ if hasattr(self, "enc_q"):
+ for hook in self.enc_q._forward_pre_hooks.values():
+ if (
+ hook.__module__ == "torch.nn.utils.parametrizations.weight_norm"
+ and hook.__class__.__name__ == "_WeightNorm"
+ ):
+ remove_weight_norm(self.enc_q)
+ return self
+
+ @torch.jit.ignore
+ def forward(
+ self,
+ phone: torch.Tensor,
+ phone_lengths: torch.Tensor,
+ pitch: Optional[torch.Tensor] = None,
+ pitchf: Optional[torch.Tensor] = None,
+ y: torch.Tensor = None,
+ y_lengths: torch.Tensor = None,
+ ds: Optional[torch.Tensor] = None,
+ ):
+ g = self.emb_g(ds).unsqueeze(-1)
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
+ if y is not None:
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
+ z_p = self.flow(z, y_mask, g=g)
+ z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size)
+ if self.use_f0:
+ pitchf = slice_segments(pitchf, ids_slice, self.segment_size, 2)
+ o = self.dec(z_slice, pitchf, g=g)
+ else:
+ o = self.dec(z_slice, g=g)
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
+ else:
+ return None, None, x_mask, None, (None, None, m_p, logs_p, None, None)
+
+ @torch.jit.export
+ def infer(
+ self,
+ phone: torch.Tensor,
+ phone_lengths: torch.Tensor,
+ pitch: Optional[torch.Tensor] = None,
+ nsff0: Optional[torch.Tensor] = None,
+ sid: torch.Tensor = None,
+ rate: Optional[torch.Tensor] = None,
+ ):
+ g = self.emb_g(sid).unsqueeze(-1)
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
+ if rate is not None:
+ assert isinstance(rate, torch.Tensor)
+ head = int(z_p.shape[2] * (1.0 - rate.item()))
+ z_p = z_p[:, :, head:]
+ x_mask = x_mask[:, :, head:]
+ if self.use_f0:
+ nsff0 = nsff0[:, head:]
+ if self.use_f0:
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
+ o = self.dec(z * x_mask, nsff0, g=g)
+ else:
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
+ o = self.dec(z * x_mask, g=g)
+ return o, x_mask, (z, z_p, m_p, logs_p)
+
+
\ No newline at end of file
diff --git a/mvsepless/vbach_lib/fairseq.py b/mvsepless/vbach_lib/fairseq.py
new file mode 100644
index 0000000000000000000000000000000000000000..ada43e3318cb7c183bf4e9fb187cde42e22e5ab6
--- /dev/null
+++ b/mvsepless/vbach_lib/fairseq.py
@@ -0,0 +1,2030 @@
+import contextlib
+import math
+import re
+import sys
+import types
+import uuid
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from omegaconf import DictConfig, open_dict
+from torch import nn
+
+
+class Dictionary:
+ def __init__(self, *args, **kwargs):
+ pass
+
+
+fairseq = types.ModuleType("fairseq")
+fairseq_data = types.ModuleType("fairseq.data")
+fairseq_data_dictionary = types.ModuleType("fairseq.data.dictionary")
+fairseq_data_dictionary.Dictionary = Dictionary
+fairseq.data = fairseq_data
+fairseq_data.dictionary = fairseq_data_dictionary
+sys.modules["fairseq"] = fairseq
+sys.modules["fairseq.data"] = fairseq_data
+sys.modules["fairseq.data.dictionary"] = fairseq_data_dictionary
+
+
+def load_model(filename):
+ state = torch.load(filename, map_location="cpu", weights_only=False)
+
+ model = HubertModel(HubertConfig(**state["cfg"]["model"]), num_classes=int(state["model"]["label_embs_concat"].shape[0]))
+ model.load_state_dict(state["model"], strict=False)
+
+ return model
+
+
+def softmax(x, dim, onnx_trace=False):
+ return F.softmax(x.float(), dim=dim) if onnx_trace else F.softmax(x, dim=dim, dtype=torch.float32)
+
+
+def log_softmax(x, dim, onnx_trace=False):
+ return F.log_softmax(x.float(), dim=dim) if onnx_trace else F.log_softmax(x, dim=dim, dtype=torch.float32)
+
+
+def eval_str_dict(x, type=dict):
+ if x is None:
+ return None
+ if isinstance(x, str):
+ x = eval(x)
+ return x
+
+
+def with_incremental_state(cls):
+ cls.__bases__ = (FairseqIncrementalState,) + tuple(b for b in cls.__bases__ if b != FairseqIncrementalState)
+ return cls
+
+
+def quant_noise(module, p, block_size):
+ if p <= 0:
+ return module
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
+ is_conv = module.weight.ndim == 4
+ if not is_conv:
+ assert module.weight.size(1) % block_size == 0
+ elif module.kernel_size == (1, 1):
+ assert module.in_channels % block_size == 0
+ else:
+ k = module.kernel_size[0] * module.kernel_size[1]
+ assert k % block_size == 0
+
+ def _forward_pre_hook(mod, input):
+ if mod.training:
+ if not is_conv:
+ weight = mod.weight
+ in_features = weight.size(1)
+ out_features = weight.size(0)
+ mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
+ mask.bernoulli_(p)
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
+ else:
+ weight = mod.weight
+ in_channels = mod.in_channels
+ out_channels = mod.out_channels
+
+ if mod.kernel_size == (1, 1):
+ mask = torch.zeros(int(in_channels // block_size * out_channels), device=weight.device)
+ mask.bernoulli_(p)
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
+ else:
+ mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
+ mask.bernoulli_(p)
+ mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
+
+ mask = mask.to(torch.bool)
+ s = 1 / (1 - p)
+ mod.weight.data = s * weight.masked_fill(mask, 0)
+
+ module.register_forward_pre_hook(_forward_pre_hook)
+ return module
+
+
+class FairseqDropout(nn.Module):
+ def __init__(self, p, module_name=None):
+ super().__init__()
+ self.p = p
+ self.module_name = module_name
+ self.apply_during_inference = False
+
+ def forward(self, x, inplace=False):
+ return (
+ F.dropout(x, p=self.p, training=True, inplace=inplace) if self.p > 0 and (self.training or self.apply_during_inference) else x
+ )
+
+ def make_generation_fast_(self, name, retain_dropout=False, retain_dropout_modules=None, **kwargs):
+ if retain_dropout:
+ if retain_dropout_modules is None or self.module_name in retain_dropout_modules:
+ self.apply_during_inference = True
+
+
+class FairseqIncrementalState:
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.init_incremental_state()
+
+ def init_incremental_state(self):
+ self._incremental_state_id = str(uuid.uuid4())
+
+ def _get_full_incremental_state_key(self, key):
+ return f"{self._incremental_state_id}.{key}"
+
+ def get_incremental_state(self, incremental_state, key):
+ full_key = self._get_full_incremental_state_key(key)
+ if incremental_state is None or full_key not in incremental_state:
+ return None
+ return incremental_state[full_key]
+
+ def set_incremental_state(self, incremental_state, key, value):
+ if incremental_state is not None:
+ incremental_state[self._get_full_incremental_state_key(key)] = value
+ return incremental_state
+
+
+class FairseqDecoder(nn.Module):
+ def __init__(self, dictionary):
+ super().__init__()
+ self.dictionary = dictionary
+ self.onnx_trace = False
+ self.adaptive_softmax = None
+
+ def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
+ x, extra = self.extract_features(prev_output_tokens, encoder_out=encoder_out, **kwargs)
+ return self.output_layer(x), extra
+
+ def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs):
+ pass
+
+ def output_layer(self, features, **kwargs):
+ pass
+
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
+ return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
+
+ def get_normalized_probs_scriptable(self, net_output, log_probs, sample=None):
+ if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
+ if sample is not None:
+ assert "target" in sample
+ target = sample["target"]
+ else:
+ target = None
+ out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
+ return out.exp_() if not log_probs else out
+
+ logits = net_output[0]
+ return log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) if log_probs else softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
+
+ def max_positions(self):
+ return 1e6
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ return state_dict
+
+ def prepare_for_onnx_export_(self):
+ self.onnx_trace = True
+
+
+@with_incremental_state
+class FairseqIncrementalDecoder(FairseqDecoder):
+ def __init__(self, dictionary):
+ super().__init__(dictionary)
+
+ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
+ pass
+
+ def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
+ pass
+
+ def reorder_incremental_state(self, incremental_state, new_order):
+ pass
+
+ def reorder_incremental_state_scripting(self, incremental_state, new_order):
+ for module in self.modules():
+ if hasattr(module, "reorder_incremental_state"):
+ result = module.reorder_incremental_state(incremental_state, new_order)
+ if result is not None:
+ incremental_state = result
+
+ def set_beam_size(self, beam_size):
+ if getattr(self, "_beam_size", -1) != beam_size:
+ seen = set()
+
+ def apply_set_beam_size(module):
+ if module != self and hasattr(module, "set_beam_size") and module not in seen:
+ seen.add(module)
+ module.set_beam_size(beam_size)
+
+ self.apply(apply_set_beam_size)
+ self._beam_size = beam_size
+
+
+class MultiheadAttention(FairseqIncrementalDecoder):
+ def __init__(
+ self,
+ embed_dim,
+ num_heads,
+ kdim=None,
+ vdim=None,
+ dropout=0.0,
+ bias=True,
+ add_bias_kv=False,
+ add_zero_attn=False,
+ self_attention=False,
+ encoder_decoder_attention=False,
+ dictionary=None,
+ q_noise=0.0,
+ qn_block_size=8,
+ xformers_att_config=None,
+ xformers_blocksparse_layout=None,
+ xformers_blocksparse_blocksize=16,
+ ):
+ super().__init__(dictionary)
+ xformers_att_config = eval_str_dict(xformers_att_config)
+ self.use_xformers = xformers_att_config is not None
+ if self.use_xformers:
+ raise ImportError
+ self.embed_dim = embed_dim
+ self.kdim = kdim if kdim is not None else embed_dim
+ self.vdim = vdim if vdim is not None else embed_dim
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
+ self.num_heads = num_heads
+ self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__)
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim
+ self.scaling = self.head_dim**-0.5
+ self.self_attention = self_attention
+ self.encoder_decoder_attention = encoder_decoder_attention
+ assert not self.self_attention or self.qkv_same_dim
+ self.k_proj = quant_noise(nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size)
+ self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
+ self.q_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
+ self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
+ if add_bias_kv:
+ self.bias_k, self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim)), nn.Parameter(torch.Tensor(1, 1, embed_dim))
+ else:
+ self.bias_k = self.bias_v = None
+ self.add_zero_attn = add_zero_attn
+ self.beam_size = 1
+ self.reset_parameters()
+ self.onnx_trace = False
+ self.skip_embed_dim_check = False
+ self.init_incremental_state()
+
+ def prepare_for_onnx_export_(self):
+ self.onnx_trace = True
+
+ def reset_parameters(self):
+ if self.qkv_same_dim:
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
+ else:
+ nn.init.xavier_uniform_(self.k_proj.weight)
+ nn.init.xavier_uniform_(self.v_proj.weight)
+ nn.init.xavier_uniform_(self.q_proj.weight)
+
+ nn.init.xavier_uniform_(self.out_proj.weight)
+ if self.out_proj.bias is not None:
+ nn.init.constant_(self.out_proj.bias, 0.0)
+ if self.bias_k is not None:
+ nn.init.xavier_normal_(self.bias_k)
+ if self.bias_v is not None:
+ nn.init.xavier_normal_(self.bias_v)
+
+ def _get_reserve_head_index(self, num_heads_to_keep: int):
+ k_proj_heads_norm, q_proj_heads_norm, v_proj_heads_norm = [], [], []
+ for i in range(self.num_heads):
+ start_idx = i * self.head_dim
+ end_idx = (i + 1) * self.head_dim
+ k_proj_heads_norm.append(
+ torch.sum(torch.abs(self.k_proj.weight[start_idx:end_idx])).tolist()
+ + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist(),
+ )
+ q_proj_heads_norm.append(
+ torch.sum(torch.abs(self.q_proj.weight[start_idx:end_idx])).tolist()
+ + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist(),
+ )
+ v_proj_heads_norm.append(
+ torch.sum(torch.abs(self.v_proj.weight[start_idx:end_idx])).tolist()
+ + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist(),
+ )
+
+ heads_norm = []
+ for i in range(self.num_heads):
+ heads_norm.append(k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i])
+
+ sorted_head_index = sorted(range(self.num_heads), key=lambda k: heads_norm[k], reverse=True)
+ reserve_head_index = []
+ for i in range(num_heads_to_keep):
+ reserve_head_index.append((sorted_head_index[i] * self.head_dim, (sorted_head_index[i] + 1) * self.head_dim))
+ return reserve_head_index
+
+ def _adaptive_prune_heads(self, reserve_head_index):
+ new_q_weight, new_q_bias, new_k_weight, new_k_bias, new_v_weight, new_v_bias, new_out_proj_weight = [], [], [], [], [], [], []
+ for ele in reserve_head_index:
+ start_idx, end_idx = ele
+ new_q_weight.append(self.q_proj.weight[start_idx:end_idx])
+ new_q_bias.append(self.q_proj.bias[start_idx:end_idx])
+ new_k_weight.append(self.k_proj.weight[start_idx:end_idx])
+ new_k_bias.append(self.k_proj.bias[start_idx:end_idx])
+ new_v_weight.append(self.v_proj.weight[start_idx:end_idx])
+ new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
+ new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx])
+ new_q_weight = torch.cat(new_q_weight).detach()
+ new_k_weight = torch.cat(new_k_weight).detach()
+ new_v_weight = torch.cat(new_v_weight).detach()
+ new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach()
+ new_q_weight.requires_grad = True
+ new_k_weight.requires_grad = True
+ new_v_weight.requires_grad = True
+ new_out_proj_weight.requires_grad = True
+ new_q_bias = torch.cat(new_q_bias).detach()
+ new_q_bias.requires_grad = True
+ new_k_bias = torch.cat(new_k_bias).detach()
+ new_k_bias.requires_grad = True
+ new_v_bias = torch.cat(new_v_bias).detach()
+ new_v_bias.requires_grad = True
+ self.q_proj.weight = nn.Parameter(new_q_weight)
+ self.q_proj.bias = nn.Parameter(new_q_bias)
+ self.k_proj.weight = nn.Parameter(new_k_weight)
+ self.k_proj.bias = nn.Parameter(new_k_bias)
+ self.v_proj.weight = nn.Parameter(new_v_weight)
+ self.v_proj.bias = nn.Parameter(new_v_bias)
+ self.out_proj.weight = nn.Parameter(new_out_proj_weight)
+ self.num_heads = len(reserve_head_index)
+ self.embed_dim = self.head_dim * self.num_heads
+ self.q_proj.out_features = self.embed_dim
+ self.k_proj.out_features = self.embed_dim
+ self.v_proj.out_features = self.embed_dim
+
+ def _set_skip_embed_dim_check(self):
+ self.skip_embed_dim_check = True
+
+ def _pad_masks(self, key_padding_mask, attn_mask):
+ if attn_mask is not None:
+ shape = attn_mask.size()[:-1] + torch.Size([1])
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1)
+
+ if key_padding_mask is not None:
+ shape = key_padding_mask.size()[:-1] + torch.Size([1])
+ key_padding_mask = torch.cat([key_padding_mask, key_padding_mask.new_zeros(shape)], dim=-1)
+
+ return key_padding_mask, attn_mask
+
+ def _add_bias(self, k, v, key_padding_mask, attn_mask, bsz):
+ assert self.bias_k is not None or self.bias_v is not None
+ key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask)
+ return torch.cat([k, self.bias_k.repeat(1, bsz, 1)]), torch.cat([v, self.bias_v.repeat(1, bsz, 1)]), key_padding_mask, attn_mask
+
+ def _append_zero_attn(self, k, v, key_padding_mask, attn_mask):
+ zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:]
+ key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask)
+ return (
+ torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2),
+ torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2),
+ key_padding_mask,
+ attn_mask,
+ )
+
+ def forward(
+ self,
+ query,
+ key,
+ value,
+ key_padding_mask=None,
+ incremental_state=None,
+ need_weights=True,
+ static_kv=False,
+ attn_mask=None,
+ before_softmax=False,
+ need_head_weights=False,
+ ):
+ if need_head_weights:
+ need_weights = True
+ is_tpu = query.device.type == "xla"
+ tgt_len, bsz, embed_dim = query.size()
+ src_len = tgt_len
+ if not self.skip_embed_dim_check:
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+ if key is not None:
+ src_len, key_bsz, _ = key.size()
+ if not torch.jit.is_scripting():
+ assert value is not None
+ assert src_len, key_bsz == value.shape[:2]
+
+ if (
+ not self.onnx_trace
+ and not is_tpu
+ and incremental_state is None
+ and not static_kv
+ and not torch.jit.is_scripting()
+ and not self.skip_embed_dim_check
+ ):
+ assert key is not None and value is not None
+ return F.multi_head_attention_forward(
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.num_heads,
+ torch.empty([0]),
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
+ self.bias_k,
+ self.bias_v,
+ self.add_zero_attn,
+ self.dropout_module.p,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ self.training or self.dropout_module.apply_during_inference,
+ key_padding_mask.bool() if key_padding_mask is not None else None,
+ need_weights,
+ attn_mask,
+ use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ )
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if saved_state is not None and "prev_key" in saved_state:
+ if static_kv:
+ assert self.encoder_decoder_attention and not self.self_attention
+ key = value = None
+ else:
+ saved_state = None
+
+ if self.self_attention:
+ q = self.q_proj(query)
+ k = self.k_proj(query)
+ v = self.v_proj(query)
+ elif self.encoder_decoder_attention:
+ q = self.q_proj(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ if self.beam_size > 1 and bsz == key.size(1):
+ key = key.view(key.size(0), -1, self.beam_size, key.size(2))[:, :, 0, :]
+ if key_padding_mask is not None:
+ key_padding_mask = key_padding_mask.view(-1, self.beam_size, key_padding_mask.size(1))[:, 0, :]
+ k = self.k_proj(key)
+ v = self.v_proj(key)
+ else:
+ assert key is not None and value is not None
+ q = self.q_proj(query)
+ k = self.k_proj(key)
+ v = self.v_proj(value)
+
+ q *= self.scaling
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k, v, attn_mask, key_padding_mask = self._add_bias(k, v, attn_mask, key_padding_mask, bsz)
+
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ kv_bsz = bsz
+ if k is not None:
+ kv_bsz = k.size(1)
+ k = k.contiguous().view(-1, kv_bsz * self.num_heads, self.head_dim).transpose(0, 1)
+
+ if v is not None:
+ v = v.contiguous().view(-1, kv_bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if saved_state is not None:
+ if "prev_key" in saved_state:
+ _prev_key = saved_state["prev_key"]
+ assert _prev_key is not None
+
+ kv_bsz = _prev_key.size(0)
+ prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim)
+
+ if static_kv:
+ k = prev_key
+ else:
+ assert k is not None
+ k = torch.cat([prev_key, k], dim=1)
+ src_len = k.size(1)
+
+ if "prev_value" in saved_state:
+ _prev_value = saved_state["prev_value"]
+ assert _prev_value is not None or kv_bsz == _prev_value.size(0)
+ prev_value = _prev_value.view(kv_bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ v = prev_value
+ else:
+ assert v is not None
+ v = torch.cat([prev_value, v], dim=1)
+
+ prev_key_padding_mask = None
+ if "prev_key_padding_mask" in saved_state:
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
+ assert k is not None and v is not None
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
+ key_padding_mask=key_padding_mask,
+ prev_key_padding_mask=prev_key_padding_mask,
+ batch_size=kv_bsz,
+ src_len=k.size(1),
+ static_kv=static_kv,
+ )
+ saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim)
+ saved_state["prev_value"] = v.view(kv_bsz, self.num_heads, -1, self.head_dim)
+ saved_state["prev_key_padding_mask"] = key_padding_mask
+ assert incremental_state is not None
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
+
+ assert k is not None
+ assert k.size(1) == src_len
+
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == kv_bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if self.add_zero_attn:
+ assert v is not None
+ src_len += 1
+ k, v, key_padding_mask, attn_mask = self._append_zero_attn(k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
+
+ if self.encoder_decoder_attention and bsz != kv_bsz:
+ attn_weights = torch.einsum(
+ "bxhtd,bhsd->bxhts",
+ q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]),
+ k.view((kv_bsz, self.num_heads) + k.size()[1:]),
+ )
+ attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:])
+ else:
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ attn_mask = attn_mask.unsqueeze(0)
+ if self.onnx_trace:
+ attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
+ attn_weights += attn_mask
+
+ if key_padding_mask is not None:
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = (
+ attn_weights.view(kv_bsz, -1, self.num_heads, tgt_len, src_len).masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(torch.bool),
+ float("-inf"),
+ )
+ if not is_tpu
+ else attn_weights.transpose(0, 2).masked_fill(key_padding_mask, float("-inf")).transpose(0, 2)
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if before_softmax:
+ return attn_weights, v
+ attn_weights_float = softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = self.dropout_module(attn_weights)
+ assert v is not None
+ attn = None
+
+ if self.encoder_decoder_attention and bsz != kv_bsz:
+ attn = torch.einsum(
+ "bxhts,bhsd->bxhtd",
+ attn_probs.view((kv_bsz, -1, self.num_heads) + attn_probs.size()[1:]),
+ v.view((kv_bsz, self.num_heads) + v.size()[1:]),
+ )
+ attn = attn.reshape((-1,) + attn.size()[-2:])
+ else:
+ attn = torch.bmm(attn_probs, v)
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+
+ attn = (
+ attn.contiguous().view(tgt_len, bsz, self.embed_dim)
+ if self.onnx_trace and attn.size(1) == 1
+ else attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
+ )
+ attn = self.out_proj(attn)
+ attn_weights = None
+
+ if need_weights:
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
+ if not need_head_weights:
+ attn_weights = attn_weights.mean(dim=0)
+
+ return attn, attn_weights
+
+ @staticmethod
+ def _append_prev_key_padding_mask(key_padding_mask, prev_key_padding_mask, batch_size, src_len, static_kv):
+ if prev_key_padding_mask is not None and static_kv:
+ new_key_padding_mask = prev_key_padding_mask
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
+ new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
+ elif prev_key_padding_mask is not None:
+ if src_len > prev_key_padding_mask.size(1):
+ filler = torch.zeros((batch_size, src_len - prev_key_padding_mask.size(1)), device=prev_key_padding_mask.device)
+ new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
+ else:
+ new_key_padding_mask = prev_key_padding_mask.float()
+ elif key_padding_mask is not None:
+ if src_len > key_padding_mask.size(1):
+ filler = torch.zeros((batch_size, src_len - key_padding_mask.size(1)), device=key_padding_mask.device)
+ new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
+ else:
+ new_key_padding_mask = key_padding_mask.float()
+ else:
+ new_key_padding_mask = prev_key_padding_mask
+ return new_key_padding_mask
+
+ @torch.jit.export
+ def reorder_incremental_state(self, incremental_state, new_order):
+ input_buffer = self._get_input_buffer(incremental_state)
+ if input_buffer is not None:
+ for k in input_buffer.keys():
+ input_buffer_k = input_buffer[k]
+ if input_buffer_k is not None:
+ if self.encoder_decoder_attention:
+ if input_buffer_k.size(0) * self.beam_size == new_order.size(0):
+ return incremental_state
+ if self.beam_size > 1:
+ input_buffer[k] = input_buffer_k.index_select(0, new_order.reshape(-1, self.beam_size)[:, 0] // self.beam_size)
+ else:
+ input_buffer[k] = input_buffer_k.index_select(0, new_order)
+ else:
+ input_buffer[k] = input_buffer_k.index_select(0, new_order)
+ incremental_state = self._set_input_buffer(incremental_state, input_buffer)
+ return incremental_state
+
+ def set_beam_size(self, beam_size):
+ self.beam_size = beam_size
+
+ def _get_input_buffer(self, incremental_state):
+ result = self.get_incremental_state(incremental_state, "attn_state")
+ return result if result is not None else {}
+
+ def _set_input_buffer(self, incremental_state, buffer):
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ prefix = name + "." if name != "" else ""
+ items_to_add, keys_to_remove = {}, []
+ for k in state_dict.keys():
+ if k.endswith(prefix + "in_proj_weight"):
+ dim = int(state_dict[k].shape[0] / 3)
+ items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
+ items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
+ items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
+ keys_to_remove.append(k)
+ k_bias = prefix + "in_proj_bias"
+ if k_bias in state_dict.keys():
+ dim = int(state_dict[k].shape[0] / 3)
+ items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
+ items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim]
+ items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
+ keys_to_remove.append(prefix + "in_proj_bias")
+
+ for k in keys_to_remove:
+ del state_dict[k]
+
+ for key, value in items_to_add.items():
+ state_dict[key] = value
+
+
+def init_bert_params(module):
+ def normal_(data):
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
+
+ if isinstance(module, nn.Linear):
+ normal_(module.weight.data)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ if isinstance(module, nn.Embedding):
+ normal_(module.weight.data)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ if isinstance(module, MultiheadAttention):
+ normal_(module.q_proj.weight.data)
+ normal_(module.k_proj.weight.data)
+ normal_(module.v_proj.weight.data)
+
+
+def make_conv_pos(e, k, g):
+ pos_conv = nn.Conv1d(e, e, kernel_size=k, padding=k // 2, groups=g)
+ dropout = 0
+ nn.init.normal_(pos_conv.weight, mean=0, std=math.sqrt((4 * (1.0 - dropout)) / (k * e)))
+ nn.init.constant_(pos_conv.bias, 0)
+ return nn.Sequential(nn.utils.parametrizations.weight_norm(pos_conv, name="weight", dim=2), SamePad(k), nn.GELU())
+
+
+def is_xla_tensor(tensor):
+ return torch.is_tensor(tensor) and tensor.device.type == "xla"
+
+
+def index_put(tensor, indices, value):
+ if is_xla_tensor(tensor):
+ for _ in range(indices.dim(), tensor.dim()):
+ indices = indices.unsqueeze(-1)
+
+ if indices.size(-1) < tensor.size(-1):
+ indices = indices.expand_as(tensor)
+ tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
+ else:
+ tensor[indices] = value
+
+ return tensor
+
+
+def pad_to_multiple(x, multiple, dim=-1, value=0):
+ if x is None:
+ return None, 0
+ tsz = x.size(dim)
+ m = tsz / multiple
+ remainder = math.ceil(m) * multiple - tsz
+ if m.is_integer():
+ return x, 0
+ return F.pad(x, (*((0,) * (-1 - dim) * 2), 0, remainder), value=value), remainder
+
+
+def compute_mask_indices(
+ shape,
+ padding_mask,
+ mask_prob,
+ mask_length,
+ mask_type="static",
+ mask_other=0.0,
+ min_masks=0,
+ no_overlap=False,
+ min_space=0,
+ require_same_masks=True,
+ mask_dropout=0.0,
+ add_masks=False,
+ seed=None,
+ epoch=None,
+ indices=None,
+ idc_select_ver=1,
+ num_mask_ver=2,
+):
+ bsz, all_sz = shape
+ mask = np.full((bsz, all_sz), False)
+ if num_mask_ver == 1:
+ all_num_mask = max(min_masks, int(mask_prob * all_sz / float(mask_length) + np.random.rand()))
+ mask_idcs = []
+
+ for i in range(bsz):
+ seed_i = (
+ int(hash((seed, epoch, indices[i].item())) % 1e6) if seed is not None and epoch is not None and indices is not None else None
+ )
+ rng = np.random.default_rng(seed_i)
+
+ if padding_mask is not None:
+ sz = all_sz - padding_mask[i].long().sum().item()
+ assert sz >= 0, sz
+ else:
+ sz = all_sz
+
+ if num_mask_ver == 1:
+ num_mask = (
+ max(min_masks, int(mask_prob * sz / float(mask_length) + np.random.rand())) if padding_mask is not None else all_num_mask
+ )
+ elif num_mask_ver == 2:
+ num_mask = max(min_masks, int(mask_prob * sz / float(mask_length) + rng.random()))
+ else:
+ raise ValueError
+
+ if mask_type == "static":
+ lengths = np.full(num_mask, mask_length)
+ elif mask_type == "uniform":
+ lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
+ elif mask_type == "normal":
+ lengths = [max(1, int(round(x))) for x in rng.normal(mask_length, mask_other, size=num_mask)]
+ elif mask_type == "poisson":
+ lengths = [int(round(x)) for x in rng.poisson(mask_length, size=num_mask)]
+ else:
+ raise Exception
+
+ if sum(lengths) == 0:
+ if mask_type == "static":
+ raise ValueError
+ lengths = [min(mask_length, sz - 1)]
+
+ if no_overlap:
+ mask_idc = []
+
+ def arrange(s, e, length, keep_length):
+ span_start = rng.randint(s, e - length)
+ mask_idc.extend(span_start + i for i in range(length))
+ new_parts = []
+ if span_start - s - min_space >= keep_length:
+ new_parts.append((s, span_start - min_space + 1))
+ if e - span_start - length - min_space > keep_length:
+ new_parts.append((span_start + length + min_space, e))
+ return new_parts
+
+ parts = [(0, sz)]
+ min_length = min(lengths)
+ for length in sorted(lengths, reverse=True):
+ lens = np.fromiter((e - s if e - s >= length + min_space else 0 for s, e in parts), np.int32)
+ l_sum = np.sum(lens)
+ if l_sum == 0:
+ break
+ s, e = parts.pop(rng.choice(len(parts), p=lens / np.sum(lens)))
+ parts.extend(arrange(s, e, length, min_length))
+ mask_idc = np.asarray(mask_idc)
+ else:
+ if idc_select_ver == 1:
+ min_len = min(lengths)
+ if sz - min_len <= num_mask:
+ min_len = sz - num_mask - 1
+ mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
+ elif idc_select_ver == 2:
+ mask_idc = rng.choice(sz, num_mask, replace=False)
+ else:
+ raise ValueError
+
+ mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
+
+ mask_idc = np.unique(mask_idc[mask_idc < sz])
+ if len(mask_idc) >= sz:
+ raise ValueError
+ mask_idcs.append(mask_idc)
+
+ target_len = None
+ if require_same_masks:
+ target_len = max([len(m) for m in mask_idcs]) if add_masks else min([len(m) for m in mask_idcs])
+
+ for i, mask_idc in enumerate(mask_idcs):
+ if target_len is not None and len(mask_idc) > target_len:
+ mask_idc = rng.choice(mask_idc, target_len, replace=False)
+ mask[i, mask_idc] = True
+
+ if target_len is not None and len(mask_idc) < target_len:
+ to_mask = rng.choice(np.flatnonzero(~mask[i]), target_len - len(mask_idc), replace=False)
+ mask[i, to_mask] = True
+
+ if mask_dropout > 0:
+ masked = np.flatnonzero(mask[i])
+ mask[i, rng.choice(masked, np.rint(len(masked) * mask_dropout).astype(int), replace=False)] = False
+
+ return mask
+
+
+def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
+ return nn.LayerNorm(normalized_shape, eps, elementwise_affine)
+
+
+def prune_state_dict(state_dict, model_cfg):
+ arch = None
+ if model_cfg is not None:
+ arch = model_cfg._name if isinstance(model_cfg, DictConfig) else getattr(model_cfg, "arch", None)
+ if not model_cfg or arch is None or arch == "ptt_transformer":
+ return state_dict
+ encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
+ decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
+ if not encoder_layers_to_keep and not decoder_layers_to_keep:
+ return state_dict
+
+ def create_pruning_pass(layers_to_keep, layer_name):
+ keep_layers = sorted(int(layer_string) for layer_string in layers_to_keep.split(","))
+ mapping_dict = {}
+ for i in range(len(keep_layers)):
+ mapping_dict[str(keep_layers[i])] = str(i)
+
+ return {"substitution_regex": re.compile(rf"^{layer_name}.*\.layers\.(\d+)"), "mapping_dict": mapping_dict}
+
+ pruning_passes, new_state_dict = [], {}
+ if encoder_layers_to_keep:
+ pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
+ if decoder_layers_to_keep:
+ pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
+
+ for layer_name in state_dict.keys():
+ match = re.search(r"\.layers\.(\d+)\.", layer_name)
+ if not match:
+ new_state_dict[layer_name] = state_dict[layer_name]
+ continue
+
+ original_layer_number = match.group(1)
+ for pruning_pass in pruning_passes:
+ if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass["substitution_regex"].search(layer_name):
+ substitution_match = pruning_pass["substitution_regex"].search(layer_name)
+ new_state_dict[
+ (
+ layer_name[: substitution_match.start(1)]
+ + pruning_pass["mapping_dict"][original_layer_number]
+ + layer_name[substitution_match.end(1) :]
+ )
+ ] = state_dict[layer_name]
+
+ with open_dict(model_cfg) if isinstance(model_cfg, DictConfig) else contextlib.ExitStack():
+ if hasattr(model_cfg, "encoder_layers_to_keep"):
+ model_cfg.encoder_layers_to_keep = None
+ if hasattr(model_cfg, "decoder_layers_to_keep"):
+ model_cfg.decoder_layers_to_keep = None
+
+ return new_state_dict
+
+
+def relu_squared(x):
+ return F.relu(x).pow(2)
+
+
+def get_activation_fn(activation):
+ def gelu(x):
+ return nn.functional.gelu(x.float()).type_as(x)
+
+ def gelu_accurate(x):
+ if not hasattr(gelu_accurate, "_a"):
+ gelu_accurate._a = math.sqrt(2 / math.pi)
+ return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
+
+ if activation == "relu":
+ return F.relu
+ if activation == "relu_squared":
+ return relu_squared
+ if activation == "gelu":
+ return gelu
+ if activation == "gelu_fast" or activation == "gelu_accurate":
+ return gelu_accurate
+ if activation == "tanh":
+ return torch.tanh
+ if activation == "linear":
+ return lambda x: x
+ if activation == "swish":
+ return nn.SiLU
+ raise RuntimeError
+
+
+class SamePad(nn.Module):
+ def __init__(self, kernel_size, causal=False):
+ super().__init__()
+ if causal:
+ self.remove = kernel_size - 1
+ else:
+ self.remove = 1 if kernel_size % 2 == 0 else 0
+
+ def forward(self, x):
+ if self.remove > 0:
+ x = x[:, :, : -self.remove]
+ return x
+
+
+class TransformerSentenceEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ embedding_dim=768,
+ ffn_embedding_dim=3072,
+ num_attention_heads=8,
+ dropout=0.1,
+ attention_dropout=0.1,
+ activation_dropout=0.1,
+ activation_fn="relu",
+ layer_norm_first=False,
+ ):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.dropout = dropout
+ self.activation_dropout = activation_dropout
+ self.activation_fn = get_activation_fn(activation_fn)
+ self.self_attn = MultiheadAttention(self.embedding_dim, num_attention_heads, dropout=attention_dropout, self_attention=True)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(self.activation_dropout)
+ self.dropout3 = nn.Dropout(dropout)
+ self.layer_norm_first = layer_norm_first
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
+
+ def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None):
+ residual = x
+ if self.layer_norm_first:
+ x = self.self_attn_layer_norm(x)
+ x, attn = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ attn_mask=self_attn_mask,
+ need_weights=False,
+ )
+ x = residual + self.dropout1(x)
+ residual = x
+ x = self.fc2(self.dropout2(self.activation_fn(self.fc1(self.final_layer_norm(x)))))
+ layer_result = x
+ x = residual + self.dropout3(x)
+ else:
+ x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, need_weights=False)
+ x = self.self_attn_layer_norm(residual + self.dropout1(x))
+ residual = x
+ x = self.fc2(self.dropout2(self.activation_fn(self.fc1(x))))
+ layer_result = x
+ x = self.final_layer_norm(residual + self.dropout3(x))
+
+ return x, (attn, layer_result)
+
+
+class AdapterFast(nn.Module):
+ def __init__(self, adapter_num, input_dim, hidden_dim, act_fn):
+ super().__init__()
+ self.adapter_num = adapter_num
+ self.input_dim = input_dim
+ self.hidden_dim = hidden_dim
+ self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim))
+ self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim))
+ self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim))
+ self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim))
+ self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim))
+ self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim))
+ self.act_fn = nn.Identity()
+ if act_fn == "relu":
+ self.act_fn = nn.ReLU()
+ elif act_fn == "gelu":
+ self.act_fn = nn.GELU()
+ elif act_fn == "selu":
+ self.act_fn = nn.SELU()
+ else:
+ raise ValueError
+ self.input_dim = input_dim
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ for ii in range(self.adapter_num):
+ nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5))
+ nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5))
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii])
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
+ nn.init.uniform_(self.b_a[ii], -bound, bound)
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii])
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
+ nn.init.uniform_(self.b_b[ii], -bound, bound)
+
+ nn.init.ones_(self.ln_W)
+ nn.init.zeros_(self.ln_b)
+
+ def forward(self, x, adapter_id):
+ ii = adapter_id
+ return F.linear(
+ self.act_fn(F.linear(F.layer_norm(x, (self.input_dim,), self.ln_W[ii], self.ln_b[ii]), self.W_a[ii], self.b_a[ii])),
+ self.W_b[ii],
+ self.b_b[ii],
+ )
+
+ def extra_repr(self):
+ return f"adapter={self.adapter_num}, input_dim={self.input_dim}, hidden_dim={self.hidden_dim}"
+
+
+class FeedForwardModule(nn.Module):
+ def __init__(self, input_feat, hidden_units, dropout1, dropout2, activation_fn="swish", bias=True):
+ super(FeedForwardModule, self).__init__()
+ self.layer_norm = LayerNorm(input_feat)
+ self.w_1 = nn.Linear(input_feat, hidden_units, bias=bias)
+ self.w_2 = nn.Linear(hidden_units, input_feat, bias=bias)
+ self.dropout1 = nn.Dropout(dropout1)
+ self.dropout2 = nn.Dropout(dropout2)
+ self.activation = get_activation_fn(activation_fn)(hidden_units)
+
+ def forward(self, x):
+ return self.dropout2(self.w_2(self.dropout1(self.activation(self.w_1(self.layer_norm(x))))))
+
+
+class ConvolutionModule(nn.Module):
+ def __init__(self, embed_dim, channels, depthwise_kernel_size, dropout, activation_fn="swish", bias=False, export=False):
+ super(ConvolutionModule, self).__init__()
+ assert (depthwise_kernel_size - 1) % 2 == 0
+ self.layer_norm = LayerNorm(embed_dim, export=export)
+ self.pointwise_conv1 = nn.Conv1d(embed_dim, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias)
+ self.glu = nn.GLU(dim=1)
+ self.depthwise_conv = nn.Conv1d(
+ channels,
+ channels,
+ depthwise_kernel_size,
+ stride=1,
+ padding=(depthwise_kernel_size - 1) // 2,
+ groups=channels,
+ bias=bias,
+ )
+ self.batch_norm = nn.BatchNorm1d(channels)
+ self.activation = get_activation_fn(activation_fn)(channels)
+ self.pointwise_conv2 = nn.Conv1d(channels, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ return self.dropout(
+ self.pointwise_conv2(
+ self.activation(self.batch_norm(self.depthwise_conv(self.glu(self.pointwise_conv1(self.layer_norm(x).transpose(1, 2)))))),
+ ),
+ ).transpose(1, 2)
+
+
+def rotate_half(x):
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=x1.ndim - 1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
+ cos, sin = (cos[offset : q.shape[0] + offset, ...], sin[offset : q.shape[0] + offset, ...])
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
+
+
+class RotaryPositionalEmbedding(nn.Module):
+ def __init__(self, dim, base=10000, precision=torch.half):
+ super().__init__()
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+ self.seq_len_cached = 0
+ self.cos_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
+ self.sin_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
+ self.precision = precision
+
+ def forward(self, x, seq_len=0):
+ if seq_len > self.seq_len_cached:
+ self.seq_len_cached = seq_len
+ freqs = torch.einsum("i,j->ij", torch.arange(seq_len, device=x.device).type_as(self.inv_freq), self.inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+ self.cos_cached = emb.cos().view(emb.size(0), 1, 1, emb.size(1))
+ self.sin_cached = emb.sin().view(emb.size(0), 1, 1, emb.size(1))
+ return self.cos_cached, self.sin_cached
+
+
+class ESPNETMultiHeadedAttention(nn.Module):
+ def __init__(self, n_feat, n_head, dropout):
+ super(ESPNETMultiHeadedAttention, self).__init__()
+ assert n_feat % n_head == 0
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ self.linear_k = nn.Linear(n_feat, n_feat)
+ self.linear_v = nn.Linear(n_feat, n_feat)
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.attn = None
+ self.dropout = nn.Dropout(p=dropout)
+
+ def forward_qkv(self, query, key, value, **kwargs):
+ n_batch = query.size(0)
+ return (
+ self.linear_q(query).view(n_batch, -1, self.h, self.d_k).transpose(1, 2),
+ self.linear_k(key).view(n_batch, -1, self.h, self.d_k).transpose(1, 2),
+ self.linear_v(value).view(n_batch, -1, self.h, self.d_k).transpose(1, 2),
+ )
+
+ def forward_attention(self, value, scores, mask):
+ n_batch = value.size(0)
+ if mask is not None:
+ scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2).to(bool), float("-inf"))
+ self.attn = torch.softmax(scores, dim=-1)
+ else:
+ self.attn = torch.softmax(scores, dim=-1)
+
+ return self.linear_out(
+ torch.matmul(self.dropout(self.attn), value).transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k),
+ )
+
+ def forward(self, query, key, value, key_padding_mask=None, **kwargs):
+ q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
+ return self.forward_attention(v, torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
+
+
+class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
+ def __init__(self, n_feat, n_head, dropout, zero_triu=False):
+ super().__init__(n_feat, n_head, dropout)
+ self.zero_triu = zero_triu
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
+ self.pos_bias_u = nn.Parameter(torch.zeros(self.h, self.d_k))
+ self.pos_bias_v = nn.Parameter(torch.zeros(self.h, self.d_k))
+ nn.init.xavier_uniform_(self.pos_bias_u)
+ nn.init.xavier_uniform_(self.pos_bias_v)
+
+ def rel_shift(self, x):
+ x = (
+ torch.cat([torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype), x], dim=-1)
+ .view(*x.size()[:2], x.size(3) + 1, x.size(2))[:, :, 1:]
+ .view_as(x)[:, :, :, : x.size(-1) // 2 + 1]
+ )
+ if self.zero_triu:
+ x = x * torch.tril(torch.ones((x.size(2), x.size(3)), device=x.device), x.size(3) - x.size(2))[None, None, :, :]
+ return x
+
+ def forward(self, query, key, value, pos_emb, key_padding_mask=None, **kwargs):
+ pos_emb = pos_emb.transpose(0, 1)
+ q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
+ q = q.transpose(1, 2)
+
+ return (
+ self.forward_attention(
+ v,
+ (
+ torch.matmul((q + self.pos_bias_u).transpose(1, 2), k.transpose(-2, -1))
+ + self.rel_shift(
+ torch.matmul(
+ (q + self.pos_bias_v).transpose(1, 2),
+ self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.h, self.d_k).transpose(1, 2).transpose(-2, -1),
+ ),
+ )
+ )
+ / math.sqrt(self.d_k),
+ key_padding_mask,
+ ).transpose(0, 1),
+ None,
+ )
+
+
+class RotaryPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
+ def __init__(self, n_feat, n_head, dropout, precision, rotary_emd_base=10000):
+ super().__init__(n_feat, n_head, dropout)
+ precision = torch.float
+ self.rotary_ndims = self.d_k
+ if precision == "fp16":
+ precision = torch.half
+ self.rotary_emb = RotaryPositionalEmbedding(self.rotary_ndims, base=rotary_emd_base, precision=precision)
+
+ def forward(self, query, key, value, key_padding_mask=None, **kwargs):
+ T, B, C = value.size()
+ query = query.view(T, B, self.h, self.d_k)
+ key = key.view(T, B, self.h, self.d_k)
+ value = value.view(T, B, self.h, self.d_k)
+ cos, sin = self.rotary_emb(value, seq_len=T)
+ query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=0)
+ query = query.view(T, B, self.h * self.d_k)
+ key = key.view(T, B, self.h * self.d_k)
+ value = value.view(T, B, self.h * self.d_k)
+ q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
+ return self.forward_attention(v, torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
+
+
+class ConformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ embed_dim,
+ ffn_embed_dim,
+ attention_heads,
+ dropout,
+ use_fp16,
+ depthwise_conv_kernel_size=31,
+ activation_fn="swish",
+ attn_type=None,
+ pos_enc_type="abs",
+ ):
+ self.pos_enc_type = pos_enc_type
+ super(ConformerEncoderLayer, self).__init__()
+ self.ffn1 = FeedForwardModule(embed_dim, ffn_embed_dim, dropout, dropout)
+ self.self_attn_layer_norm = LayerNorm(embed_dim, export=False)
+ self.self_attn_dropout = nn.Dropout(dropout)
+ if attn_type == "espnet":
+ if self.pos_enc_type == "rel_pos":
+ self.self_attn = RelPositionMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout)
+ elif self.pos_enc_type == "rope":
+ self.self_attn = RotaryPositionMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout, precision=use_fp16)
+ elif self.pos_enc_type == "abs":
+ self.self_attn = ESPNETMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout)
+ else:
+ raise Exception
+ else:
+ self.self_attn = MultiheadAttention(embed_dim, attention_heads, dropout=dropout)
+ self.conv_module = ConvolutionModule(
+ embed_dim=embed_dim,
+ channels=embed_dim,
+ depthwise_kernel_size=depthwise_conv_kernel_size,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ )
+ self.ffn2 = FeedForwardModule(embed_dim, ffn_embed_dim, dropout, dropout, activation_fn=activation_fn)
+ self.final_layer_norm = LayerNorm(embed_dim, export=False)
+
+ def forward(self, x, encoder_padding_mask, position_emb=None):
+ residual = x
+ x = self.ffn1(x) * 0.5 + residual
+ residual = x
+ x = self.self_attn_layer_norm(x)
+ if self.pos_enc_type == "rel_pos":
+ x, attn = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=encoder_padding_mask,
+ pos_emb=position_emb,
+ need_weights=False,
+ )
+ else:
+ x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=False)
+ x = self.self_attn_dropout(x)
+ x = x + residual
+ residual = x
+ x = residual + self.conv_module(x.transpose(0, 1)).transpose(0, 1)
+ residual = x
+ x = self.ffn2(x)
+ layer_result = x
+ x = self.final_layer_norm(x * 0.5 + residual)
+ return x, (attn, layer_result)
+
+
+class ConformerWav2Vec2EncoderLayer(ConformerEncoderLayer):
+ def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None, position_emb=None):
+ return super().forward(x, self_attn_padding_mask, position_emb)
+
+
+class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer):
+ def __init__(
+ self,
+ embedding_dim=768,
+ ffn_embedding_dim=3072,
+ num_attention_heads=8,
+ dropout=0.1,
+ attention_dropout=0.1,
+ activation_dropout=0.1,
+ activation_fn="relu",
+ layer_norm_first=False,
+ adapter_num=201,
+ adapter_dim=64,
+ adapter_act_fn="relu",
+ ):
+ super().__init__(
+ embedding_dim=embedding_dim,
+ ffn_embedding_dim=ffn_embedding_dim,
+ num_attention_heads=num_attention_heads,
+ dropout=dropout,
+ attention_dropout=attention_dropout,
+ activation_dropout=activation_dropout,
+ activation_fn=activation_fn,
+ layer_norm_first=layer_norm_first,
+ )
+ self.adapter_num = adapter_num
+ self.adapter_dim = adapter_dim
+ self.adapter_layer = AdapterFast(adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn)
+
+ def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None, corpus_key=None):
+ x, (attn, layer_result) = super().forward(
+ x=x,
+ self_attn_mask=self_attn_mask,
+ self_attn_padding_mask=self_attn_padding_mask,
+ need_weights=need_weights,
+ att_args=att_args,
+ )
+ assert corpus_key is not None
+ assert len(set(corpus_key)) == 1
+ return x + self.adapter_layer(x, corpus_key[0]), (attn, layer_result)
+
+
+class TransposeLast(nn.Module):
+ def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
+ super().__init__()
+ self.deconstruct_idx = deconstruct_idx
+ self.tranpose_dim = tranpose_dim
+
+ def forward(self, x):
+ if self.deconstruct_idx is not None:
+ x = x[self.deconstruct_idx]
+ return x.transpose(self.tranpose_dim, -1)
+
+
+class TransformerEncoder(nn.Module):
+ def build_encoder_layer(self, args, **kwargs):
+ if args.layer_type == "transformer":
+ layer = TransformerSentenceEncoderLayer(
+ embedding_dim=self.embedding_dim,
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
+ num_attention_heads=args.encoder_attention_heads,
+ dropout=self.dropout,
+ attention_dropout=args.attention_dropout,
+ activation_dropout=args.activation_dropout,
+ activation_fn=args.activation_fn,
+ layer_norm_first=args.layer_norm_first,
+ )
+ elif args.layer_type == "conformer":
+ layer = ConformerWav2Vec2EncoderLayer(
+ embed_dim=self.embedding_dim,
+ ffn_embed_dim=args.encoder_ffn_embed_dim,
+ attention_heads=args.encoder_attention_heads,
+ dropout=args.dropout,
+ depthwise_conv_kernel_size=args.depthwise_conv_kernel_size,
+ activation_fn="swish",
+ attn_type=args.attn_type,
+ use_fp16=args.fp16,
+ pos_enc_type="abs",
+ )
+ elif args.layer_type == "trf_adp":
+ use_adp = False
+ if args.adp_trf_idx == "all" or kwargs.get("layer_idx") in list(range(*[int(g) for g in args.adp_trf_idx.split(":")])):
+ use_adp = True
+
+ layer = (
+ TransformerSentenceEncoderWithAdapterLayer(
+ embedding_dim=self.embedding_dim,
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
+ num_attention_heads=args.encoder_attention_heads,
+ dropout=self.dropout,
+ attention_dropout=args.attention_dropout,
+ activation_dropout=args.activation_dropout,
+ activation_fn=args.activation_fn,
+ layer_norm_first=args.layer_norm_first,
+ adapter_num=args.adp_num,
+ adapter_dim=args.adp_dim,
+ adapter_act_fn=args.adp_act_fn,
+ )
+ if use_adp
+ else TransformerSentenceEncoderLayer(
+ embedding_dim=self.embedding_dim,
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
+ num_attention_heads=args.encoder_attention_heads,
+ dropout=self.dropout,
+ attention_dropout=args.attention_dropout,
+ activation_dropout=args.activation_dropout,
+ activation_fn=args.activation_fn,
+ layer_norm_first=args.layer_norm_first,
+ )
+ )
+
+ return layer
+
+ def __init__(self, args):
+ super().__init__()
+ self.dropout = args.dropout
+ self.embedding_dim = args.encoder_embed_dim
+ self.required_seq_len_multiple = args.required_seq_len_multiple
+ pos_conv_depth = getattr(args, "pos_conv_depth", 1)
+ if pos_conv_depth > 1:
+ num_layers = args.pos_conv_depth
+ k = max(3, args.conv_pos // num_layers)
+
+ def make_conv_block(e, k, g, l):
+ return nn.Sequential(
+ *[
+ nn.Sequential(
+ nn.Conv1d(e, e, kernel_size=k, padding=k // 2, groups=g),
+ SamePad(k),
+ TransposeLast(),
+ LayerNorm(e, elementwise_affine=False),
+ TransposeLast(),
+ nn.GELU(),
+ )
+ for _ in range(l)
+ ],
+ )
+
+ self.pos_conv = make_conv_block(self.embedding_dim, k, args.conv_pos_groups, num_layers)
+ else:
+ self.pos_conv = make_conv_pos(self.embedding_dim, args.conv_pos, args.conv_pos_groups)
+
+ self.layers = nn.ModuleList([self.build_encoder_layer(args, layer_idx=ii) for ii in range(args.encoder_layers)])
+ self.layer_norm_first = args.layer_norm_first
+ self.layer_norm = LayerNorm(self.embedding_dim)
+ self.layerdrop = args.encoder_layerdrop
+ self.apply(init_bert_params)
+
+ def forward(self, x, padding_mask=None, layer=None, corpus_key=None):
+ x, layer_results = self.extract_features(x, padding_mask, layer, corpus_key=corpus_key)
+ if self.layer_norm_first and layer is None:
+ x = self.layer_norm(x)
+ return x, layer_results
+
+ def extract_features(self, x, padding_mask=None, tgt_layer=None, min_layer=0, corpus_key=None):
+ if padding_mask is not None:
+ x = index_put(x, padding_mask, 0)
+ x = x + self.pos_conv(x.transpose(1, 2)).transpose(1, 2)
+ if not self.layer_norm_first:
+ x = self.layer_norm(x)
+ x, pad_length = pad_to_multiple(x, self.required_seq_len_multiple, dim=-2, value=0)
+ if pad_length > 0 and padding_mask is None:
+ padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
+ padding_mask[:, -pad_length:] = True
+ else:
+ padding_mask, _ = pad_to_multiple(padding_mask, self.required_seq_len_multiple, dim=-1, value=True)
+ x = F.dropout(x, p=self.dropout, training=self.training).transpose(0, 1)
+ layer_results = []
+ r = None
+
+ for i, layer in enumerate(self.layers):
+ dropout_probability = np.random.random() if self.layerdrop > 0 else 1
+ if not self.training or (dropout_probability > self.layerdrop):
+ layer_check = layer
+ if (corpus_key is None) or (not isinstance(layer_check, (TransformerSentenceEncoderWithAdapterLayer))):
+ x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
+ else:
+ x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, corpus_key=corpus_key)
+ if i >= min_layer:
+ layer_results.append((x, z, lr))
+ if i == tgt_layer:
+ r = x
+ break
+
+ if r is not None:
+ x = r
+ x = x.transpose(0, 1)
+
+ if pad_length > 0:
+ x = x[:, :-pad_length]
+
+ def undo_pad(a, b, c):
+ return (a[:-pad_length], b[:-pad_length] if b is not None else b, c[:-pad_length])
+
+ layer_results = [undo_pad(*u) for u in layer_results]
+
+ return x, layer_results
+
+ def max_positions(self):
+ return self.args.max_positions
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ return state_dict
+
+
+class Fp32GroupNorm(nn.GroupNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, input):
+ output = F.group_norm(
+ input.float(),
+ self.num_groups,
+ self.weight.float() if self.weight is not None else None,
+ self.bias.float() if self.bias is not None else None,
+ self.eps,
+ )
+ return output.type_as(input)
+
+
+class Fp32LayerNorm(nn.LayerNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, input):
+ output = F.layer_norm(
+ input.float(),
+ self.normalized_shape,
+ self.weight.float() if self.weight is not None else None,
+ self.bias.float() if self.bias is not None else None,
+ self.eps,
+ )
+ return output.type_as(input)
+
+
+class ConvFeatureExtractionModel(nn.Module):
+ def __init__(self, conv_layers, dropout=0.0, mode="default", conv_bias=False):
+ super().__init__()
+ assert mode in {"default", "layer_norm"}
+
+ def block(n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False):
+ def make_conv():
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
+ nn.init.kaiming_normal_(conv.weight)
+ return conv
+
+ assert (is_layer_norm and is_group_norm) == False
+
+ if is_layer_norm:
+ return nn.Sequential(
+ make_conv(),
+ nn.Dropout(p=dropout),
+ nn.Sequential(TransposeLast(), Fp32LayerNorm(dim, elementwise_affine=True), TransposeLast()),
+ nn.GELU(),
+ )
+ if is_group_norm:
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), Fp32GroupNorm(dim, dim, affine=True), nn.GELU())
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
+
+ in_d = 1
+ self.conv_layers = nn.ModuleList()
+ for i, cl in enumerate(conv_layers):
+ assert len(cl) == 3
+ (dim, k, stride) = cl
+ self.conv_layers.append(
+ block(
+ in_d,
+ dim,
+ k,
+ stride,
+ is_layer_norm=mode == "layer_norm",
+ is_group_norm=mode == "default" and i == 0,
+ conv_bias=conv_bias,
+ ),
+ )
+ in_d = dim
+
+ def forward(self, x):
+ x = x.unsqueeze(1)
+ for conv in self.conv_layers:
+ x = conv(x)
+
+ return x
+
+
+class GradMultiply(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, scale):
+ ctx.scale = scale
+ res = x.new(x)
+ return res
+
+ @staticmethod
+ def backward(ctx, grad):
+ return grad * ctx.scale, None
+
+
+class BaseFairseqModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self._is_generation_fast = False
+
+ def get_targets(self, sample, net_output):
+ return sample["target"]
+
+ def extract_features(self, *args, **kwargs):
+ return self(*args, **kwargs)
+
+ def load_state_dict(self, state_dict, strict=True, model_cfg=None, args=None):
+ self.upgrade_state_dict(state_dict)
+ new_state_dict = prune_state_dict(state_dict, model_cfg)
+ return super().load_state_dict(new_state_dict, strict)
+
+ def upgrade_state_dict(self, state_dict):
+ self.upgrade_state_dict_named(state_dict, "")
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ assert state_dict is not None
+
+ def do_upgrade(m, prefix):
+ if len(prefix) > 0:
+ prefix += "."
+ for n, c in m.named_children():
+ name = prefix + n
+ if hasattr(c, "upgrade_state_dict_named"):
+ c.upgrade_state_dict_named(state_dict, name)
+ elif hasattr(c, "upgrade_state_dict"):
+ c.upgrade_state_dict(state_dict)
+ do_upgrade(c, name)
+
+ do_upgrade(self, name)
+
+ def make_generation_fast_(self, **kwargs):
+ if self._is_generation_fast:
+ return
+ self._is_generation_fast = True
+
+ def apply_remove_weight_norm(module):
+ try:
+ nn.utils.remove_weight_norm(module)
+ except (AttributeError, ValueError):
+ return
+
+ self.apply(apply_remove_weight_norm)
+
+ def apply_make_generation_fast_(module, prefix):
+ if len(prefix) > 0:
+ prefix += "."
+
+ base_func = BaseFairseqModel.make_generation_fast_
+ for n, m in module.named_modules():
+ if m != self and hasattr(m, "make_generation_fast_") and m.make_generation_fast_.__func__ is not base_func:
+ m.make_generation_fast_(name=prefix + n, **kwargs)
+
+ apply_make_generation_fast_(self, "")
+ self.eval()
+
+
+class HubertConfig:
+ def __init__(
+ self,
+ _name=None,
+ label_rate=50,
+ encoder_layers_1=3,
+ logit_temp_ctr=0.1,
+ num_negatives=100,
+ cross_sample_negatives=0,
+ ctr_layers=[-6],
+ crop_seq_to_multiple=1,
+ extractor_mode="default",
+ encoder_layers=12,
+ encoder_embed_dim=768,
+ encoder_ffn_embed_dim=3072,
+ encoder_attention_heads=12,
+ activation_fn="gelu",
+ layer_type="transformer",
+ dropout=0.1,
+ attention_dropout=0.1,
+ activation_dropout=0.0,
+ encoder_layerdrop=0.0,
+ dropout_input=0.0,
+ dropout_features=0.0,
+ final_dim=0,
+ untie_final_proj=False,
+ layer_norm_first=False,
+ conv_feature_layers="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
+ conv_bias=False,
+ logit_temp=0.1,
+ target_glu=False,
+ feature_grad_mult=1.0,
+ mask_length=10,
+ mask_prob=0.65,
+ mask_selection="static",
+ mask_other=0.0,
+ no_mask_overlap=False,
+ mask_min_space=1,
+ mask_channel_length=10,
+ mask_channel_prob=0.0,
+ mask_channel_selection="static",
+ mask_channel_other=0.0,
+ no_mask_channel_overlap=False,
+ mask_channel_min_space=1,
+ conv_pos=128,
+ conv_pos_groups=16,
+ conv_pos_batch_norm=False,
+ latent_temp=(2, 0.5, 0.999995),
+ skip_masked=False,
+ skip_nomask=False,
+ checkpoint_activations=False,
+ required_seq_len_multiple=2,
+ depthwise_conv_kernel_size=31,
+ attn_type="",
+ pos_enc_type="abs",
+ fp16=False,
+ ):
+ self._name = _name
+ self.label_rate = label_rate
+ self.encoder_layers_1 = encoder_layers_1
+ self.logit_temp_ctr = logit_temp_ctr
+ self.num_negatives = num_negatives
+ self.cross_sample_negatives = cross_sample_negatives
+ self.ctr_layers = ctr_layers
+ self.crop_seq_to_multiple = crop_seq_to_multiple
+ self.extractor_mode = extractor_mode
+ self.encoder_layers = encoder_layers
+ self.encoder_embed_dim = encoder_embed_dim
+ self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
+ self.encoder_attention_heads = encoder_attention_heads
+ self.activation_fn = activation_fn
+ self.layer_type = layer_type
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.encoder_layerdrop = encoder_layerdrop
+ self.dropout_input = dropout_input
+ self.dropout_features = dropout_features
+ self.final_dim = final_dim
+ self.untie_final_proj = untie_final_proj
+ self.layer_norm_first = layer_norm_first
+ self.conv_feature_layers = conv_feature_layers
+ self.conv_bias = conv_bias
+ self.logit_temp = logit_temp
+ self.target_glu = target_glu
+ self.feature_grad_mult = feature_grad_mult
+ self.mask_length = mask_length
+ self.mask_prob = mask_prob
+ self.mask_selection = mask_selection
+ self.mask_other = mask_other
+ self.no_mask_overlap = no_mask_overlap
+ self.mask_min_space = mask_min_space
+ self.mask_channel_length = mask_channel_length
+ self.mask_channel_prob = mask_channel_prob
+ self.mask_channel_selection = mask_channel_selection
+ self.mask_channel_other = mask_channel_other
+ self.no_mask_channel_overlap = no_mask_channel_overlap
+ self.mask_channel_min_space = mask_channel_min_space
+ self.conv_pos = conv_pos
+ self.conv_pos_groups = conv_pos_groups
+ self.conv_pos_batch_norm = conv_pos_batch_norm
+ self.latent_temp = latent_temp
+ self.skip_masked = skip_masked
+ self.skip_nomask = skip_nomask
+ self.checkpoint_activations = checkpoint_activations
+ self.required_seq_len_multiple = required_seq_len_multiple
+ self.depthwise_conv_kernel_size = depthwise_conv_kernel_size
+ self.attn_type = attn_type
+ self.pos_enc_type = pos_enc_type
+ self.fp16 = fp16
+
+
+class HubertModel(BaseFairseqModel):
+ def __init__(self, cfg, num_classes):
+ super().__init__()
+ feature_enc_layers = eval(cfg.conv_feature_layers)
+ self.embed = feature_enc_layers[-1][0]
+ self.feature_extractor = ConvFeatureExtractionModel(
+ conv_layers=feature_enc_layers,
+ dropout=0.0,
+ mode=cfg.extractor_mode,
+ conv_bias=cfg.conv_bias,
+ )
+ feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
+ self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / 16000
+ self.post_extract_proj = nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None
+ self.mask_prob = cfg.mask_prob
+ self.mask_selection = cfg.mask_selection
+ self.mask_other = cfg.mask_other
+ self.mask_length = cfg.mask_length
+ self.no_mask_overlap = cfg.no_mask_overlap
+ self.mask_min_space = cfg.mask_min_space
+ self.mask_channel_prob = cfg.mask_channel_prob
+ self.mask_channel_selection = cfg.mask_channel_selection
+ self.mask_channel_other = cfg.mask_channel_other
+ self.mask_channel_length = cfg.mask_channel_length
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
+ self.mask_channel_min_space = cfg.mask_channel_min_space
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
+ self.feature_grad_mult = cfg.feature_grad_mult
+ self.logit_temp = cfg.logit_temp
+ self.skip_masked = cfg.skip_masked
+ self.skip_nomask = cfg.skip_nomask
+ final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
+ self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_())
+ self.encoder = TransformerEncoder(cfg)
+ self.layer_norm = LayerNorm(self.embed)
+ self.target_glu = None
+ if cfg.target_glu:
+ self.target_glu = nn.Sequential(nn.Linear(final_dim, final_dim * 2), nn.GLU())
+ self.untie_final_proj = cfg.untie_final_proj
+ self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
+ self.num_classes = [num_classes]
+ self.label_embs_concat = nn.Parameter(torch.FloatTensor(sum(self.num_classes), final_dim))
+ nn.init.uniform_(self.label_embs_concat)
+
+ def upgrade_state_dict_named(self, state_dict, name):
+ super().upgrade_state_dict_named(state_dict, name)
+ return state_dict
+
+ def apply_mask(self, x, padding_mask, target_list):
+ B, T, C = x.shape
+ if self.mask_prob > 0:
+ mask_indices = torch.from_numpy(
+ compute_mask_indices(
+ (B, T),
+ padding_mask,
+ self.mask_prob,
+ self.mask_length,
+ self.mask_selection,
+ self.mask_other,
+ min_masks=2,
+ no_overlap=self.no_mask_overlap,
+ min_space=self.mask_min_space,
+ ),
+ ).to(x.device)
+ x[mask_indices] = self.mask_emb
+ else:
+ mask_indices = None
+
+ if self.mask_channel_prob > 0:
+ x[
+ (
+ torch.from_numpy(
+ compute_mask_indices(
+ (B, C),
+ None,
+ self.mask_channel_prob,
+ self.mask_channel_length,
+ self.mask_channel_selection,
+ self.mask_channel_other,
+ no_overlap=self.no_mask_channel_overlap,
+ min_space=self.mask_channel_min_space,
+ ),
+ )
+ .to(x.device)
+ .unsqueeze(1)
+ .expand(-1, T, -1)
+ )
+ ] = 0
+ return x, mask_indices
+
+ def compute_nce(self, x, pos, negs):
+ neg_is_pos = (pos == negs).all(-1)
+ logits = torch.cosine_similarity(x.float(), torch.cat([pos.unsqueeze(0), negs], dim=0).float(), dim=-1).type_as(x)
+ logits /= self.logit_temp
+ if neg_is_pos.any():
+ logits[1:][neg_is_pos] = float("-inf")
+ return logits.transpose(0, 1)
+
+ def forward_features(self, source):
+ if self.feature_grad_mult > 0:
+ features = self.feature_extractor(source)
+ if self.feature_grad_mult != 1.0:
+ features = GradMultiply.apply(features, self.feature_grad_mult)
+ else:
+ with torch.no_grad():
+ features = self.feature_extractor(source)
+ return features
+
+ def forward_targets(self, features, target_list):
+ feat_tsz = features.size(2)
+ targ_tsz = min([t.size(1) for t in target_list])
+ if self.feat2tar_ratio * feat_tsz > targ_tsz:
+ feat_tsz = int(targ_tsz / self.feat2tar_ratio)
+ features = features[..., :feat_tsz]
+
+ return features, [t[:, (torch.arange(feat_tsz).float() * self.feat2tar_ratio).long()] for t in target_list]
+
+ def forward_padding_mask(self, features, padding_mask):
+ extra = padding_mask.size(1) % features.size(1)
+ if extra > 0:
+ padding_mask = padding_mask[:, :-extra]
+ return padding_mask.view(padding_mask.size(0), features.size(1), -1).all(-1)
+
+ def forward(self, source, target_list=None, padding_mask=None, mask=True, features_only=False, output_layer=None):
+ features = self.forward_features(source)
+ if target_list is not None:
+ features, target_list = self.forward_targets(features, target_list)
+ features_pen = features.float().pow(2).mean()
+ features = self.layer_norm(features.transpose(1, 2))
+ unmasked_features = features.clone()
+ if padding_mask is not None:
+ padding_mask = self.forward_padding_mask(features, padding_mask)
+ if self.post_extract_proj is not None:
+ features = self.post_extract_proj(features)
+ features = self.dropout_input(features)
+ unmasked_features = self.dropout_features(unmasked_features)
+ if mask:
+ x, mask_indices = self.apply_mask(features, padding_mask, target_list)
+ else:
+ x, mask_indices = features, None
+ x, _ = self.encoder(x, padding_mask=padding_mask, layer=None if output_layer is None else output_layer - 1)
+ if features_only:
+ return {"x": x, "padding_mask": padding_mask, "features": features}
+
+ def compute_pred(proj_x, target, label_embs):
+ y = torch.index_select(label_embs, 0, target.long())
+ negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
+ if self.target_glu:
+ y = self.target_glu(y)
+ negs = self.target_glu(negs)
+
+ return self.compute_nce(proj_x, y, negs)
+
+ label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
+ if not self.skip_masked:
+ masked_indices = torch.logical_and(~padding_mask, mask_indices)
+ proj_x_m = self.final_proj(x[masked_indices])
+ logit_m_list = [
+ compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
+ for i, (proj_x_m, t) in enumerate(
+ zip(
+ proj_x_m.chunk(len(target_list), dim=-1) if self.untie_final_proj else [proj_x_m for _ in range(len(target_list))],
+ target_list,
+ strict=False,
+ ),
+ )
+ ]
+ else:
+ logit_m_list = [None for _ in target_list]
+
+ if not self.skip_nomask:
+ nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
+ proj_x_u = self.final_proj(x[nomask_indices])
+ logit_u_list = [
+ compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
+ for i, (proj_x_u, t) in enumerate(
+ zip(
+ proj_x_u.chunk(len(target_list), dim=-1) if self.untie_final_proj else [proj_x_u for _ in range(len(target_list))],
+ target_list,
+ strict=False,
+ ),
+ )
+ ]
+ else:
+ logit_u_list = [None for _ in target_list]
+
+ return {"logit_m_list": logit_m_list, "logit_u_list": logit_u_list, "padding_mask": padding_mask, "features_pen": features_pen}
+
+ def extract_features(self, source, padding_mask=None, mask=False, ret_conv=False, output_layer=None):
+ res = self.forward(source, padding_mask=padding_mask, mask=mask, features_only=True, output_layer=output_layer)
+ return res["features"] if ret_conv else res["x"], res["padding_mask"]
+
+ def get_logits(self, net_output, is_masked=True):
+ return [x.float() for x in (net_output["logit_m_list"] if is_masked else net_output["logit_u_list"]) if x is not None]
+
+ def get_targets(self, net_output, is_masked=True):
+ return [x.new_zeros(x.size(0), dtype=torch.long) for x in self.get_logits(net_output, is_masked)]
+
+ def get_extra_losses(self, net_output):
+ extra_losses, names = [], []
+ if "features_pen" in net_output:
+ extra_losses.append(net_output["features_pen"])
+ names.append("features_pen")
+
+ return extra_losses, names
+
+ def remove_pretraining_modules(self):
+ self.target_glu = None
+ self.final_proj = None
+
+
+# Добавьте в конец fairseq.py
+
+def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
+ """Упрощенная версия загрузки checkpoint для CPU"""
+ state = torch.load(path, map_location=torch.device("cpu"), weights_only=False)
+ return state
+
+def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1, state=None):
+ """
+ Упрощенная версия загрузки модели, совместимая с infer.py
+ """
+ if isinstance(filenames, str):
+ filenames = [filenames]
+
+ ensemble = []
+ for filename in filenames:
+ # Используем существующую функцию load_model
+ model = load_model(filename)
+ ensemble.append(model)
+
+ # Возвращаем в формате, ожидаемом infer.py
+ return ensemble, None, None
+
+# Создаем алиас для обратной совместимости
+load_model_ensemble = load_model_ensemble_and_task
\ No newline at end of file
diff --git a/mvsepless/vbach_lib/predictors/FCPE.py b/mvsepless/vbach_lib/predictors/FCPE.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ae0d956476f3c1d9d7967b05bca84a280d6bfe2
--- /dev/null
+++ b/mvsepless/vbach_lib/predictors/FCPE.py
@@ -0,0 +1,892 @@
+
+from typing import Union
+
+import torch.nn.functional as F
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn.utils.parametrizations import weight_norm
+from torchaudio.transforms import Resample
+import os
+import librosa
+import soundfile as sf
+import torch.utils.data
+from librosa.filters import mel as librosa_mel_fn
+import math
+from functools import partial
+
+from einops import rearrange, repeat
+from local_attention import LocalAttention
+
+os.environ["LRU_CACHE_CAPACITY"] = "3"
+
+
+def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
+ try:
+ data, sample_rate = sf.read(full_path, always_2d=True)
+ except Exception as error:
+ print(f"An error occurred loading {full_path}: {error}")
+ if return_empty_on_exception:
+ return [], sample_rate or target_sr or 48000
+ else:
+ raise
+
+ data = data[:, 0] if len(data.shape) > 1 else data
+ assert len(data) > 2
+
+ max_mag = (
+ -np.iinfo(data.dtype).min
+ if np.issubdtype(data.dtype, np.integer)
+ else max(np.amax(data), -np.amin(data))
+ )
+ max_mag = (
+ (2**31) + 1 if max_mag > (2**15) else ((2**15) + 1 if max_mag > 1.01 else 1.0)
+ )
+ data = torch.FloatTensor(data.astype(np.float32)) / max_mag
+
+ if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:
+ return [], sample_rate or target_sr or 48000
+ if target_sr is not None and sample_rate != target_sr:
+ data = torch.from_numpy(
+ librosa.core.resample(data.numpy(), orig_sr=sample_rate, target_sr=target_sr)
+ )
+ sample_rate = target_sr
+
+ return data, sample_rate
+
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+ return np.exp(x) / C
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+ return torch.exp(x) / C
+
+
+class STFT:
+ def __init__(
+ self,
+ sr=22050,
+ n_mels=80,
+ n_fft=1024,
+ win_size=1024,
+ hop_length=256,
+ fmin=20,
+ fmax=11025,
+ clip_val=1e-5,
+ ):
+ self.target_sr = sr
+ self.n_mels = n_mels
+ self.n_fft = n_fft
+ self.win_size = win_size
+ self.hop_length = hop_length
+ self.fmin = fmin
+ self.fmax = fmax
+ self.clip_val = clip_val
+ self.mel_basis = {}
+ self.hann_window = {}
+
+ def get_mel(self, y, keyshift=0, speed=1, center=False, train=False):
+ sample_rate = self.target_sr
+ n_mels = self.n_mels
+ n_fft = self.n_fft
+ win_size = self.win_size
+ hop_length = self.hop_length
+ fmin = self.fmin
+ fmax = self.fmax
+ clip_val = self.clip_val
+
+ factor = 2 ** (keyshift / 12)
+ n_fft_new = int(np.round(n_fft * factor))
+ win_size_new = int(np.round(win_size * factor))
+ hop_length_new = int(np.round(hop_length * speed))
+
+ mel_basis = self.mel_basis if not train else {}
+ hann_window = self.hann_window if not train else {}
+
+ mel_basis_key = str(fmax) + "_" + str(y.device)
+ if mel_basis_key not in mel_basis:
+ mel = librosa_mel_fn(
+ sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax
+ )
+ mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device)
+
+ keyshift_key = str(keyshift) + "_" + str(y.device)
+ if keyshift_key not in hann_window:
+ hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
+
+ pad_left = (win_size_new - hop_length_new) // 2
+ pad_right = max(
+ (win_size_new - hop_length_new + 1) // 2,
+ win_size_new - y.size(-1) - pad_left,
+ )
+ mode = "reflect" if pad_right < y.size(-1) else "constant"
+ y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode=mode)
+ y = y.squeeze(1)
+
+ spec = torch.stft(
+ y,
+ n_fft_new,
+ hop_length=hop_length_new,
+ win_length=win_size_new,
+ window=hann_window[keyshift_key],
+ center=center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=True,
+ )
+ spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9))
+
+ if keyshift != 0:
+ size = n_fft // 2 + 1
+ resize = spec.size(1)
+ spec = (
+ F.pad(spec, (0, 0, 0, size - resize))
+ if resize < size
+ else spec[:, :size, :]
+ )
+ spec = spec * win_size / win_size_new
+ spec = torch.matmul(mel_basis[mel_basis_key], spec)
+ spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
+ return spec
+
+ def __call__(self, audiopath):
+ audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
+ spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
+ return spect
+
+
+stft = STFT()
+
+
+def softmax_kernel(
+ data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device=None
+):
+ b, h, *_ = data.shape
+
+ data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.0
+
+ ratio = projection_matrix.shape[0] ** -0.5
+ projection = repeat(projection_matrix, "j d -> b h j d", b=b, h=h)
+ projection = projection.type_as(data)
+ data_dash = torch.einsum("...id,...jd->...ij", (data_normalizer * data), projection)
+
+ diag_data = data**2
+ diag_data = torch.sum(diag_data, dim=-1)
+ diag_data = (diag_data / 2.0) * (data_normalizer**2)
+ diag_data = diag_data.unsqueeze(dim=-1)
+
+ if is_query:
+ data_dash = ratio * (
+ torch.exp(
+ data_dash - diag_data - torch.max(data_dash, dim=-1, keepdim=True).values
+ )
+ + eps
+ )
+ else:
+ data_dash = ratio * (torch.exp(data_dash - diag_data + eps))
+
+ return data_dash.type_as(data)
+
+
+def orthogonal_matrix_chunk(cols, qr_uniform_q=False, device=None):
+ unstructured_block = torch.randn((cols, cols), device=device)
+ q, r = torch.linalg.qr(unstructured_block.cpu(), mode="reduced")
+ q, r = map(lambda t: t.to(device), (q, r))
+
+ if qr_uniform_q:
+ d = torch.diag(r, 0)
+ q *= d.sign()
+ return q.t()
+
+
+def exists(val):
+ return val is not None
+
+
+def empty(tensor):
+ return tensor.numel() == 0
+
+
+def default(val, d):
+ return val if exists(val) else d
+
+
+def cast_tuple(val):
+ return (val,) if not isinstance(val, tuple) else val
+
+
+class PCmer(nn.Module):
+ def __init__(
+ self,
+ num_layers,
+ num_heads,
+ dim_model,
+ dim_keys,
+ dim_values,
+ residual_dropout,
+ attention_dropout,
+ ):
+ super().__init__()
+ self.num_layers = num_layers
+ self.num_heads = num_heads
+ self.dim_model = dim_model
+ self.dim_values = dim_values
+ self.dim_keys = dim_keys
+ self.residual_dropout = residual_dropout
+ self.attention_dropout = attention_dropout
+
+ self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(num_layers)])
+
+ def forward(self, phone, mask=None):
+ for layer in self._layers:
+ phone = layer(phone, mask)
+ return phone
+
+
+class _EncoderLayer(nn.Module):
+ def __init__(self, parent: PCmer):
+ super().__init__()
+ self.conformer = ConformerConvModule(parent.dim_model)
+ self.norm = nn.LayerNorm(parent.dim_model)
+ self.dropout = nn.Dropout(parent.residual_dropout)
+ self.attn = SelfAttention(
+ dim=parent.dim_model, heads=parent.num_heads, causal=False
+ )
+
+ def forward(self, phone, mask=None):
+ phone = phone + (self.attn(self.norm(phone), mask=mask))
+ phone = phone + (self.conformer(phone))
+ return phone
+
+
+def calc_same_padding(kernel_size):
+ pad = kernel_size // 2
+ return (pad, pad - (kernel_size + 1) % 2)
+
+
+class Swish(nn.Module):
+ def forward(self, x):
+ return x * x.sigmoid()
+
+
+class Transpose(nn.Module):
+ def __init__(self, dims):
+ super().__init__()
+ assert len(dims) == 2, "dims must be a tuple of two dimensions"
+ self.dims = dims
+
+ def forward(self, x):
+ return x.transpose(*self.dims)
+
+
+class GLU(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ out, gate = x.chunk(2, dim=self.dim)
+ return out * gate.sigmoid()
+
+
+class DepthWiseConv1d(nn.Module):
+ def __init__(self, chan_in, chan_out, kernel_size, padding):
+ super().__init__()
+ self.padding = padding
+ self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in)
+
+ def forward(self, x):
+ x = F.pad(x, self.padding)
+ return self.conv(x)
+
+
+class ConformerConvModule(nn.Module):
+ def __init__(
+ self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0
+ ):
+ super().__init__()
+
+ inner_dim = dim * expansion_factor
+ padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
+
+ self.net = nn.Sequential(
+ nn.LayerNorm(dim),
+ Transpose((1, 2)),
+ nn.Conv1d(dim, inner_dim * 2, 1),
+ GLU(dim=1),
+ DepthWiseConv1d(
+ inner_dim, inner_dim, kernel_size=kernel_size, padding=padding
+ ),
+ Swish(),
+ nn.Conv1d(inner_dim, dim, 1),
+ Transpose((1, 2)),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def linear_attention(q, k, v):
+ if v is None:
+ out = torch.einsum("...ed,...nd->...ne", k, q)
+ return out
+ else:
+ k_cumsum = k.sum(dim=-2)
+ D_inv = 1.0 / (torch.einsum("...nd,...d->...n", q, k_cumsum.type_as(q)) + 1e-8)
+ context = torch.einsum("...nd,...ne->...de", k, v)
+ out = torch.einsum("...de,...nd,...n->...ne", context, q, D_inv)
+ return out
+
+
+def gaussian_orthogonal_random_matrix(
+ nb_rows, nb_columns, scaling=0, qr_uniform_q=False, device=None
+):
+ nb_full_blocks = int(nb_rows / nb_columns)
+ block_list = []
+
+ for _ in range(nb_full_blocks):
+ q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device)
+ block_list.append(q)
+
+ remaining_rows = nb_rows - nb_full_blocks * nb_columns
+ if remaining_rows > 0:
+ q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device)
+ block_list.append(q[:remaining_rows])
+
+ final_matrix = torch.cat(block_list)
+
+ if scaling == 0:
+ multiplier = torch.randn((nb_rows, nb_columns), device=device).norm(dim=1)
+ elif scaling == 1:
+ multiplier = math.sqrt((float(nb_columns))) * torch.ones(
+ (nb_rows,), device=device
+ )
+ else:
+ raise ValueError(f"Invalid scaling {scaling}")
+
+ return torch.diag(multiplier) @ final_matrix
+
+
+class FastAttention(nn.Module):
+ def __init__(
+ self,
+ dim_heads,
+ nb_features=None,
+ ortho_scaling=0,
+ causal=False,
+ generalized_attention=False,
+ kernel_fn=nn.ReLU(),
+ qr_uniform_q=False,
+ no_projection=False,
+ ):
+ super().__init__()
+ nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
+
+ self.dim_heads = dim_heads
+ self.nb_features = nb_features
+ self.ortho_scaling = ortho_scaling
+
+ self.create_projection = partial(
+ gaussian_orthogonal_random_matrix,
+ nb_rows=self.nb_features,
+ nb_columns=dim_heads,
+ scaling=ortho_scaling,
+ qr_uniform_q=qr_uniform_q,
+ )
+ projection_matrix = self.create_projection()
+ self.register_buffer("projection_matrix", projection_matrix)
+
+ self.generalized_attention = generalized_attention
+ self.kernel_fn = kernel_fn
+ self.no_projection = no_projection
+ self.causal = causal
+
+ @torch.no_grad()
+ def redraw_projection_matrix(self):
+ projections = self.create_projection()
+ self.projection_matrix.copy_(projections)
+ del projections
+
+ def forward(self, q, k, v):
+ device = q.device
+
+ if self.no_projection:
+ q = q.softmax(dim=-1)
+ k = torch.exp(k) if self.causal else k.softmax(dim=-2)
+ else:
+ create_kernel = partial(
+ softmax_kernel, projection_matrix=self.projection_matrix, device=device
+ )
+ q = create_kernel(q, is_query=True)
+ k = create_kernel(k, is_query=False)
+
+ attn_fn = linear_attention if not self.causal else self.causal_linear_fn
+
+ if v is None:
+ out = attn_fn(q, k, None)
+ return out
+ else:
+ out = attn_fn(q, k, v)
+ return out
+
+
+class SelfAttention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ causal=False,
+ heads=8,
+ dim_head=64,
+ local_heads=0,
+ local_window_size=256,
+ nb_features=None,
+ feature_redraw_interval=1000,
+ generalized_attention=False,
+ kernel_fn=nn.ReLU(),
+ qr_uniform_q=False,
+ dropout=0.0,
+ no_projection=False,
+ ):
+ super().__init__()
+ assert dim % heads == 0, "dimension must be divisible by number of heads"
+ dim_head = default(dim_head, dim // heads)
+ inner_dim = dim_head * heads
+ self.fast_attention = FastAttention(
+ dim_head,
+ nb_features,
+ causal=causal,
+ generalized_attention=generalized_attention,
+ kernel_fn=kernel_fn,
+ qr_uniform_q=qr_uniform_q,
+ no_projection=no_projection,
+ )
+
+ self.heads = heads
+ self.global_heads = heads - local_heads
+ self.local_attn = (
+ LocalAttention(
+ window_size=local_window_size,
+ causal=causal,
+ autopad=True,
+ dropout=dropout,
+ look_forward=int(not causal),
+ rel_pos_emb_config=(dim_head, local_heads),
+ )
+ if local_heads > 0
+ else None
+ )
+
+ self.to_q = nn.Linear(dim, inner_dim)
+ self.to_k = nn.Linear(dim, inner_dim)
+ self.to_v = nn.Linear(dim, inner_dim)
+ self.to_out = nn.Linear(inner_dim, dim)
+ self.dropout = nn.Dropout(dropout)
+
+ @torch.no_grad()
+ def redraw_projection_matrix(self):
+ self.fast_attention.redraw_projection_matrix()
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ context_mask=None,
+ name=None,
+ inference=False,
+ **kwargs,
+ ):
+ _, _, _, h, gh = *x.shape, self.heads, self.global_heads
+
+ cross_attend = exists(context)
+ context = default(context, x)
+ context_mask = default(context_mask, mask) if not cross_attend else context_mask
+ q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
+ (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
+
+ attn_outs = []
+ if not empty(q):
+ if exists(context_mask):
+ global_mask = context_mask[:, None, :, None]
+ v.masked_fill_(~global_mask, 0.0)
+ if cross_attend:
+ pass
+ else:
+ out = self.fast_attention(q, k, v)
+ attn_outs.append(out)
+
+ if not empty(lq):
+ assert (
+ not cross_attend
+ ), "local attention is not compatible with cross attention"
+ out = self.local_attn(lq, lk, lv, input_mask=mask)
+ attn_outs.append(out)
+
+ out = torch.cat(attn_outs, dim=1)
+ out = rearrange(out, "b h n d -> b n (h d)")
+ out = self.to_out(out)
+ return self.dropout(out)
+
+
+def l2_regularization(model, l2_alpha):
+ l2_loss = []
+ for module in model.modules():
+ if type(module) is nn.Conv2d:
+ l2_loss.append((module.weight**2).sum() / 2.0)
+ return l2_alpha * sum(l2_loss)
+
+
+class FCPE(nn.Module):
+ def __init__(
+ self,
+ input_channel=128,
+ out_dims=360,
+ n_layers=12,
+ n_chans=512,
+ use_siren=False,
+ use_full=False,
+ loss_mse_scale=10,
+ loss_l2_regularization=False,
+ loss_l2_regularization_scale=1,
+ loss_grad1_mse=False,
+ loss_grad1_mse_scale=1,
+ f0_max=1975.5,
+ f0_min=32.70,
+ confidence=False,
+ threshold=0.05,
+ use_input_conv=True,
+ ):
+ super().__init__()
+ if use_siren is True:
+ raise ValueError("Siren is not supported yet.")
+ if use_full is True:
+ raise ValueError("Full model is not supported yet.")
+
+ self.loss_mse_scale = loss_mse_scale if (loss_mse_scale is not None) else 10
+ self.loss_l2_regularization = (
+ loss_l2_regularization if (loss_l2_regularization is not None) else False
+ )
+ self.loss_l2_regularization_scale = (
+ loss_l2_regularization_scale
+ if (loss_l2_regularization_scale is not None)
+ else 1
+ )
+ self.loss_grad1_mse = loss_grad1_mse if (loss_grad1_mse is not None) else False
+ self.loss_grad1_mse_scale = (
+ loss_grad1_mse_scale if (loss_grad1_mse_scale is not None) else 1
+ )
+ self.f0_max = f0_max if (f0_max is not None) else 1975.5
+ self.f0_min = f0_min if (f0_min is not None) else 32.70
+ self.confidence = confidence if (confidence is not None) else False
+ self.threshold = threshold if (threshold is not None) else 0.05
+ self.use_input_conv = use_input_conv if (use_input_conv is not None) else True
+
+ self.cent_table_b = torch.Tensor(
+ np.linspace(
+ self.f0_to_cent(torch.Tensor([f0_min]))[0],
+ self.f0_to_cent(torch.Tensor([f0_max]))[0],
+ out_dims,
+ )
+ )
+ self.register_buffer("cent_table", self.cent_table_b)
+
+ _leaky = nn.LeakyReLU()
+ self.stack = nn.Sequential(
+ nn.Conv1d(input_channel, n_chans, 3, 1, 1),
+ nn.GroupNorm(4, n_chans),
+ _leaky,
+ nn.Conv1d(n_chans, n_chans, 3, 1, 1),
+ )
+
+ self.decoder = PCmer(
+ num_layers=n_layers,
+ num_heads=8,
+ dim_model=n_chans,
+ dim_keys=n_chans,
+ dim_values=n_chans,
+ residual_dropout=0.1,
+ attention_dropout=0.1,
+ )
+ self.norm = nn.LayerNorm(n_chans)
+
+ self.n_out = out_dims
+ self.dense_out = weight_norm(nn.Linear(n_chans, self.n_out))
+
+ def forward(
+ self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder="local_argmax"
+ ):
+ if cdecoder == "argmax":
+ self.cdecoder = self.cents_decoder
+ elif cdecoder == "local_argmax":
+ self.cdecoder = self.cents_local_decoder
+
+ x = (
+ self.stack(mel.transpose(1, 2)).transpose(1, 2)
+ if self.use_input_conv
+ else mel
+ )
+ x = self.decoder(x)
+ x = self.norm(x)
+ x = self.dense_out(x)
+ x = torch.sigmoid(x)
+
+ if not infer:
+ gt_cent_f0 = self.f0_to_cent(gt_f0)
+ gt_cent_f0 = self.gaussian_blurred_cent(gt_cent_f0)
+ loss_all = self.loss_mse_scale * F.binary_cross_entropy(x, gt_cent_f0)
+ if self.loss_l2_regularization:
+ loss_all = loss_all + l2_regularization(
+ model=self, l2_alpha=self.loss_l2_regularization_scale
+ )
+ x = loss_all
+ if infer:
+ x = self.cdecoder(x)
+ x = self.cent_to_f0(x)
+ x = (1 + x / 700).log() if not return_hz_f0 else x
+
+ return x
+
+ def cents_decoder(self, y, mask=True):
+ B, N, _ = y.size()
+ ci = self.cent_table[None, None, :].expand(B, N, -1)
+ rtn = torch.sum(ci * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True)
+ if mask:
+ confident = torch.max(y, dim=-1, keepdim=True)[0]
+ confident_mask = torch.ones_like(confident)
+ confident_mask[confident <= self.threshold] = float("-INF")
+ rtn = rtn * confident_mask
+ return (rtn, confident) if self.confidence else rtn
+
+ def cents_local_decoder(self, y, mask=True):
+ B, N, _ = y.size()
+ ci = self.cent_table[None, None, :].expand(B, N, -1)
+ confident, max_index = torch.max(y, dim=-1, keepdim=True)
+ local_argmax_index = torch.arange(0, 9).to(max_index.device) + (max_index - 4)
+ local_argmax_index = torch.clamp(local_argmax_index, 0, self.n_out - 1)
+ ci_l = torch.gather(ci, -1, local_argmax_index)
+ y_l = torch.gather(y, -1, local_argmax_index)
+ rtn = torch.sum(ci_l * y_l, dim=-1, keepdim=True) / torch.sum(
+ y_l, dim=-1, keepdim=True
+ )
+ if mask:
+ confident_mask = torch.ones_like(confident)
+ confident_mask[confident <= self.threshold] = float("-INF")
+ rtn = rtn * confident_mask
+ return (rtn, confident) if self.confidence else rtn
+
+ def cent_to_f0(self, cent):
+ return 10.0 * 2 ** (cent / 1200.0)
+
+ def f0_to_cent(self, f0):
+ return 1200.0 * torch.log2(f0 / 10.0)
+
+ def gaussian_blurred_cent(self, cents):
+ mask = (cents > 0.1) & (cents < (1200.0 * np.log2(self.f0_max / 10.0)))
+ B, N, _ = cents.size()
+ ci = self.cent_table[None, None, :].expand(B, N, -1)
+ return torch.exp(-torch.square(ci - cents) / 1250) * mask.float()
+
+
+class FCPEInfer:
+ def __init__(self, model_path, device=None, dtype=torch.float32):
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.device = device
+ ckpt = torch.load(model_path, map_location=torch.device(self.device))
+ self.args = DotDict(ckpt["config"])
+ self.dtype = dtype
+ model = FCPE(
+ input_channel=self.args.model.input_channel,
+ out_dims=self.args.model.out_dims,
+ n_layers=self.args.model.n_layers,
+ n_chans=self.args.model.n_chans,
+ use_siren=self.args.model.use_siren,
+ use_full=self.args.model.use_full,
+ loss_mse_scale=self.args.loss.loss_mse_scale,
+ loss_l2_regularization=self.args.loss.loss_l2_regularization,
+ loss_l2_regularization_scale=self.args.loss.loss_l2_regularization_scale,
+ loss_grad1_mse=self.args.loss.loss_grad1_mse,
+ loss_grad1_mse_scale=self.args.loss.loss_grad1_mse_scale,
+ f0_max=self.args.model.f0_max,
+ f0_min=self.args.model.f0_min,
+ confidence=self.args.model.confidence,
+ )
+ model.to(self.device).to(self.dtype)
+ model.load_state_dict(ckpt["model"])
+ model.eval()
+ self.model = model
+ self.wav2mel = Wav2Mel(self.args, dtype=self.dtype, device=self.device)
+
+ @torch.no_grad()
+ def __call__(self, audio, sr, threshold=0.05):
+ self.model.threshold = threshold
+ audio = audio[None, :]
+ mel = self.wav2mel(audio=audio, sample_rate=sr).to(self.dtype)
+ f0 = self.model(mel=mel, infer=True, return_hz_f0=True)
+ return f0
+
+
+class Wav2Mel:
+ def __init__(self, args, device=None, dtype=torch.float32):
+ self.sample_rate = args.mel.sampling_rate
+ self.hop_size = args.mel.hop_size
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.device = device
+ self.dtype = dtype
+ self.stft = STFT(
+ args.mel.sampling_rate,
+ args.mel.num_mels,
+ args.mel.n_fft,
+ args.mel.win_size,
+ args.mel.hop_size,
+ args.mel.fmin,
+ args.mel.fmax,
+ )
+ self.resample_kernel = {}
+
+ def extract_nvstft(self, audio, keyshift=0, train=False):
+ mel = self.stft.get_mel(audio, keyshift=keyshift, train=train).transpose(1, 2)
+ return mel
+
+ def extract_mel(self, audio, sample_rate, keyshift=0, train=False):
+ audio = audio.to(self.dtype).to(self.device)
+ if sample_rate == self.sample_rate:
+ audio_res = audio
+ else:
+ key_str = str(sample_rate)
+ if key_str not in self.resample_kernel:
+ self.resample_kernel[key_str] = Resample(
+ sample_rate, self.sample_rate, lowpass_filter_width=128
+ )
+ self.resample_kernel[key_str] = (
+ self.resample_kernel[key_str].to(self.dtype).to(self.device)
+ )
+ audio_res = self.resample_kernel[key_str](audio)
+
+ mel = self.extract_nvstft(audio_res, keyshift=keyshift, train=train)
+ n_frames = int(audio.shape[1] // self.hop_size) + 1
+ mel = torch.cat((mel, mel[:, -1:, :]), 1) if n_frames > int(mel.shape[1]) else mel
+ mel = mel[:, :n_frames, :] if n_frames < int(mel.shape[1]) else mel
+ return mel
+
+ def __call__(self, audio, sample_rate, keyshift=0, train=False):
+ return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train)
+
+
+class DotDict(dict):
+ def __getattr__(*args):
+ val = dict.get(*args)
+ return DotDict(val) if type(val) is dict else val
+
+ __setattr__ = dict.__setitem__
+ __delattr__ = dict.__delitem__
+
+
+class F0Predictor(object):
+ def compute_f0(self, wav, p_len):
+ pass
+
+ def compute_f0_uv(self, wav, p_len):
+ pass
+
+
+class FCPEF0Predictor(F0Predictor):
+ def __init__(
+ self,
+ model_path,
+ hop_length=512,
+ f0_min=50,
+ f0_max=1100,
+ dtype=torch.float32,
+ device=None,
+ sample_rate=44100,
+ threshold=0.05,
+ ):
+ self.fcpe = FCPEInfer(model_path, device=device, dtype=dtype)
+ self.hop_length = hop_length
+ self.f0_min = f0_min
+ self.f0_max = f0_max
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
+ self.threshold = threshold
+ self.sample_rate = sample_rate
+ self.dtype = dtype
+ self.name = "fcpe"
+
+ def repeat_expand(
+ self,
+ content: Union[torch.Tensor, np.ndarray],
+ target_len: int,
+ mode: str = "nearest",
+ ):
+ ndim = content.ndim
+ content = (
+ content[None, None] if ndim == 1 else content[None] if ndim == 2 else content
+ )
+ assert content.ndim == 3
+ is_np = isinstance(content, np.ndarray)
+ content = torch.from_numpy(content) if is_np else content
+ results = torch.nn.functional.interpolate(content, size=target_len, mode=mode)
+ results = results.numpy() if is_np else results
+ return results[0, 0] if ndim == 1 else results[0] if ndim == 2 else results
+
+ def post_process(self, x, sample_rate, f0, pad_to):
+ f0 = (
+ torch.from_numpy(f0).float().to(x.device)
+ if isinstance(f0, np.ndarray)
+ else f0
+ )
+ f0 = self.repeat_expand(f0, pad_to) if pad_to is not None else f0
+
+ vuv_vector = torch.zeros_like(f0)
+ vuv_vector[f0 > 0.0] = 1.0
+ vuv_vector[f0 <= 0.0] = 0.0
+
+ nzindex = torch.nonzero(f0).squeeze()
+ f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
+ time_org = self.hop_length / sample_rate * nzindex.cpu().numpy()
+ time_frame = np.arange(pad_to) * self.hop_length / sample_rate
+
+ vuv_vector = F.interpolate(vuv_vector[None, None, :], size=pad_to)[0][0]
+
+ if f0.shape[0] <= 0:
+ return np.zeros(pad_to), vuv_vector.cpu().numpy()
+ if f0.shape[0] == 1:
+ return np.ones(pad_to) * f0[0], vuv_vector.cpu().numpy()
+
+ f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
+ return f0, vuv_vector.cpu().numpy()
+
+ def compute_f0(self, wav, p_len=None):
+ x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
+ p_len = x.shape[0] // self.hop_length if p_len is None else p_len
+ f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold)[0, :, 0]
+ if torch.all(f0 == 0):
+ return f0.cpu().numpy() if p_len is None else np.zeros(p_len), (
+ f0.cpu().numpy() if p_len is None else np.zeros(p_len)
+ )
+ return self.post_process(x, self.sample_rate, f0, p_len)[0]
+
+ def compute_f0_uv(self, wav, p_len=None):
+ x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
+ p_len = x.shape[0] // self.hop_length if p_len is None else p_len
+ f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold)[0, :, 0]
+ if torch.all(f0 == 0):
+ return f0.cpu().numpy() if p_len is None else np.zeros(p_len), (
+ f0.cpu().numpy() if p_len is None else np.zeros(p_len)
+ )
+ return self.post_process(x, self.sample_rate, f0, p_len)
+
diff --git a/mvsepless/vbach_lib/predictors/RMVPE.py b/mvsepless/vbach_lib/predictors/RMVPE.py
new file mode 100644
index 0000000000000000000000000000000000000000..88716914ad4754dda0a99bab6ac66bb0cb49ff92
--- /dev/null
+++ b/mvsepless/vbach_lib/predictors/RMVPE.py
@@ -0,0 +1,518 @@
+
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+from librosa.filters import mel
+from scipy.signal import get_window
+from librosa.util import pad_center, tiny, normalize
+
+
+def window_sumsquare(
+ window,
+ n_frames,
+ hop_length=200,
+ win_length=800,
+ n_fft=800,
+ dtype=np.float32,
+ norm=None,
+):
+ if win_length is None:
+ win_length = n_fft
+
+ n = n_fft + hop_length * (n_frames - 1)
+ x = np.zeros(n, dtype=dtype)
+
+ win_sq = get_window(window, win_length, fftbins=True)
+ win_sq = normalize(win_sq, norm=norm) ** 2
+ win_sq = pad_center(win_sq, n_fft)
+
+ for i in range(n_frames):
+ sample = i * hop_length
+ x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
+ return x
+
+
+class STFT(nn.Module):
+ def __init__(
+ self, filter_length=1024, hop_length=512, win_length=None, window="hann"
+ ):
+ super(STFT, self).__init__()
+ self.filter_length = filter_length
+ self.hop_length = hop_length
+ self.win_length = win_length if win_length else filter_length
+ self.window = window
+ self.pad_amount = int(self.filter_length / 2)
+ scale = self.filter_length / self.hop_length
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
+
+ cutoff = int((self.filter_length / 2 + 1))
+ fourier_basis = np.vstack(
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
+ )
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
+ inverse_basis = torch.FloatTensor(
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :]
+ )
+
+ assert filter_length >= self.win_length
+ fft_window = get_window(window, self.win_length, fftbins=True)
+ fft_window = pad_center(fft_window, size=filter_length)
+ fft_window = torch.from_numpy(fft_window).float()
+
+ forward_basis *= fft_window
+ inverse_basis *= fft_window
+
+ self.register_buffer("forward_basis", forward_basis.float())
+ self.register_buffer("inverse_basis", inverse_basis.float())
+
+ def transform(self, input_data):
+ num_batches = input_data.shape[0]
+ num_samples = input_data.shape[-1]
+
+ input_data = input_data.view(num_batches, 1, num_samples)
+ input_data = F.pad(
+ input_data.unsqueeze(1),
+ (self.pad_amount, self.pad_amount, 0, 0, 0, 0),
+ mode="reflect",
+ ).squeeze(1)
+ forward_transform = F.conv1d(
+ input_data, self.forward_basis, stride=self.hop_length, padding=0
+ )
+
+ cutoff = int((self.filter_length / 2) + 1)
+ real_part = forward_transform[:, :cutoff, :]
+ imag_part = forward_transform[:, cutoff:, :]
+ return torch.sqrt(real_part**2 + imag_part**2)
+
+ def inverse(self, magnitude, phase):
+ recombine_magnitude_phase = torch.cat(
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
+ )
+ inverse_transform = F.conv_transpose1d(
+ recombine_magnitude_phase,
+ self.inverse_basis,
+ stride=self.hop_length,
+ padding=0,
+ )
+
+ if self.window is not None:
+ window_sum = window_sumsquare(
+ self.window,
+ magnitude.size(-1),
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ n_fft=self.filter_length,
+ dtype=np.float32,
+ )
+ approx_nonzero_indices = torch.from_numpy(
+ np.where(window_sum > tiny(window_sum))[0]
+ )
+ window_sum = torch.from_numpy(window_sum).to(inverse_transform.device)
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
+ approx_nonzero_indices
+ ]
+ inverse_transform *= float(self.filter_length) / self.hop_length
+
+ inverse_transform = inverse_transform[..., self.pad_amount :]
+ inverse_transform = inverse_transform[..., : self.num_samples]
+ return inverse_transform.squeeze(1)
+
+ def forward(self, input_data):
+ self.magnitude, self.phase = self.transform(input_data)
+ return self.inverse(self.magnitude, self.phase)
+
+
+class BiGRU(nn.Module):
+ def __init__(self, input_features, hidden_features, num_layers):
+ super(BiGRU, self).__init__()
+ self.gru = nn.GRU(
+ input_features,
+ hidden_features,
+ num_layers=num_layers,
+ batch_first=True,
+ bidirectional=True,
+ )
+
+ def forward(self, x):
+ return self.gru(x)[0]
+
+
+class ConvBlockRes(nn.Module):
+ def __init__(self, in_channels, out_channels, momentum=0.01):
+ super(ConvBlockRes, self).__init__()
+ self.conv = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(3, 3),
+ stride=(1, 1),
+ padding=(1, 1),
+ bias=False,
+ ),
+ nn.BatchNorm2d(out_channels, momentum=momentum),
+ nn.ReLU(),
+ nn.Conv2d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=(3, 3),
+ stride=(1, 1),
+ padding=(1, 1),
+ bias=False,
+ ),
+ nn.BatchNorm2d(out_channels, momentum=momentum),
+ nn.ReLU(),
+ )
+ self.shortcut = (
+ nn.Conv2d(in_channels, out_channels, (1, 1))
+ if in_channels != out_channels
+ else None
+ )
+
+ def forward(self, x):
+ out = self.conv(x)
+ if self.shortcut is not None:
+ x = self.shortcut(x)
+ return out + x
+
+
+class ResEncoderBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
+ super(ResEncoderBlock, self).__init__()
+ self.conv = nn.ModuleList(
+ [
+ ConvBlockRes(
+ in_channels if i == 0 else out_channels, out_channels, momentum
+ )
+ for i in range(n_blocks)
+ ]
+ )
+ self.pool = (
+ nn.AvgPool2d(kernel_size=kernel_size) if kernel_size is not None else None
+ )
+
+ def forward(self, x):
+ for conv in self.conv:
+ x = conv(x)
+ pooled = self.pool(x) if self.pool is not None else x
+ return pooled, x
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ in_size,
+ n_encoders,
+ kernel_size,
+ n_blocks,
+ out_channels=16,
+ momentum=0.01,
+ ):
+ super(Encoder, self).__init__()
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
+ self.layers = nn.ModuleList()
+ self.latent_channels = []
+ for _ in range(n_encoders):
+ self.layers.append(
+ ResEncoderBlock(
+ in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
+ )
+ )
+ self.latent_channels.append([out_channels, in_size])
+ in_channels = out_channels
+ out_channels *= 2
+ in_size //= 2
+ self.out_size = in_size
+ self.out_channel = out_channels
+
+ def forward(self, x):
+ concat_tensors = []
+ x = self.bn(x)
+ for layer in self.layers:
+ x, pooled = layer(x)
+ concat_tensors.append(pooled)
+ return x, concat_tensors
+
+
+class Intermediate(nn.Module):
+ def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
+ super(Intermediate, self).__init__()
+ self.layers = nn.ModuleList(
+ [
+ ResEncoderBlock(
+ in_channels if i == 0 else out_channels,
+ out_channels,
+ None,
+ n_blocks,
+ momentum,
+ )
+ for i in range(n_inters)
+ ]
+ )
+
+ def forward(self, x):
+ for layer in self.layers:
+ _, x = layer(x)
+ return x
+
+
+class ResDecoderBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
+ super(ResDecoderBlock, self).__init__()
+ out_padding = (0, 1) if stride == (1, 2) else (1, 1)
+ self.conv1 = nn.Sequential(
+ nn.ConvTranspose2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(3, 3),
+ stride=stride,
+ padding=(1, 1),
+ output_padding=out_padding,
+ bias=False,
+ ),
+ nn.BatchNorm2d(out_channels, momentum=momentum),
+ nn.ReLU(),
+ )
+ self.conv2 = nn.ModuleList(
+ [
+ ConvBlockRes(
+ out_channels * 2 if i == 0 else out_channels, out_channels, momentum
+ )
+ for i in range(n_blocks)
+ ]
+ )
+
+ def forward(self, x, concat_tensor):
+ x = self.conv1(x)
+ x = torch.cat((x, concat_tensor), dim=1)
+ for conv in self.conv2:
+ x = conv(x)
+ return x
+
+
+class Decoder(nn.Module):
+ def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
+ super(Decoder, self).__init__()
+ self.layers = nn.ModuleList()
+ for _ in range(n_decoders):
+ out_channels = in_channels // 2
+ self.layers.append(
+ ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
+ )
+ in_channels = out_channels
+
+ def forward(self, x, concat_tensors):
+ for layer, concat_tensor in zip(self.layers, reversed(concat_tensors)):
+ x = layer(x, concat_tensor)
+ return x
+
+
+class DeepUnet(nn.Module):
+ def __init__(
+ self,
+ kernel_size,
+ n_blocks,
+ en_de_layers=5,
+ inter_layers=4,
+ in_channels=1,
+ en_out_channels=16,
+ ):
+ super(DeepUnet, self).__init__()
+ self.encoder = Encoder(
+ in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
+ )
+ self.intermediate = Intermediate(
+ self.encoder.out_channel // 2,
+ self.encoder.out_channel,
+ inter_layers,
+ n_blocks,
+ )
+ self.decoder = Decoder(
+ self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
+ )
+
+ def forward(self, x):
+ x, concat_tensors = self.encoder(x)
+ x = self.intermediate(x)
+ return self.decoder(x, concat_tensors)
+
+
+class E2E(nn.Module):
+ def __init__(
+ self,
+ n_blocks,
+ n_gru,
+ kernel_size,
+ en_de_layers=5,
+ inter_layers=4,
+ in_channels=1,
+ en_out_channels=16,
+ ):
+ super(E2E, self).__init__()
+ self.unet = DeepUnet(
+ kernel_size,
+ n_blocks,
+ en_de_layers,
+ inter_layers,
+ in_channels,
+ en_out_channels,
+ )
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
+ if n_gru:
+ self.fc = nn.Sequential(
+ BiGRU(3 * 128, 256, n_gru),
+ nn.Linear(512, 360),
+ nn.Dropout(0.25),
+ nn.Sigmoid(),
+ )
+ else:
+ self.fc = nn.Sequential(
+ nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
+ )
+
+ def forward(self, mel):
+ mel = mel.transpose(-1, -2).unsqueeze(1)
+ x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
+ return self.fc(x)
+
+
+class MelSpectrogram(nn.Module):
+ def __init__(
+ self,
+ is_half,
+ n_mel_channels,
+ sample_rate,
+ win_length,
+ hop_length,
+ n_fft=None,
+ mel_fmin=0,
+ mel_fmax=None,
+ clamp=1e-5,
+ ):
+ super(MelSpectrogram, self).__init__()
+ n_fft = win_length if n_fft is None else n_fft
+ self.hann_window = {}
+ mel_basis = mel(
+ sr=sample_rate,
+ n_fft=n_fft,
+ n_mels=n_mel_channels,
+ fmin=mel_fmin,
+ fmax=mel_fmax,
+ htk=True,
+ )
+ self.register_buffer("mel_basis", torch.from_numpy(mel_basis).float())
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.win_length = win_length
+ self.sample_rate = sample_rate
+ self.n_mel_channels = n_mel_channels
+ self.clamp = clamp
+ self.is_half = is_half
+
+ def forward(self, audio, keyshift=0, speed=1, center=True):
+ factor = 2 ** (keyshift / 12)
+ n_fft_new = int(np.round(self.n_fft * factor))
+ win_length_new = int(np.round(self.win_length * factor))
+ hop_length_new = int(np.round(self.hop_length * speed))
+ keyshift_key = f"{keyshift}_{audio.device}"
+ if keyshift_key not in self.hann_window:
+ self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
+ audio.device
+ )
+ if not hasattr(self, "stft"):
+ self.stft = STFT(
+ filter_length=n_fft_new,
+ hop_length=hop_length_new,
+ win_length=win_length_new,
+ window="hann",
+ ).to(audio.device)
+ magnitude = self.stft.transform(audio)
+ if keyshift != 0:
+ size = self.n_fft // 2 + 1
+ resize = magnitude.size(1)
+ if resize < size:
+ magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
+ mel_output = torch.matmul(self.mel_basis, magnitude)
+ if self.is_half:
+ mel_output = mel_output.half()
+ return torch.log(torch.clamp(mel_output, min=self.clamp))
+
+
+class RMVPE0Predictor:
+ def __init__(self, model_path, is_half, device=None):
+ self.resample_kernel = {}
+ self.is_half = is_half
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.device = device
+ self.mel_extractor = MelSpectrogram(
+ is_half, 128, 16000, 1024, 160, None, 30, 8000
+ ).to(device)
+ model = E2E(4, 1, (2, 2))
+ ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
+ model.load_state_dict(ckpt)
+ model.eval()
+ if is_half:
+ model = model.half()
+ self.model = model.to(device)
+ self.cents_mapping = np.pad(20 * np.arange(360) + 1997.3794084376191, (4, 4))
+
+ def mel2hidden(self, mel):
+ with torch.no_grad():
+ n_frames = mel.shape[-1]
+ mel = mel.float()
+ padding = min(32 * ((n_frames - 1) // 32 + 1) - n_frames, n_frames)
+ mel = F.pad(mel, (0, padding), mode="reflect")
+ if self.is_half:
+ mel = mel.half()
+ hidden = self.model(mel)
+ return hidden[:, :n_frames]
+
+ def decode(self, hidden, thred=0.03):
+ cents_pred = self.to_local_average_cents(hidden, thred=thred)
+ f0 = 10 * (2 ** (cents_pred / 1200))
+ f0[f0 == 10] = 0
+ return f0
+
+ def infer_from_audio(self, audio, thred=0.03):
+ audio = torch.from_numpy(audio).float().to(self.device).unsqueeze(0)
+ mel = self.mel_extractor(audio, center=True)
+ hidden = self.mel2hidden(mel)
+ hidden = hidden.squeeze(0).cpu().numpy()
+ if self.is_half:
+ hidden = hidden.astype("float32")
+ return self.decode(hidden, thred=thred)
+
+ def infer_from_audio_with_pitch(self, audio, thred=0.03, f0_min=50, f0_max=1100):
+ audio = torch.from_numpy(audio).float().to(self.device).unsqueeze(0)
+ mel = self.mel_extractor(audio, center=True)
+ hidden = self.mel2hidden(mel)
+ hidden = hidden.squeeze(0).cpu().numpy()
+ if self.is_half:
+ hidden = hidden.astype("float32")
+ f0 = self.decode(hidden, thred=thred)
+ f0[(f0 < f0_min) | (f0 > f0_max)] = 0
+ return f0
+
+ def to_local_average_cents(self, salience, thred=0.05):
+ center = np.argmax(salience, axis=1)
+ salience = np.pad(salience, ((0, 0), (4, 4)))
+ center += 4
+ todo_salience = []
+ todo_cents_mapping = []
+ starts = center - 4
+ ends = center + 5
+ for idx in range(salience.shape[0]):
+ todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
+ todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
+ todo_salience = np.array(todo_salience)
+ todo_cents_mapping = np.array(todo_cents_mapping)
+ product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
+ weight_sum = np.sum(todo_salience, 1)
+ divided = product_sum / weight_sum
+ maxx = np.max(salience, axis=1)
+ divided[maxx <= thred] = 0
+ return divided
+