diff --git a/mvsepless/__init__.py b/mvsepless/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb0f890adfeb87d73fd58f17cb259619b24617ce --- /dev/null +++ b/mvsepless/__init__.py @@ -0,0 +1,1579 @@ +import os +import sys +import shutil +import logging +import zipfile +import importlib.util +from pathlib import Path +logging.basicConfig(level=logging.WARNING) + +script_dir = os.path.dirname(os.path.abspath(__file__)) +os.chdir(script_dir) + +if not __package__: + from model_manager import MvseplessModelManager + from audio import Audio, Inverter + from namer import Namer + from vbach_infer import vbach_inference, model_manager as VbachModel + from downloader import dw_file, dw_yt_dlp + from ensemble import ensemble_audio_files +else: + from .model_manager import MvseplessModelManager + from .audio import Audio, Inverter + from .namer import Namer + from .vbach_infer import vbach_inference, model_manager as VbachModel + from .downloader import dw_file, dw_yt_dlp + from .ensemble import ensemble_audio_files +from typing import Literal +import gradio as gr +import pandas as pd +import subprocess +import json +import threading +import queue +import time +import argparse +from datetime import datetime +import tempfile +import ast + +class MVSEPLESS: + audio = Audio() + inverter = Inverter() + namer = Namer() + model_manager = MvseplessModelManager() + vbach_model_manager = VbachModel + +class Separator(MVSEPLESS): + + class OutputReader: + def __init__(self, debug=False): + self.debug = debug + + def parse_json_line(self, line): + try: + return json.loads(line) + except json.JSONDecodeError: + return None + + def reaction_line(self, line, progress, add_text): + _add_text = "" + if add_text != "" or add_text is not None: + _add_text = f"| {add_text}" + + data = self.parse_json_line(line) + if data is None: + return None + elif "reading" in data: + progress(0.05, desc=f"Чтение файла {_add_text}") + print("Чтение файла") + return None + elif "processing" in data: + progress_a = data["processing"] + processed = progress_a.get("processed", 0) + total = progress_a.get("total", 1) + # Исправлено: убрано деление на ноль + if total > 0: + progress_ratio = min(0.89, 0.05 + (processed / total * 0.85)) # Оставляем место для этапа записи + percent = int((processed / total) * 100) + progress(progress_ratio, desc=f"Обработано: {percent}% {_add_text}") + print(f"\rОбработано: {percent}%", end="") + return None + elif "writing" in data: + progress(0.9, desc="Запись результатов") + print(f"\rЗапись в файл {data['writing']}", end="") + return None + elif "done" in data: + progress(1.0, desc=f"Завершено {_add_text}") + print("\rЗавершено", end="\n") + return data["done"] + elif "error" in data: + raise Exception(data["error"]) + + def read_stream_to_queue(self, stream, queue_obj, stream_name): + """Чтение потока вывода подпроцесса и запись в очередь""" + try: + for line in iter(stream.readline, ''): + line = line.strip() + if line: + if self.debug: + print(f"[{stream_name}] {line}") # Отладочный вывод + queue_obj.put(line) + stream.close() + except Exception as e: + print(f"Error reading {stream_name}: {e}") + + output_reader = OutputReader() + + def separator_model_loader(self, model_type: str, model_name: str, mdx_denoise: bool, vr_aggr: bool, progress) -> tuple[int, str, str]: + + if model_type in [ + "mel_band_roformer", + "bs_roformer", + "mdx23c", + "mdxnet", + "vr", + "scnet", + "htdemucs", + "bandit", + "bandit_v2", + ]: + info = self.model_manager.models_info[model_type].get(model_name, None) + if not info: + raise ValueError(f"Модель {model_name} не найдена для типа {model_type}") + + id = self.model_manager.get_id(model_type, model_name) + conf, ckpt = self.model_manager.download_model( + self.model_manager.models_cache_dir, + model_name, + model_type, + info["checkpoint_url"], + info["config_url"], + ) + if model_type != "htdemucs": + self.model_manager.conf_editor(conf, mdx_denoise, vr_aggr, model_type) + + return id, conf, ckpt + + else: + raise ValueError("Неподдерживаемый тип модели") + + def separator_base( + self, + input_file: str, + output_dir: str, + model_type: Literal[ + "mel_band_roformer", + "bs_roformer", + "mdx23c", + "mdxnet", + "scnet", + "htdemucs", + "bandit", + "bandit_v2", + ] = "mel_band_roformer", + model_name: str = "Mel-Band-Roformer_Vocals_kimberley_jensen", + ext_inst: bool = True, + output_format: Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] = "mp3", + output_bitrate: str = "320k", + template: str = "NAME_(STEM)_MODEL", + selected_stems: list = None, + ckpt: str = None, + conf: str = None, + id: int = None, + progress: any = gr.Progress(track_tqdm=True), + add_text_progress: str = "" + ) -> list[tuple[str, str]]: + + if model_type in [ + "mel_band_roformer", + "bs_roformer", + "mdx23c", + "mdxnet", + "vr", + "scnet", + "htdemucs", + "bandit", + "bandit_v2", + ]: + + cmd = [ + os.sys.executable, + "-m", + "infer", + "--input", input_file, + "--store_dir", output_dir, + "--model_type", model_type, + "--model_name", model_name, + "--model_id", str(id), + "--config_path", conf, + "--start_check_point", ckpt, + "--output_format", output_format, + "--output_bitrate", str(output_bitrate), + "--template", template + ] + if ext_inst: + cmd.append("--extract_instrumental") + if selected_stems: + cmd.append("--selected_instruments") + cmd.extend(selected_stems) + + try: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + universal_newlines=True, + encoding='utf-8', + errors='replace' + ) + + # Создаем очереди для stdout и stderr + stdout_queue = queue.Queue() + stderr_queue = queue.Queue() + + # Запускаем потоки для чтения stdout и stderr + stdout_thread = threading.Thread( + target=self.output_reader.read_stream_to_queue, + args=(process.stdout, stdout_queue, "stdout") + ) + stderr_thread = threading.Thread( + target=self.output_reader.read_stream_to_queue, + args=(process.stderr, stderr_queue, "stderr") + ) + + stdout_thread.daemon = True + stderr_thread.daemon = True + + stdout_thread.start() + stderr_thread.start() + + results = {'output': None, 'error': None} + process_completed = False + + # Основной цикл обработки сообщений + while not process_completed: + # Проверяем завершение процесса + if process.poll() is not None: + process_completed = True + + # Обрабатываем сообщения из stdout + try: + stdout_line = stdout_queue.get_nowait() + result = self.output_reader.reaction_line(stdout_line, progress, add_text_progress) + if result is not None: + results['output'] = result + break + except queue.Empty: + pass + + # Обрабатываем сообщения из stderr + try: + stderr_line = stderr_queue.get_nowait() + result = self.output_reader.reaction_line(stderr_line, progress, add_text_progress) + if result is not None: + results['output'] = result + break + except queue.Empty: + pass + + # Если процесс еще работает, ждем немного перед следующей проверкой + if not process_completed: + time.sleep(0.1) + + # Дополнительная обработка оставшихся сообщений после завершения процесса + for _ in range(10): # Проверяем несколько раз на случай задержек + try: + stdout_line = stdout_queue.get_nowait() + result = self.output_reader.reaction_line(stdout_line, progress, add_text_progress) + if result is not None: + results['output'] = result + break + except queue.Empty: + pass + + try: + stderr_line = stderr_queue.get_nowait() + result = self.output_reader.reaction_line(stderr_line, progress, add_text_progress) + if result is not None: + results['output'] = result + break + except queue.Empty: + pass + + time.sleep(0.1) + + # Проверяем результаты + if results.get('error'): + raise Exception(results['error']) + + if results.get('output'): + return results['output'] + + # Если процесс завершился с ошибкой + if process.returncode != 0: + # Пытаемся получить последние сообщения об ошибках + error_messages = [] + try: + while True: + error_msg = stderr_queue.get_nowait() + error_messages.append(error_msg) + except queue.Empty: + pass + + error_text = "\n".join(error_messages[-5:]) # Последние 5 сообщений + raise Exception(f"Процесс завершился с ошибкой. Код возврата: {process.returncode}. Сообщения об ошибках:\n{error_text}") + + except Exception as e: + raise e + finally: + # Гарантируем завершение процесса + try: + if process.poll() is None: + process.terminate() + process.wait(timeout=5) + except: + try: + process.kill() + except: + pass + else: + raise ValueError("Неподдерживаемый тип модели") + + def separate( + self, + input: str | list = None, + output_dir: str = None, + model_type: Literal[ + "mel_band_roformer", + "bs_roformer", + "mdx23c", + "mdxnet", + "vr", + "scnet", + "htdemucs", + "bandit", + "bandit_v2", + ] = "mel_band_roformer", + model_name: str = "Mel-Band-Roformer_Vocals_kimberley_jensen", + ext_inst: bool = True, + output_format: Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] = "mp3", + output_bitrate: str = "320k", + template: str = "NAME_(STEM)_MODEL", + selected_stems: list = None, + add_settings: dict = {"mdx_denoise": False, "vr_aggr": 5, "add_single_sep_text_progress": None}, + progress: any = gr.Progress(track_tqdm=True) + ) -> list[tuple[str, str]] | list[str, list[tuple[str, str]]]: + + progress(0, desc="Начало обработки") + + # Валидация параметров + if output_format not in self.audio.output_formats: + output_format = "flac" + + if output_dir is None: + output_dir = os.getcwd() + + if output_dir: + output_dir = os.path.abspath(output_dir) + + if selected_stems is None: + selected_stems = [] + + if not input: + raise ValueError("Входной файл не указан") + + if "STEM" not in template and template is not None: + template = template + "_STEM_" + if not template: + template = "mvsepless_NAME_(STEM)" + + os.makedirs(output_dir, exist_ok=True) + + mdx_denoise = add_settings.get("mdx_denoise", False) + + vr_aggr = add_settings.get("vr_aggr", 5) + + add_progress_text_custom = add_settings.get("add_single_sep_text_progress", "") + + id, conf, ckpt = self.separator_model_loader(model_type, model_name, mdx_denoise, vr_aggr, progress) + + if isinstance(input, str): + if not os.path.exists(input): + raise ValueError(f"Входной файл не найден: {input}") + + if not self.audio.check(input): + raise ValueError("Входной файл не содержит аудио") + + basename = os.path.splitext(os.path.basename(input))[0] + seped = self.separator_base(input_file=input, + output_dir=output_dir, + model_type=model_type, + model_name=model_name, + ext_inst=ext_inst, + output_format=output_format, + output_bitrate=output_bitrate, + template=template, + selected_stems=selected_stems, + ckpt=ckpt, + conf=conf, + id=id, + progress=progress, + add_text_progress=add_progress_text_custom) + return seped + + elif isinstance(input, list): + results = [] + for i, f in enumerate(input, 1): + print(f"Файл {i} из {len(input)}: {f}") + if os.path.exists(f): + if self.audio.check(f): + basename = os.path.splitext(os.path.basename(f))[0] + seped = self.separator_base(input_file=f, + output_dir=output_dir, + model_type=model_type, + model_name=model_name, + ext_inst=ext_inst, + output_format=output_format, + output_bitrate=output_bitrate, + template=template, + selected_stems=selected_stems, + ckpt=ckpt, + conf=conf, + id=id, + progress=progress, + add_text_progress=f"({i} из {len(input)}) ") + results.append([basename, seped]) + return results + + + def UI(self, output_base_dir, output_temp_dir_check): + default_mts = self.model_manager.get_mt() + default_mt = self.model_manager.get_mt()[0] + default_mns = self.model_manager.get_mn(default_mt) + default_mn = default_mns[0] + default_stems = self.model_manager.get_stems(default_mt, default_mn) + default_tgt_inst = self.model_manager.get_tgt_inst(default_mt, default_mn) + with gr.Blocks(): + with gr.Row(): + with gr.Column(): + with gr.Group(visible=False) as add_inputs: + input_path = gr.Textbox(label="Путь к входному файлу", interactive=True) + add_inputs_btn = gr.Button("Добавить файл", variant="primary") + with gr.Group(visible=False) as add_inputs_from_url: + input_url = gr.Textbox(label="URL входного файла", interactive=True) + with gr.Row(): + inputs_url_format = gr.Dropdown(label="Формат входного файла", interactive=True, + choices=self.audio.output_formats, + value="mp3", filterable=False) + inputs_url_bitrate = gr.Slider(label="Битрейт входного файла", minimum=64, maximum=512, step=32, value=320, interactive=True) + with gr.Row(): + inputs_url_cookie = gr.UploadButton(label="Файл cookie (необязательно)", interactive=True, type="filepath", file_count="single", file_types=[".txt", ".cookies"], variant="secondary") + add_inputs_url_btn = gr.Button("Добавить файл", variant="primary") + with gr.Row(visible=True) as add_buttons_row: + add_path_btn = gr.Button("Добавить файл по пути", variant="secondary") + add_url_btn = gr.Button("Добавить файл по URL", variant="secondary") + with gr.Group(): + input_audio = gr.File(label="Входные аудио", interactive=True, type="filepath", file_count="multiple", file_types=[f".{of}" for of in self.audio.input_formats]) + sep_state = gr.Textbox(label="Состояние разделения", interactive=False, value="", visible=False) + status = gr.Textbox(container=False, lines=3, interactive=False, max_lines=3) + input_preview_check = gr.Checkbox(label="Показать плееры для входных аудио", interactive=True, value=False) + @gr.render(inputs=[input_preview_check, input_audio]) + def show_input_players(preview, audios): + if preview: + if audios: + with gr.Group(): + for file in audios: + gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath") + with gr.Column(): + with gr.Group(): + with gr.Row(): + model_type = gr.Dropdown(label="Тип модели", interactive=True, filterable=False, + choices=default_mts, + value=default_mt) + model_name = gr.Dropdown(label="Имя модели", interactive=True, filterable=False, + choices=default_mns, value=default_mn) + with gr.Group(): + extract_instrumental = gr.Checkbox(label="Извлечь инструментал", interactive=True, value=True) + selected_stems = gr.CheckboxGroup(label="Выбранные стемы для разделения", interactive=False, + choices=default_stems, value=[]) + + with gr.Accordion(label="Дополнительные настройки", open=False): + vr_aggr_slider = gr.Slider(label="Сила подавления для VR моделей", minimum=-100, maximum=100, value=5, step=1) + mdx_denoise_check = gr.Checkbox(label="Включить шумоподавление для MDX-NET моделей (это повышает потребление памяти в два раза)", value=False) + with gr.Row(): + output_format = gr.Dropdown(label="Формат выходного файла", interactive=True, + choices=self.audio.output_formats, + value="mp3", filterable=False) + output_bitrate = gr.Slider(label="Битрейт выходного файла", minimum=64, maximum=512, step=32, value=320, interactive=True) + + template = gr.Textbox(label="Шаблон именования выходных файлов", interactive=True, value="NAME (STEM) MODEL", info="Используйте ключи: \nNAME - имя входного файла без расширения, \nSTEM - имя стема, \nMODEL - имя модели разделения") + separate_btn = gr.Button("Разделить", variant="primary") + + @gr.render(inputs=[sep_state], triggers=[sep_state.change]) + def players(state): + def create_archive_advanced(file_list, archive_name="archive.zip"): + """ + Создает архив с расширенной обработкой ошибок + """ + try: + print("Генерация ZIP-архива с результатами разделения...") + with zipfile.ZipFile(archive_name, 'w', zipfile.ZIP_DEFLATED) as zipf: + successful_files = 0 + + for basename, stems in file_list: + for stem_name, stem_path in stems: + try: + if os.path.exists(stem_path) and os.path.isfile(stem_path): + basename_ = os.path.basename(stem_path) + zipf.write(stem_path, basename_) + successful_files += 1 + print(f"✓ Добавлен: {stem_path} -> {basename}") + else: + print(f"✗ Файл не найден или не является файлом: {stem_path}") + + except Exception as e: + print(f"✗ Ошибка при добавлении {stem_path}: {e}") + + print(f"\nАрхив создан: {archive_name}") + print(f"Успешно добавлено файлов: {successful_files}") + return os.path.abspath(archive_name) + + except Exception as e: + print(f"Ошибка при создании архива: {e}") + if state != "": + state_loaded = ast.literal_eval(state) + archive_stems = create_archive_advanced(state_loaded, os.path.join(tempfile.tempdir, f"mvsepless_output_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip")) + for basename, stems in state_loaded: + with gr.Group(): + gr.Markdown(f"

{basename}

") + for stem_name, stem_path in stems: + with gr.Row(equal_height=True): + output_stem = gr.Audio(value=stem_path, label=stem_name, type="filepath", interactive=False, show_download_button=True, scale=15) + reuse_btn = gr.Button("Использовать снова", variant="secondary") + @reuse_btn.click( + inputs=[output_stem, input_audio], + outputs=input_audio + ) + def reuse_fn(stem_audio, input_a): + if input_a is None: + input_a = [] + if isinstance(input_a, str): + input_a = [input_a] + if os.path.exists(stem_audio): + if self.audio.check(stem_audio): + input_a.append(stem_audio) + return input_a + + gr.DownloadButton(label="Скачать как ZIP", value=archive_stems, interactive=True) + + @add_inputs_btn.click( + inputs=[input_path, input_audio], + outputs=[add_inputs, input_audio, add_buttons_row]) + def add_inputs_fn(input_p, input_a): + if input_p and os.path.exists(input_p): + if input_a is None: + input_a = [] + if isinstance(input_a, str): + input_a = [input_a] + if self.audio.check(input_p): + input_a.append(input_p) + return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True) + return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True) + + @add_inputs_url_btn.click( + inputs=[input_url, input_audio, inputs_url_format, inputs_url_bitrate, inputs_url_cookie], + outputs=[add_inputs_from_url, input_audio, add_buttons_row]) + def add_inputs_from_url_fn(input_u, input_a, fmt, br, cookie): + if input_u: + if input_a is None: + input_a = [] + if isinstance(input_a, str): + input_a = [input_a] + downloaded_file = dw_yt_dlp( + url=input_u, + output_format=fmt, + output_bitrate=str(int(br)), + cookie=cookie + ) + if downloaded_file and os.path.exists(downloaded_file): + if self.audio.check(downloaded_file): + input_a.append(downloaded_file) + return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True) + return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True) + + add_path_btn.click( + lambda: (gr.update(visible=True), gr.update(visible=False)), + outputs=[add_inputs, add_buttons_row]) + + add_url_btn.click( + lambda: (gr.update(visible=True), gr.update(visible=False)), + outputs=[add_inputs_from_url, add_buttons_row]) + + model_type.change(lambda x: gr.update(choices=self.model_manager.get_mn(x), value=self.model_manager.get_mn(x)[0]), + inputs=model_type, outputs=model_name) + model_name.change(lambda mt, mn: (gr.update(choices=self.model_manager.get_stems(mt, mn), value=[], interactive=False if self.model_manager.get_tgt_inst(mt, mn) else True), gr.update(value=True if self.model_manager.get_tgt_inst(mt, mn) else False)), inputs=[model_type, model_name], outputs=[selected_stems, extract_instrumental]) + output_format.change(lambda x: gr.update(visible=False if x in ["wav", "flac", "aiff"] else True), inputs=output_format, outputs=output_bitrate) + inputs_url_format.change(lambda x: gr.update(visible=False if x in ["wav", "flac", "aiff"] else True), inputs=inputs_url_format, outputs=inputs_url_bitrate) + @separate_btn.click( + inputs=[ + input_audio, model_type, model_name, + extract_instrumental, output_format, output_bitrate, + template, selected_stems, output_base_dir, output_temp_dir_check, mdx_denoise_check, vr_aggr_slider + ], + outputs=[sep_state, status], + show_progress="full" + ) + def wrap(i, mt, mn, ei, of, ob, t, stems, o_dir, temp_save, mdx_denoise, vr_aggr, progress=gr.Progress(track_tqdm=True)): + if o_dir.strip() != "" and not temp_save: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + o = os.path.join(o_dir, f"mvsepless_outputs_{timestamp}") + os.makedirs(o, exist_ok=True) + else: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + o = tempfile.mkdtemp(prefix=f"mvsepless_outputs_{timestamp}_") + os.makedirs(o, exist_ok=True) + results = self.separate(i, o, mt, mn, ei, of, ob, t, stems, add_settings={"mdx_denoise": mdx_denoise, "vr_aggr": int(vr_aggr)}, progress=progress) + return str(results), "" + +class AutoEnsembless(Separator): + + class ModelManager(MVSEPLESS): + def __init__(self): + self.data = [] + self.ensemble_methods = ("min_fft", "max_fft", "avg_fft", "median_fft") + self.ensemble_invert_methods_map = {"min_fft": "max_fft", "max_fft": "min_fft", "avg_fft": "avg_fft", "median_fft": "median_fft"} + self.dir_presets = os.path.join(tempfile.tempdir, "presets") + os.makedirs(self.dir_presets, exist_ok=True) + + def save(self, name): + if not name: + name = "ensembless_preset" + filepath = os.path.join(self.dir_presets, f"{self.namer.short(self.namer.sanitize(name), length=50)}.json") + with open(filepath, "w") as f: + json.dump(self.data, f, indent=4, ensure_ascii=False) + return filepath + + def load(self, filepath): + with open(filepath, "r") as f: + ensemble_data_temp = json.load(f) + self.data = [] + for (mt, mn, s_stem, i_stem, weight) in ensemble_data_temp: + if {mt, mn} not in [{model[0], model[1]} for model in self.data]: + self.data.append((mt, mn, s_stem, i_stem, weight)) + + def add(self, mt, mn, s_stem, i_stem, weight): + if {mt, mn} not in [{model[0], model[1]} for model in self.data]: + if s_stem and i_stem: + self.data.append((mt, mn, s_stem, i_stem, weight)) + + def replace(self, mt, mn, s_stem, i_stem, weight, index=1): + if self.data: + len_data = len(self.data) + if index >= 1: + if index <= len_data: + self.data[index - 1] = (mt, mn, s_stem, i_stem, weight) + elif index == 0: + self.data[0] = (mt, mn, s_stem, i_stem, weight) + + def remove(self, index=1): + if self.data: + len_data = len(self.data) + if index >= 1: + if index <= len_data: + del self.data[index - 1] + elif index == 0: + del self.data[0] + + def clear(self): + self.data = [] + + def get_df(self): + if not self.data: + columns = ["#", "Имя модели", "Основной стем", "Инверсия", "Вес"] + return pd.DataFrame(columns=columns) + + data = [] + for i, model in enumerate(self.data): + data.append( + [ + f"{i+1}", + model[1], + model[2], + model[3], + model[4], + ] + ) + columns = ["#", "Имя модели", "Основной стем", "Инверсия", "Вес"] + return pd.DataFrame(data, columns=columns) + + def UI(self, output_base_dir, output_temp_dir_check): + ensemble_model_manager = self.ModelManager() + def get_stems(mt, mn): + stems = [] + for stem in self.model_manager.get_stems(mt, mn): + stems.append(stem) + + if not self.model_manager.get_tgt_inst(mt, mn): + if set(stems) == {"bass", "drums", "other", "vocals"} or set(stems) == {"bass", "drums", "other", "vocals", "piano", "guitar"}: + stems.append("instrumental +") + stems.append("instrumental -") + + return stems + + def get_invert_stems(mt, mn, s_stem): + orig_stems = [] + stems = [] + for stem in self.model_manager.get_stems(mt, mn): + orig_stems.append(stem) + + for stem in orig_stems: + if stem != s_stem: + stems.append(stem) + + if not self.model_manager.get_tgt_inst(mt, mn): + if len(orig_stems) > 2: + if s_stem not in ["instrumental +", "instrumental -"]: + stems.append("inverted +") + stems.append("inverted -") + + return stems + + default_model = { + "mt": self.model_manager.get_mt(), + "mn": self.model_manager.get_mn(self.model_manager.get_mt()[0]), + "stem": get_stems( + self.model_manager.get_mt()[0], + self.model_manager.get_mn(self.model_manager.get_mt()[0])[0], + ), + "invert_stem": get_invert_stems( + self.model_manager.get_mt()[0], + self.model_manager.get_mn(self.model_manager.get_mt()[0])[0], + "vocals", + ), + "weight": 1, + } + + gr.Markdown("

Пресет

") + with gr.Group(): + with gr.Row(equal_height=True): + export_preset_name = gr.Textbox( + label="Имя пресета", + interactive=True, + value="ensembless_preset", scale=9 + ) + export_btn = gr.DownloadButton("Экспорт", variant="secondary", scale=3, interactive=True) + import_btn = gr.UploadButton( + "Импорт", file_types=[".json"], file_count="single", scale=3, interactive=True + ) + gr.Markdown("

Ансамбль

") + with gr.Row(): + with gr.Column(scale=3): # логика добавлеия моделей + model_type = gr.Dropdown(label="Тип модели", choices=default_model["mt"], value=default_model["mt"][0], interactive=True, filterable=False) + model_name = gr.Dropdown(label="Имя модели", choices=default_model["mn"], value=default_model["mn"][0], interactive=True, filterable=False) + primary_stem = gr.Dropdown(label="Основной стем", choices=default_model["stem"], value=default_model["stem"][0], interactive=True, filterable=False) + secondary_stem = gr.Dropdown(label="Инверсия", choices=default_model["invert_stem"], value=default_model["invert_stem"][0], interactive=True, filterable=False) + weight = gr.Slider(label="Вес", minimum=0, maximum=10, step=0.01, value=1, interactive=True) + @model_type.change( + inputs=[model_type], + outputs=[model_name] + ) + def update_model_names(mt): + model_names = self.model_manager.get_mn(mt) + new_mn = model_names[0] if model_names else "" + + return gr.update(choices=model_names, value=new_mn) + @model_name.change( + inputs=[model_type, model_name], + outputs=[primary_stem, secondary_stem] + ) + def update_stems_after_model_change(mt, mn): + stems = get_stems(mt, mn) + invert_stems = get_invert_stems(mt, mn, stems[0]) if stems else [] + + new_s_stem = stems[0] if stems else "" + new_i_stem = invert_stems[0] if invert_stems else "" + + return ( + gr.update(choices=stems, value=new_s_stem), + gr.update(choices=invert_stems, value=new_i_stem) + ) + @primary_stem.change( + inputs=[model_type, model_name, primary_stem], + outputs=[secondary_stem] + ) + def update_invert_stems(mt, mn, s_stem): + stems = get_invert_stems(mt, mn, s_stem) + new_i_stem = stems[0] if stems else "" + return gr.update(choices=stems, value=new_i_stem) + + model_add_button = gr.Button("Добавить", interactive=True) + with gr.Column(scale=10): + df = gr.DataFrame( + value=ensemble_model_manager.get_df(), + headers=["#", "Имя модели", "Основной стем", "Инверсия", "Вес"], + datatype=["number", "str", "str", "str", "number"], + interactive=False + ) + + with gr.Group(): + with gr.Row(equal_height=True): + with gr.Column(): + model_index = gr.Number(label="Индекс модели", value=1, interactive=True) + model_clear_btn = gr.Button("Очистить", variant="stop", interactive=True) + with gr.Column(): + model_replace_btn = gr.Button("Заменить", variant="primary", interactive=True) + model_delete_btn = gr.Button("Удалить", variant="stop", interactive=True) + + @model_add_button.click( + inputs=[model_type, model_name, primary_stem, secondary_stem, weight], + outputs=df + ) + def add_model_to_auto_ensemble(mt, mn, s_stem, i_stem, weight): + ensemble_model_manager.add(mt, mn, s_stem, i_stem, weight) + return ensemble_model_manager.get_df() + + @model_replace_btn.click( + inputs=[model_type, model_name, primary_stem, secondary_stem, weight, model_index], + outputs=df + ) + def replace_model_to_auto_ensemble(mt, mn, s_stem, i_stem, weight, index): + ensemble_model_manager.replace(mt, mn, s_stem, i_stem, weight, index) + return ensemble_model_manager.get_df() + + @model_delete_btn.click( + inputs=[model_index], + outputs=df + ) + def delete_model_to_auto_ensemble(index): + ensemble_model_manager.remove(index) + return ensemble_model_manager.get_df() + + @model_clear_btn.click( + outputs=df + ) + def clear_model_to_auto_ensemble(): + ensemble_model_manager.clear() + return ensemble_model_manager.get_df() + + gr.on(fn=ensemble_model_manager.get_df, outputs=df) + + df.change( + fn=ensemble_model_manager.save, + inputs=export_preset_name, + outputs=export_btn + ) + + export_preset_name.change( + fn=ensemble_model_manager.save, + inputs=export_preset_name, + outputs=export_btn + ) + + @import_btn.upload( + inputs=import_btn, + outputs=df + ) + def load_ensemble_preset(filepath): + ensemble_model_manager.load(filepath) + return ensemble_model_manager.get_df() + with gr.Row(): + with gr.Column(): + gr.Markdown("

Входное аудио

") + with gr.Group(): + with gr.Group(visible=False) as add_inputs: + input_path = gr.Textbox(label="Путь к входному файлу", interactive=True) + add_inputs_btn = gr.Button("Загрузить файл", variant="primary") + with gr.Group(visible=False) as add_inputs_from_url: + input_url = gr.Textbox(label="URL входного файла", interactive=True) + with gr.Row(): + inputs_url_format = gr.Dropdown(label="Формат входного файла", interactive=True, + choices=self.audio.output_formats, + value="mp3", filterable=False) + inputs_url_bitrate = gr.Slider(label="Битрейт входного файла", minimum=64, maximum=512, step=32, value=320, interactive=True) + with gr.Row(): + inputs_url_cookie = gr.UploadButton(label="Файл cookie (необязательно)", interactive=True, type="filepath", file_count="single", file_types=[".txt", ".cookies"], variant="secondary") + add_inputs_url_btn = gr.Button("Загрузить файл", variant="primary") + with gr.Row(visible=True) as add_buttons_row: + add_path_btn = gr.Button("Загрузить файл по пути", variant="secondary") + add_url_btn = gr.Button("Загрузить файл по URL", variant="secondary") + with gr.Group(): + input_audio = gr.File(label="Входное аудио", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.audio.input_formats]) + with gr.Column(): + gr.Markdown("

Настройки

") + with gr.Group(): + method = gr.Dropdown(label="Алгоритм склеивания", choices=["min_fft", "max_fft", "avg_fft", "median_fft"], value="avg_fft", filterable=False) + invert_ensemble = gr.Checkbox(label="Инверсия ансамбля", interactive=True, value=False) + output_format = gr.Dropdown(label="Формат выходного файла", interactive=True, + choices=self.audio.output_formats, + value="mp3", filterable=False) + run_btn = gr.Button("Создать ансамбль", variant="primary", interactive=True) + + with gr.Row(): + with gr.Column(): + gr.Markdown("

Результаты

") + output_audio = gr.Audio(label="Результат", type="filepath", interactive=False, show_download_button=True) + output_audio_wav = gr.Textbox(label="Результат в WAV", interactive=False, visible=False) + with gr.Group(): + invert_method = gr.Radio( + choices=["waveform", "spectrogram"], + label="Метод создания инверсии", + value="waveform", + ) + invert_btn = gr.Button("Инвертировать") + output_inverted_audio = gr.Audio(label="Инверсия", type="filepath", interactive=False, show_download_button=True) + @invert_btn.click(inputs=[input_audio, output_audio_wav, invert_method, output_format], outputs=[output_inverted_audio]) + def invert_result_ensemble(input_file, output_file, method, out_format): + if input_file and output_file: + o_dir = os.path.dirname(output_file) + basename = os.path.splitext(os.path.basename(input_file))[0] + output_path = os.path.join(o_dir, f"ensembless_{self.namer.short(basename, length=50)}_{method}_invert.{out_format}") + inverted = self.inverter.process_audio(audio1_path=input_file, audio2_path=output_file, out_format=out_format, method=method, output_path=output_path) + return inverted + else: + return None + + with gr.Column(): + gr.Markdown("

Исходники ансамбля (WAV)

") + output_source_files = gr.Files(type="filepath", interactive=False, show_label=False) + output_source_preview_check = gr.Checkbox(label="Показать плееры для исходников ансамбля", interactive=True, value=False) + @gr.render(inputs=[output_source_preview_check, output_source_files]) + def show_output_auto_ensemble_players(preview, audios): + if preview: + if audios: + with gr.Group(): + for file in audios: + gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath") + + @run_btn.click( + inputs=[input_audio, method, output_format, invert_ensemble, output_base_dir, output_temp_dir_check], + outputs=[output_audio, output_audio_wav, output_inverted_audio, output_source_files] + ) + def auto_ensemble_run(input_file, method, out_format, invert_ensemble, o_dir, temp_save, progress=gr.Progress(track_tqdm=True)): + ensemble_state = ensemble_model_manager.data + invert_methods_map = ensemble_model_manager.ensemble_invert_methods_map + if not input_file: + return None, None, None, None, [] + if not os.path.exists(input_file): + return None, None, None, None, [] + if not self.audio.check(input_file): + return None, None, None, None, [] + if o_dir.strip() != "" and not temp_save: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + o = os.path.join(o_dir, f"ensembless_outputs_{timestamp}") + os.makedirs(o, exist_ok=True) + else: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + o = tempfile.mkdtemp(prefix=f"ensembless_outputs_{timestamp}_") + os.makedirs(o, exist_ok=True) + + basename = os.path.splitext(os.path.basename(input_file))[0] + def invert_weights(weights): + total_weight = sum(weights) + return [total_weight - w for w in weights] + # print(json.dumps(ensemble_state, indent=4, ensure_ascii=False)) + success_separations = [] # list[tuple[str, str, float]] [(s_stem, i_stem, weight)] + ensemble_sources_list = [] # list[str, str, str, ...] + if ensemble_state: + total_ensemble_models = len(ensemble_state) + for i, model in enumerate(ensemble_state, start=1): + + ens_mt = model[0] + ens_mn = model[1] + ens_s_stem = model[2] + ens_i_stem = model[3] + weight = model[4] + + s_stem = None #path to primary stem + i_stem = None #path to invert stem + + try: + result_seped_auto_ensemble = self.separate(input=input_file, output_dir=os.path.join(o, ens_mn), model_type=ens_mt, model_name=ens_mn, ext_inst=True, template="NAME - MODEL - STEM", output_format="wav", add_settings={"add_single_sep_text_progress": f"{i} из {total_ensemble_models}"}, progress=progress) + if result_seped_auto_ensemble: + for stem, path in result_seped_auto_ensemble: + ensemble_sources_list.append(path) + if stem == ens_s_stem: + s_stem = path + elif stem == ens_i_stem: + i_stem = path + + if invert_ensemble: + if not i_stem: + result_seped_auto_ensemble_invert = self.separate(input=input_file, output_dir=os.path.join(o, f"{ens_mn}_invert"), model_type=ens_mt, model_name=ens_mn, ext_inst=True, template="NAME - MODEL - STEM", output_format="wav", selected_stems=[ens_s_stem], add_settings={"add_single_sep_text_progress": f"{i} из {total_ensemble_models} (инверт.)"}, progress=progress) + if result_seped_auto_ensemble_invert: + for stem, path in result_seped_auto_ensemble: + if stem == ens_i_stem: + i_stem = path + ensemble_sources_list.append(path) + + except Exception as e: + print(f"\nПроизошла ошибка при разделении: {e}") + progress(0, desc="Произошла ошибка при разделении, модель пропускается...") + continue + finally: + if s_stem: + success_separations.append((ens_mn, s_stem, i_stem, weight)) + + ensemble_sources_stems = [] + ensemble_sources_invert_stems = [] + weights = [] + + for out_mn, out_s_stem, out_i_stem, out_weight in success_separations: + ensemble_sources_stems.append(out_s_stem) + ensemble_sources_invert_stems.append(out_i_stem) + weights.append(out_weight) + + + auto_ensemble_invout_file = None + auto_ensemble_invout_file_wav = None + + auto_ensemble_output_name = f"ensembless_{self.namer.short(basename, length=50)}_{len(ensemble_sources_stems)}_{method}" + auto_ensemble_inverted_output_name = f"ensembless_{self.namer.short(basename, length=50)}_{len(ensemble_sources_stems)}_{invert_methods_map[method]}_invert" + auto_ensemble_out_file, auto_ensemble_out_file_wav = ensemble_audio_files(files=ensemble_sources_stems, weights=weights, output=os.path.join(o, auto_ensemble_output_name), ensemble_type=method, out_format=out_format, add_wav=True) + + if invert_ensemble: + auto_ensemble_invout_file, auto_ensemble_invout_file_wav = ensemble_audio_files(files=ensemble_sources_invert_stems, weights=invert_weights(weights), output=os.path.join(o, auto_ensemble_inverted_output_name), ensemble_type=invert_methods_map[method], out_format=out_format, add_wav=True) + + return auto_ensemble_out_file, auto_ensemble_out_file_wav, auto_ensemble_invout_file, ensemble_sources_list + + @add_inputs_btn.click( + inputs=[input_path, input_audio], + outputs=[add_inputs, input_audio, add_buttons_row]) + def add_inputs_fn(input_p, input_a): + if input_p and os.path.exists(input_p): + if input_a is None: + input_a = None + if self.audio.check(input_p): + input_a = input_p + return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True) + return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True) + + @add_inputs_url_btn.click( + inputs=[input_url, input_audio, inputs_url_format, inputs_url_bitrate, inputs_url_cookie], + outputs=[add_inputs_from_url, input_audio, add_buttons_row]) + def add_inputs_from_url_fn(input_u, input_a, fmt, br, cookie): + if input_u: + if input_a is None: + input_a = None + downloaded_file = dw_yt_dlp( + url=input_u, + output_format=fmt, + output_bitrate=str(int(br)), + cookie=cookie + ) + if downloaded_file and os.path.exists(downloaded_file): + if self.audio.check(downloaded_file): + input_a = downloaded_file + return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True) + return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True) + + add_path_btn.click( + lambda: (gr.update(visible=True), gr.update(visible=False)), + outputs=[add_inputs, add_buttons_row]) + + add_url_btn.click( + lambda: (gr.update(visible=True), gr.update(visible=False)), + outputs=[add_inputs_from_url, add_buttons_row]) + + inputs_url_format.change(lambda x: gr.update(visible=False if x in ["wav", "flac", "aiff"] else True), inputs=inputs_url_format, outputs=inputs_url_bitrate) + +class ManualEnsembless(MVSEPLESS): + def UI(self, output_base_dir, output_temp_dir_check): + with gr.Row(): + with gr.Column(): + with gr.Group(visible=False) as add_ensemble_inputs: + input_ensemble_path = gr.Textbox(label="Путь к входному файлу", interactive=True) + add_ensemble_inputs_btn = gr.Button("Добавить файл", variant="primary") + add_ensemble_path_btn = gr.Button("Добавить файл по пути", variant="secondary") + input_ensemble_files = gr.File(label="Входное аудио", interactive=True, type="filepath", file_count="multiple", file_types=[f".{of}" for of in self.audio.input_formats]) + input_ensemble_preview_check = gr.Checkbox(label="Показать плееры для входных аудио", interactive=True, value=False) + @gr.render(inputs=[input_ensemble_preview_check, input_ensemble_files]) + def show_input_ensemble_players(preview, audios): + if preview: + if audios: + with gr.Group(): + for file in audios: + gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath") + + with gr.Column(): + @gr.render(inputs=[input_ensemble_files]) + def input_ensemble_files_fn(input_files): + check_ensemble_files_status = f"""Анализ входных файлов +---""" + hz_ = [] + err_list = [] + if input_files: + for file in input_files: + basename = os.path.splitext(os.path.basename(file))[0] + if os.path.exists(file): + if self.audio.check(file): + info = self.audio.get_info(file) + hz = info[0].get("sample_rate") + check_ensemble_files_status += f"\n{basename} - {hz} hz" + hz_.append(hz) + else: + check_ensemble_files_status += f"\n{basename} - Нет аудио" + err_list.append(file) + else: + check_ensemble_files_status += f"\n{basename} - Файл не найден" + err_list.append(file) + + check_ensemble_files_result = f"Действительных файлов: {len(hz_)}" + + all_same = True + + common_rate = None + + for hz_hz in hz_: + if common_rate is None: + common_rate = hz_hz + elif common_rate != hz_hz: + all_same = False + + if hz_ and len(hz_) > 1: + check_ensemble_files_result += "\nВсе действительные файлы имеют одинаковую частоту дискретизации" if all_same else "\nОшибка! Все действительные файлы имеют РАЗНУЮ частоту дискретизации" + else: + check_ensemble_files_result += "\nДля создания ансамбля нужно загрузить, как минимум - 2 файла, содержащие аудио" + + + check_ensemble_files_status += f"\n \n{check_ensemble_files_result}" + + gr.Textbox(container=False, lines=len(check_ensemble_files_status.split("\n")), interactive=False, value=check_ensemble_files_status) + + weights = gr.Textbox(label="Веса", value="1.0,1.0") + + method = gr.Dropdown(label="Алгоритм склеивания", choices=["min_fft", "max_fft", "avg_fft", "median_fft"], value="avg_fft", filterable=False) + + output_format = gr.Dropdown(label="Формат выходного файла", interactive=True, + choices=self.audio.output_formats, + value="mp3", filterable=False) + + output_manual_ensemble_filename = gr.Textbox(label="Имя выходного файла", value="ensemble", interactive=True) + + make_manual_ensemble_btn = gr.Button(value="Создать ансамбль", variant="primary") + + manual_ensemble_output_audio = gr.Audio(label="Результат", type="filepath", interactive=False, show_download_button=True) + + @make_manual_ensemble_btn.click( + inputs=[input_ensemble_files, method, output_format, output_base_dir, output_temp_dir_check, output_manual_ensemble_filename, weights], outputs=manual_ensemble_output_audio + ) + def make_manual_ensemble_fn(input_files_list, method, out_format, o_dir, temp_save, o_filename, weights: str): + if o_dir.strip() != "" and not temp_save: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + o = os.path.join(o_dir, f"ensembless_outputs_{timestamp}") + os.makedirs(o, exist_ok=True) + else: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + o = tempfile.mkdtemp(prefix=f"ensembless_outputs_{timestamp}_") + os.makedirs(o, exist_ok=True) + + o_filename = self.namer.sanitize(o_filename) + o_filename = self.namer.short(o_filename) + + output_file = ensemble_audio_files(files=input_files_list, output=os.path.join(o, o_filename), weights=[float(x) for x in weights.split(",")], ensemble_type=method, out_format=out_format) + return output_file + + add_ensemble_path_btn.click( + lambda: (gr.update(visible=True), gr.update(visible=False)), + outputs=[add_ensemble_inputs, add_ensemble_path_btn]) + + @add_ensemble_inputs_btn.click( + inputs=[input_ensemble_path, input_ensemble_files], + outputs=[add_ensemble_inputs, input_ensemble_files, add_ensemble_path_btn]) + def add_ensemble_inputs_fn(input_p, input_a): + if input_p and os.path.exists(input_p): + if input_a is None: + input_a = [] + if isinstance(input_a, str): + input_a = [input_a] + if self.audio.check(input_p): + input_a.append(input_p) + return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True) + return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True) + +class Inverter_UI(MVSEPLESS): + def UI(self): + with gr.Group(): + with gr.Row(): + original_audio = gr.File(label="Оригинал", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.audio.input_formats]) + stem_audio = gr.File(label="Cтем, который будет вычтен из оригинала", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.audio.input_formats]) + with gr.Group(): + output_format = gr.Dropdown(label="Формат выходного файла", interactive=True, + choices=self.audio.output_formats, + value="mp3", filterable=False) + method = gr.Radio( + choices=["waveform", "spectrogram"], + label="Метод вычитания", + value="waveform", + ) + btn = gr.Button("Вычесть") + output_audio = gr.Audio(label="Инверсия", type="filepath", interactive=False, show_download_button=True) + @btn.click(inputs=[original_audio, stem_audio, method, output_format], outputs=[output_audio]) + def invert_result_ensemble(input_file, output_file, method, out_format): + if input_file and output_file: + o_dir = tempfile.mkdtemp(suffix="_inverter") + basename = os.path.splitext(os.path.basename(input_file))[0] + output_path = os.path.join(o_dir, f"inverter_{self.namer.short(basename, length=50)}_{method}.{out_format}") + inverted = self.inverter.process_audio(audio1_path=input_file, audio2_path=output_file, out_format=out_format, method=method, output_path=output_path) + return inverted + else: + return None + +class Vbach(MVSEPLESS): + pitch_methods = ("rmvpe+", "fcpe", "mangio-crepe") + hop_length_values = (8, 512) + index_rates_values = (0, 1) + filter_radius_values = (0, 7) + protect_values = (0, 0.5) + rms_values = (0, 1) + f0_min_values = (50, 3000) + f0_max_values = (300, 6000) + + def UI(self): + with gr.Tab("Инференс"): + with gr.Row(): + with gr.Column(): + with gr.Group(): + input_audio = gr.File(label="Входные аудио", interactive=True, type="filepath", file_count="multiple", file_types=[f".{of}" for of in self.audio.input_formats]) + converted_state = gr.Textbox(label="Состояние разделения", interactive=False, value="", visible=False) + status = gr.Textbox(container=False, lines=3, interactive=False, max_lines=3) + input_preview_check = gr.Checkbox(label="Показать плееры для входных аудио", interactive=True, value=False) + @gr.render(inputs=[input_preview_check, input_audio]) + def show_input_players(preview, audios): + if preview: + if audios: + with gr.Group(): + for file in audios: + gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath") + + with gr.Column(): + with gr.Group(): + model_name = gr.Dropdown(label="Имя модели", interactive=True) + model_list_refresh_btn = gr.Button("Обновить", variant="secondary", interactive=True) + @model_list_refresh_btn.click( + outputs=[model_name] + ) + def refresh_list_voice_models(): + models = [] + models = self.vbach_model_manager.parse_voice_models() + first_model = None + if len(models) > 0: + first_model = models[0] + return gr.update(choices=models, value=first_model) + with gr.Group(): + pitch_method = gr.Radio(label="Метод извлечения высоты тона", choices=self.pitch_methods, value=self.pitch_methods[0], interactive=True) + pitch = gr.Slider(label="Высота тона", minimum=-48, maximum=48, step=0.5, value=0, interactive=True) + hop_length = gr.Slider(label="Длина шага", info="Длина шага влияет на точность передачи высоты тона\nЧем меньше длина шага - тем точнее будет передана высота тона", minimum=self.hop_length_values[0], maximum=self.hop_length_values[1], step=8, value=128, interactive=True, visible=False) + @pitch_method.change( + inputs=[pitch_method], + outputs=[hop_length] + ) + def show_mangio_crepe_hop_length(pitch_method): + return gr.update(visible=True if pitch_method in ["mangio-crepe"] else False) + stereo_mode = gr.Radio( + choices=["mono", "left/right", "sim/dif"], + label="Стерео режим", + info="mono - монофоническая обработка аудио, \nleft/right - обработка левого и правого каналов отдельно, \nsim/dif - обработка фантомного центра и стерео-базы, разделенную на левый и правый каналы", + value="mono", + interactive=True + ) + with gr.Accordion(label="Дополнительные настройки",open=False): + with gr.Group(): + with gr.Row(): + index_rate = gr.Slider(label="Влияние индекса", info="Чем ниже значение, тем больше голос похож на исходный; чем выше, тем ближе к модели", minimum=self.index_rates_values[0], maximum=self.index_rates_values[1], step=0.05, value=0, interactive=True) + filter_radius = gr.Slider(label="Радиус фильтра", info="Сглаживает результаты извлечения тона\nМожет снизить дыхание и шумы на выходе", minimum=self.filter_radius_values[0], maximum=self.filter_radius_values[1], step=1, value=3, interactive=True) + with gr.Row(): + rms = gr.Slider(label="Соотношение огибающих громкости", info="Значение 0 - огибающая громкости как у входного аудио, 1 - как у выходного сигнала", minimum=self.rms_values[0], maximum=self.rms_values[1], step=0.05, value=0.25, interactive=True) + protect = gr.Slider(label="Защита согласных", info="Предовращает роботизацию дыхания и согласных (Может влиять на четкость речи)\nЗначение 0.5 - выключает защиту, 0 - максимальная защита", minimum=self.protect_values[0], maximum=self.protect_values[1], step=0.05, value=0.35, interactive=True) + with gr.Group(): + with gr.Row(): + f0_min = gr.Slider(label="Нижний предел диапазона определения высоты тона", minimum=self.f0_min_values[0], maximum=self.f0_min_values[1], step=10, value=50, interactive=True) + f0_max = gr.Slider(label="Верхний предел диапазона определения высоты тона", minimum=self.f0_max_values[0], maximum=self.f0_max_values[1], step=10, value=1100, interactive=True) + + with gr.Group(): + output_name = gr.Textbox(label="Имя выходного файла", interactive=True, value="NAME - MODEL - F0METHOD - PITCH") + format_output_name_check = gr.Checkbox(label="Форматировать имя", info="Используйте ключи: \nNAME - имя входного файла без расширения, \nPITCH - высота тона, \nF0METHOD - метод извлечения высота тона, \nMODEL - имя голосовой модели", value=True, interactive=True) + output_format = gr.Dropdown(label="Формат выходного файла", interactive=True, choices=self.audio.output_formats, value=self.audio.output_formats[0], filterable=False) + convert_btn = gr.Button("Преобразовать", variant="primary", interactive=True) + + + @convert_btn.click( + inputs=[ + input_audio, + model_name, + pitch_method, + pitch, + hop_length, + index_rate, + filter_radius, + rms, + protect, + f0_min, + f0_max, + output_name, + format_output_name_check, + output_format, + stereo_mode + ], outputs=[converted_state, status] + ) + def vbach_convert_batch(ifl, mn, pm, p, hl, ir, fr, rms, pr, f0min, f0max, on, fn, of, sm): + output_converted_files = [] + progress = gr.Progress() + if ifl: + for i, file in enumerate(ifl, start=1): + try: + print(f"Файл {i} из {len(ifl)}: {file}") + progress(progress=(i / len(ifl)), desc=f"Файл {i} из {len(ifl)}") + out_conv = vbach_inference(input_file=file, model_name=mn, output_dir=tempfile.mkdtemp(), output_name=on, format_name=True if len(ifl) > 1 else fn, output_format=of, pitch=p, method_pitch=pm, output_bitrate=320, add_params={ "index_rate": ir,"filter_radius": fr,"protect": pr,"rms": rms,"mangio_crepe_hop_length": hl,"f0_min": f0min,"f0_max": f0max,"stereo_mode": sm }) + output_converted_files.append(out_conv) + except Exception as e: + print(e) + return str(output_converted_files), None + + @gr.render(inputs=[converted_state]) + def show_players_converted(state): + if state != "": + output_converted_files = ast.literal_eval(state) + if output_converted_files: + with gr.Group(): + for conv_file in output_converted_files: + basename = os.path.splitext(os.path.basename(conv_file))[0] + gr.Audio( + label=basename, + value=conv_file, + type="filepath", + interactive=False, + show_download_button=True, + ) + + with gr.TabItem("Менеджер"): + with gr.TabItem("Загрузить по ссылке"): + with gr.TabItem("Через zip файл"): + with gr.Row(): + with gr.Column(variant="panel"): + url_zip = gr.Text(label="Ссылка на zip файл") + with gr.Group(): + url_zip_model_name = gr.Text( + label="Имя модели", + ) + url_zip_download_btn = gr.Button("Загрузить", variant="primary") + + url_zip_output = gr.Text(label="Статус", interactive=False, lines=5) + url_zip_download_btn.click( + (lambda x, y: self.vbach_model_manager.install_model_zip(x, self.namer.short(self.namer.sanitize(y), length=40), "url")), + inputs=[url_zip, url_zip_model_name], + outputs=url_zip_output, + ) + + with gr.TabItem("Через отдельные файлы"): + with gr.Row(): + with gr.Column(variant="panel"): + url_pth = gr.Text(label="Ссылка на *.pth файл") + url_index = gr.Text(label="Ссылка на *.index файл (необязательно)") + with gr.Group(): + url_file_model_name = gr.Text( + label="Имя модели", + ) + url_file_download_btn = gr.Button("Загрузить", variant="primary") + + url_file_output = gr.Text(label="Статус", interactive=False, lines=5) + url_file_download_btn.click( + (lambda x, y, z: self.vbach_model_manager.install_model_files(x, y, self.namer.short(self.namer.sanitize(z), length=40), "url")), + inputs=[url_index, url_pth, url_file_model_name], + outputs=url_file_output, + ) + + with gr.Tab("Загрузить с устройства"): + with gr.Tab("Через zip файл"): + with gr.Row(): + with gr.Column(): + local_zip = gr.File( + label="zip файл", file_types=[".zip"], file_count="single" + ) + with gr.Column(variant="panel"): + with gr.Group(): + local_zip_model_name = gr.Text( + label="Имя модели", + ) + local_zip_upload_btn = gr.Button("Загрузить", variant="primary") + + local_zip_output = gr.Text(label="Статус", interactive=False, lines=5) + local_zip_upload_btn.click( + (lambda x, y: self.vbach_model_manager.install_model_zip(x, self.namer.short(self.namer.sanitize(y), length=40), "local")), + inputs=[local_zip, local_zip_model_name], + outputs=local_zip_output, + ) + + with gr.TabItem("Через отдельные файлы"): + with gr.Group(): + with gr.Row(): + local_pth = gr.File( + label="*.pth файл", file_types=[".pth"], file_count="single" + ) + local_index = gr.File( + label="*.index файл (необязательно)", file_types=[".index"], file_count="single" + ) + with gr.Column(variant="panel"): + with gr.Group(): + local_file_model_name = gr.Text( + label="Имя модели", + ) + local_file_upload_btn = gr.Button("Загрузить", variant="primary") + + local_file_output = gr.Text( + label="Статус", interactive=False + ) + local_file_upload_btn.click( + (lambda x, y, z: self.vbach_model_manager.install_model_files(x, y, self.namer.short(self.namer.sanitize(z), length=40), "local")), + inputs=[local_index, local_pth, local_file_model_name], + outputs=local_file_output, + ) + + with gr.TabItem("Удалить модель"): + with gr.Column(variant="panel"): + with gr.Group(): + delete_model_name = gr.Dropdown( + label="Имя модели", + choices=self.vbach_model_manager.parse_voice_models(), + interactive=True, + filterable=False + ) + delete_refresh_btn = gr.Button("Обновить") + @delete_refresh_btn.click(inputs=None, outputs=delete_model_name) + def refresh_list_voice_models(): + models = [] + models = self.vbach_model_manager.parse_voice_models() + first_model = None + if len(models) > 0: + first_model = models[0] + return gr.update(choices=models, value=first_model) + + delete_output = gr.Text( + label="Статус", interactive=False, lines=5 + ) + delete_btn = gr.Button("Удалить") + delete_btn.click( + fn=self.vbach_model_manager.del_voice_model, + inputs=delete_model_name, + outputs=delete_output + ) + + @gr.on(fn="decorator", inputs=None, outputs=[delete_model_name, model_name]) + def refresh_list_voice_models(): + models = [] + models = self.vbach_model_manager.parse_voice_models() + first_model = None + if len(models) > 0: + first_model = models[0] + return gr.update(choices=models, value=first_model), gr.update(choices=models, value=first_model) + + +class PluginManager(Separator): + plugins_dir = os.path.join(script_dir, "plugins") + os.makedirs(plugins_dir, exist_ok=True) + + def restart_after_install_plugin(self): + subprocess.Popen([os.sys.executable] + sys.argv) + os._exit(0) + + def parse_plugins(self): + for plugin_file in os.listdir(self.plugins_dir): + # Пропускаем не-Python файлы и __init__.py + if not plugin_file.endswith('.py') or plugin_file == '__init__.py': + continue + + # Получаем имя модуля без расширения + plugin_module_name = os.path.splitext(plugin_file)[0] + + try: + # Определяем путь импорта в зависимости от структуры проекта + if __package__: + # Если мы в пакете, используем абсолютный импорт + plugin_module = importlib.import_module(f".plugins.{plugin_module_name}", package=__package__) + else: + # Если не в пакете, пробуем разные варианты + try: + # Попробуем абсолютный импорт + plugin_module = importlib.import_module(f"plugins.{plugin_module_name}") + except ImportError: + # Если не сработало, загружаем из файла + plugin_path = os.path.join(self.plugins_dir, plugin_file) + spec = importlib.util.spec_from_file_location(plugin_module_name, plugin_path) + plugin_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(plugin_module) + + # Получаем класс Plugin из модуля + plugin_class = getattr(plugin_module, 'Plugin') + + # Создаем экземпляр плагина + plugin_instance = plugin_class() + + # Создаем UI плагина + with gr.Tab(plugin_instance.name): + plugin_instance.UI() + + except Exception as e: + print(f"Ошибка загрузки плагина {plugin_module_name}: {e}") + continue + + def UI(self): + with gr.Tab("Установка"): + upload_plugins_files = gr.File(label="Загрузить плагины", file_types=[".py"], file_count="multiple", interactive=True) + install_plugins_btn = gr.Button("Установить", interactive=True) + + @install_plugins_btn.click( + inputs=[upload_plugins_files] + ) + def upload_plugin_list(files): + if not files: + return + for file in files: + try: + # Копируем только .py файлы + if file.name.endswith('.py'): + shutil.copy( + file, os.path.join(self.plugins_dir, os.path.basename(file).replace(" ", "_")) + ) + except Exception as e: + print(f"Ошибка копирования файла {file}: {e}") + time.sleep(2) + self.restart_after_install_plugin() + + self.parse_plugins() + +def mvsepless_app(theme): + css = None + with gr.Blocks(theme=theme, css=css, title="Разделение музыки и вокала") as app: + + output_base_dir_state = gr.State(value=os.path.join(os.getcwd(), "outputs")) + output_temp_dir_state = gr.State(value=False) + + with gr.Tab("Инференс"): + Separator().UI(output_base_dir_state, output_temp_dir_state) + + with gr.Tab("Ансамбль"): + with gr.Tab("Авто-ансамбль"): + AutoEnsembless().UI(output_base_dir_state, output_temp_dir_state) + + with gr.Tab("Ручной ансамбль"): + ManualEnsembless().UI(output_base_dir_state, output_temp_dir_state) + + with gr.Tab("Вычитание"): + Inverter_UI().UI() + + with gr.Tab("Преобразование"): + Vbach().UI() + + with gr.Tab("Плагины"): + PluginManager().UI() + + with gr.Tab("Настройки"): + with gr.Column(): + output_base_dir_ui = gr.Textbox( + label="Базовый каталог для выходных файлов", + interactive=True, + value=os.path.join(os.getcwd(), "outputs"), + lines=5 + ) + output_temp_dir_check_ui = gr.Checkbox( + label="Использовать временный каталог для выходных файлов", + interactive=True, + value=False + ) + + # Связываем UI элементы с состоянием + output_base_dir_ui.change( + lambda x: x, + inputs=[output_base_dir_ui], + outputs=[output_base_dir_state] + ) + output_temp_dir_check_ui.change( + lambda x: x, + inputs=[output_temp_dir_check_ui], + outputs=[output_temp_dir_state] + ) + + return app + diff --git a/mvsepless/__main__.py b/mvsepless/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0009e708a942278fd688793386f12963522a822 --- /dev/null +++ b/mvsepless/__main__.py @@ -0,0 +1,67 @@ +import argparse +import gradio as gr +import os + +if not __package__: + from __init__ import mvsepless_app, Separator +else: + from .__init__ import mvsepless_app, Separator + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="MVSepless") + subparsers = parser.add_subparsers(dest="command") + app_parser = subparsers.add_parser("app", help="Приложение MVSepless") + app_parser.add_argument("--port", type=int, default=None, help="Порт для запуска сервера Gradio.") + app_parser.add_argument("--share", action="store_true", help="Создать публичную ссылку для приложения Gradio.") + cli_parser = subparsers.add_parser("cli", help="CLI MVSepless Lite") + cli_parser.add_argument("--input", type=str, required=True, help="Входной аудиофайл или каталог.") + cli_parser.add_argument("--output_dir", type=str, default=None, help="Каталог для выходных файлов.") + cli_parser.add_argument("--model_type", type=str, default="mel_band_roformer", help="Тип модели разделения.") + cli_parser.add_argument("--model_name", type=str, default="Mel-Band-Roformer_Vocals_kimberley_jensen", help="Имя модели разделения.") + cli_parser.add_argument("--ext_inst", action="store_true", help="Извлечь инструментал.") + cli_parser.add_argument("--output_format", type=str, default="mp3", choices=Separator.audio.output_formats, help="Формат выходного файла.") + cli_parser.add_argument("--output_bitrate", type=str, default="320k", help="Битрейт выходного файла.") + cli_parser.add_argument("--template", type=str, default="NAME (STEM) MODEL", help="Шаблон именования выходных файлов.") + cli_parser.add_argument("--selected_stems", type=str, nargs='*', default=None, help="Выбранные стемы для разделения.") + args = parser.parse_args() + + if args.command == "app": + theme = gr.themes.Citrus( + primary_hue="teal", + secondary_hue="blue", + neutral_hue="blue", + spacing_size="sm", + font=[ + gr.themes.GoogleFont("Montserrat"), + "ui-sans-serif", + "system-ui", + "sans-serif", + ], + ) + mvsepless_lite_app = mvsepless_app(theme=theme) + mvsepless_lite_app.launch(server_name="0.0.0.0", server_port=args.port, share=args.share, allowed_paths=["/"], debug=True) + elif args.command == "cli": + input_data = args.input + if os.path.isdir(input_data): + list_valid_files = [] + for file in os.listdir(args.input): + if os.path.isfile(os.path.join(args.input, file)): + if Separator.audio.check(os.path.join(args.input, file)): + list_valid_files.append(os.path.join(args.input, file)) + + input_files = list_valid_files + else: + input_files = input_data + + results = Separator().separate( + input=input_files, + output_dir=args.output_dir, + model_type=args.model_type, + model_name=args.model_name, + ext_inst=args.ext_inst, + output_format=args.output_format, + output_bitrate=args.output_bitrate, + template=args.template, + selected_stems=args.selected_stems + ) + print("Разделение завершено.") \ No newline at end of file diff --git a/mvsepless/audio.py b/mvsepless/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..15a6a874a99bf87b48c67fa130280fd65401232c --- /dev/null +++ b/mvsepless/audio.py @@ -0,0 +1,781 @@ +import os +from pathlib import Path +import sys +import json +import subprocess +import numpy as np +from typing import Literal +from collections.abc import Callable +from pathlib import Path +from numpy.typing import DTypeLike +import tempfile +import librosa +if not __package__: + from namer import Namer +else: + from .namer import Namer +class NotInputFileSpecified(Exception): pass +class NotOutputFileSpecified(Exception): pass +class NotSupportedDataType(Exception): pass +class ErrorDecode(Exception): pass +class ErrorEncode(Exception): pass +class NotSupportedFormat(Exception): pass +class SampleRateError(Exception): pass +class FileIsNotAudio(Exception): pass + +class Audio(Namer): + def __init__(self): + """ +Чтение и запись аудио файла через ffmpeg + +Поддерживаемые типы данных: - int16, int32, float32, float64 + """ + super().__init__() + self.ffmpeg_path = os.environ.get("MVSEPLESS_FFMPEG", "ffmpeg") + self.ffprobe_path = os.environ.get("MVSEPLESS_FFPROBE", "ffprobe") + self.output_formats = ("mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff") + self.input_formats = ("mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff", "mp4", "mkv", "webm", "avi", "mov", "ts") + self.supported_dtypes = ("int16", "int32", "float32", "float64") + self.dtypes_dict = { + "int16": "s16le", + "int32": "s32le", + "float32": "f32le", + "float64": "f64le", + np.int16: "s16le", + np.int32: "s32le", + np.float32: "f32le", + np.float64: "f64le", + } + self.bitrate_limit = { + "mp3": {"min": 8, "max": 320}, + "aac": {"min": 8, "max": 512}, + "m4a": {"min": 8, "max": 512}, + "ac3": {"min": 32, "max": 640}, + "ogg": {"min": 64, "max": 500}, + "opus": {"min": 6, "max": 512}, + } + self.sample_rates = { + "mp3": { + "supported": (44100, 48000, 32000, 22050, 24000, 16000, 11025, 12000, 8000) + }, + "opus": {"supported": (48000, 24000, 16000, 12000, 8000)}, + "m4a": { + "supported": ( + 96000, + 88200, + 64000, + 48000, + 44100, + 32000, + 24000, + 22050, + 16000, + 12000, + 11025, + 8000, + 7350, + ) + }, + "aac": { + "supported": ( + 96000, + 88200, + 64000, + 48000, + 44100, + 32000, + 24000, + 22050, + 16000, + 12000, + 11025, + 8000, + 7350, + ) + }, + "ac3": { + "supported": ( + 48000, + 44100, + 32000, + ) + }, + "ogg": {"min": 6, "max": 192000}, + "wav": {"min": 0, "max": float("inf")}, + "aiff": {"min": 0, "max": float("inf")}, + "flac": {"min": 0, "max": 192000}, + } + self.check_ffmpeg() + self.check_ffprobe() + + def check_ffmpeg(self): + """ + Проверяет, установлен ли ffmpeg? + """ + try: + ffmpeg_version_output = subprocess.check_output( + [self.ffmpeg_path, "-version"], text=True + ) + except FileNotFoundError: + if "PYTEST_CURRENT_TEST" not in os.environ: + raise FileNotFoundError("FFMPEG не установлен. Укажите путь к установленному FFMPEG через переменную окружения MVSEPLESS_FFMPEG") + + def check_ffprobe(self): + """ + Проверяет, установлен ли ffprobe? + """ + try: + ffmpeg_version_output = subprocess.check_output( + [self.ffprobe_path, "-version"], text=True + ) + except FileNotFoundError: + if "PYTEST_CURRENT_TEST" not in os.environ: + raise FileNotFoundError("FFPROBE не установлен. Укажите путь к установленному FFPROBE через переменную окружения MVSEPLESS_FFPROBE") + + + def fit_sr( + self, + f: str | Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] = "mp3", + sr: int = 44100 + ) -> int: + """ + Исправляет значение частоты дисректизации выходного файла + + Параметры: + f: Формат вывода + sr: Частота дискретизации (целое число) + Возвращает: + sr: Исправленная частота дискретизации + """ + format_info = self.sample_rates.get(f.lower()) + + if not format_info: + return None # Формат не найден + + if "supported" in format_info: + # Для форматов с конкретным списком + supported_rates = format_info["supported"] + if sr in supported_rates: + return sr + + # Находим ближайшую поддерживаемую частоту + return min(supported_rates, key=lambda x: abs(x - sr)) + + elif "min" in format_info and "max" in format_info: + # Для форматов с диапазоном - обрезаем до границ + min_rate = format_info["min"] + max_rate = format_info["max"] + + if sr < min_rate: + return min_rate + elif sr > max_rate: + return max_rate + else: + return sr + + return None + + def fit_br( + self, + f: str | Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] = "mp3", + br: int = 320 + ) -> int: + """ + Исправляет значение битрейта выходного файла + + Параметры: + f: Формат вывода + br: Битрейт (целое число) + Возвращает: + br: Исправленный битрейт + """ + if f not in self.bitrate_limit: + raise NotSupportedFormat(f"Формат {f} не поддерживается") + + limits = self.bitrate_limit[f] + + if br < limits["min"]: + return limits["min"] + elif br > limits["max"]: + return limits["max"] + else: + return br + + def get_info( + self, + i: str | os.PathLike | Callable | None = None, + ) -> dict[int, dict[int, float]]: + """ + Получает информацию о аудио потоках из файла напрямую через FFMPEG + + Параметры: + i: Путь к выходному файлу + Возвращает: + audio_info: Словарь с информацией о аудиопотоках вида: + + {Номер потока: + { + "sample_rate": Частота дисректизации (является целым числом), + "duration": Длительность аудиопотока (является числом с плавающей точкой) + } + } + """ + audio_info = {} + if i: + if isinstance(i, Path): + i = str(i) + if os.path.exists(i): + cmd = [self.ffprobe_path, "-i", i, "-v", "quiet", "-hide_banner", + "-show_entries", "stream=index,sample_rate,duration", "-select_streams", "a", "-of", "json"] + + process = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + stdout, stderr = process.communicate() + + if process.returncode != 0: + print(f"STDERR: {stderr.decode('utf-8')}") + print(f"STDOUT: {stdout.decode('utf-8')}") + + json_output = json.loads(stdout) + streams = json_output["streams"] + if not streams: + pass + + else: + for a, stream in enumerate(streams): + audio_info[a] = { + "sample_rate": int(stream.get("sample_rate", 0)), + "duration": float(stream.get("duration", 0)) + } + + return audio_info + + else: + raise FileExistsError("Указанного файла не существует") + + else: + raise NotInputFileSpecified("Не указан путь к файлу") + + def check( + self, + i: str | os.PathLike | Callable | None = None + ) -> bool: + """ + Проверяет, является ли файл аудио или видео файлом, поддерживаемым ffmpeg + + Параметры: + i: Путь к выходному файлу + Возвращает: + is_audio_video: Булево значение, является ли файл аудио или видео файлом + """ + if i: + if isinstance(i, Path): + i = str(i) + if os.path.exists(i): + info = self.get_info(i=i) + if info: + list_streams = list(info.keys()) + if len(list_streams) > 0: + if info[0].get("sample_rate") > 0: + return True + else: + return False + else: + return False + else: + return False + else: + raise FileExistsError("Указанного файла не существует") + else: + raise NotInputFileSpecified("Не указан путь к файлу") + + def read( + self, + i: str | os.PathLike | Callable | None = None, + sr: int | None = None, + mono: bool = False, + dtype: DTypeLike = np.float32, + s: int = 0 + ) -> tuple[np.ndarray, int, float]: + """ + Читает аудио-файл, преобразовывая его в массив с аудио данными напрямую через FFMPEG + Является заменой soundfile.read() и librosa.load() + + Параметры: + i: Путь к выходному файлу + sr: Целевая частота дискретизации (Если не указана, то используется частота дискретизации входного файла) + mono: Конвертация в моно (по умолчанию отключена) + dtype: Тип данных (поддерживаются типы: int16, int32, float32, float64; по умолчанию - float32) + s: Номер аудиопотока (по умолчанию 0) + Возвращает: + audio_array: Массив с аудио данными + sr: Частота дискретизации массива + duration: Длительность аудио (количество сэмплов / частота дискретизации) + """ + output_format = self.dtypes_dict.get(dtype, None) + if not output_format: + raise NotSupportedDataType(f"Этот тип данных не поддерживается {dtype}") + if i: + if isinstance(i, Path): + i = str(i) + if os.path.exists(i): + audio_info = self.get_info(i=i) + list_streams = list(audio_info.keys()) + if audio_info.get(s, False): + stream = s + else: + if len(list_streams) > 0: + stream = 0 + else: + raise FileIsNotAudio("В входном файле нет аудио потоков") + + sample_rate_input = audio_info[stream]["sample_rate"] + if sample_rate_input == 0: + raise FileIsNotAudio("В входном файле нет аудио потоков") + + cmd = [ + self.ffmpeg_path, + "-i", i, + "-map", f"0:a:{stream}", "-vn", + "-f", output_format, + "-ac", "1" if mono else "2", + ] + + if sr: + cmd.extend(["-ar", str(sr)]) + else: + sr = sample_rate_input + + cmd.append("pipe:1") + + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + bufsize=10**8 + ) + + try: + + raw_audio, stderr = process.communicate(timeout=300) + + if process.returncode != 0: + raise ErrorDecode(f"FFmpeg error: {stderr.decode()}") + + except subprocess.TimeoutExpired: + process.kill() + raise ErrorDecode("FFmpeg timeout при чтении файла") + + audio_array = np.frombuffer(raw_audio, dtype=dtype) + + channels = 1 if mono else 2 + audio_array = audio_array.reshape((-1, channels)).T + if audio_array.ndim > 1 and channels == 1: + audio_array = np.mean(audio_array, axis=tuple(range(audio_array.ndim - 1))) + + len_samples = float(audio_array.shape[-1]) + + duration = len_samples / sr + + print(f"Частота дискретизации: {sr}") + + return audio_array, sr, duration + else: + raise FileExistsError("Указанного файла не существует") + + else: + raise NotInputFileSpecified("Не указан путь к файлу") + + def write( + self, + o: str | os.PathLike | Callable | None = None, + array: np.ndarray = np.array([], dtype=np.float32), + sr: int = 44100, + of: str | Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] | None = None, + br: str | int | None = None + ) -> str: + """ + Записывает numpy-массив с аудио данными в файл напрямую через ffmpeg. + Является заменой soundfile.write() + + Параметры: + o: Путь к выходному файлу + array: Массив с аудио данными (поддерживаются типы: int16, int32, float32, float64) + sr: Частота дискретизации массива + of: Формат вывода (по умолчанию mp3) + br: Битрейт для кодеков, сжимающих аудио с потерями + Возвращает: + o: Путь к выходному файлу + """ + if isinstance(array, np.ndarray): + + if len(array.shape) == 1: + array = array.reshape(-1, 1) + elif len(array.shape) == 2: + if array.shape[0] == 2: + array = array.T + else: + raise ValueError("numpy-массив должен быть либо одномерным, либо двухмерным") + + if array.dtype == np.int16: + input_format = "s16le" + elif array.dtype == np.int32: + input_format = "s32le" + elif array.dtype == np.float32: + input_format = "f32le" + elif array.dtype == np.float64: + input_format = "f64le" + else: + raise NotSupportedDataType(f"Этот тип данных не поддерживается {array.dtype}") + + if array.shape[1] == 1: + audio_bytes = array.tobytes() + + channels = 1 + + elif array.shape[1] == 2: + audio_bytes = array.tobytes() + + channels = 2 + else: + raise ValueError("numpy-массив должен содержать 1 или 2 канала") + + else: + raise ValueError("Вход должен быть numpy-массивом") + + if o: + if isinstance(o, Path): + o = str(o) + output_dir = os.path.dirname(o) + output_base = os.path.basename(o) + output_name, output_ext = os.path.splitext(output_base) + if output_dir != "": + os.makedirs(output_dir, exist_ok=True) + if output_ext == "": + if of: + o += f".{of}" + else: + o += f".mp3" + elif output_ext == ".": + if of: + o += f"{of}" + else: + o += f"mp3" + else: + raise NotOutputFileSpecified("Не указан путь к выходному файлу") + + if of: + if of in self.output_formats: + output_name, output_ext = os.path.splitext(o) + if output_ext == f".{of}": + pass + else: + o = f"{os.path.join(output_dir, output_name)}.{of}" + else: + raise NotSupportedFormat(f"Неподдерживаемый формат: {of}") + else: + of = os.path.splitext(o)[1].strip(".") + if of in self.output_formats: + pass + else: + raise NotSupportedFormat(f"Неподдерживаемый формат: {of}") + + if sr: + if isinstance(sr, int): + sample_rate_fixed = self.fit_sr(f=of, sr=sr) + elif isinstance(sr, float): + sr = int(sr) + sample_rate_fixed = self.fit_sr(f=of, sr=sr) + else: + raise SampleRateError(f"Частота дискретизации должна быть числом\n\nЗначение: {sr}\nТип: {type(sr)}") + else: + raise SampleRateError("Не указана частота дискретизации") + + bitrate_fixed = "320k" + + if of not in ["wav", "flac", "aiff"]: + if br: + if isinstance(br, int): + bitrate_fixed = self.fit_br(f=of, br=br) + elif isinstance(br, float): + bitrate_fixed = self.fit_br(f=of, br=int(br)) + elif isinstance(br, str): + bitrate_fixed = self.fit_br(f=of, br=int(br.strip("k").strip("K"))) + else: + bitrate_fixed = self.fit_br(f=of, br=320) + else: + bitrate_fixed = self.fit_br(of, 320) + + format_settings = { + "wav": [ + "-c:a", + "pcm_f32le", + "-sample_fmt", + "flt", + ], + "aiff": [ + "-c:a", + "pcm_f32le", + "-sample_fmt", + "flt", + ], + "flac": [ + "-c:a", + "flac", + "-compression_level", + "12", + "-sample_fmt", + "s32", + ], + "mp3": [ + "-c:a", + "libmp3lame", + "-b:a", + f"{bitrate_fixed}k", + ], + "ogg": [ + "-c:a", + "libvorbis", + "-b:a", + f"{bitrate_fixed}k", + ], + "opus": [ + "-c:a", + "libopus", + "-b:a", + f"{bitrate_fixed}k", + ], + "m4a": [ + "-c:a", + "aac", + "-b:a", + f"{bitrate_fixed}k", + ], + "aac": [ + "-c:a", + "aac", + "-b:a", + f"{bitrate_fixed}k", + ], + "ac3": [ + "-c:a", + "ac3", + "-b:a", + f"{bitrate_fixed}k", + ], + } + + cmd = [ + self.ffmpeg_path, + "-y", + "-f", + input_format, + "-ar", + str(sr), + "-ac", + str(channels), + "-i", + "pipe:0", + "-ac", + str(channels), + ] + + cmd.extend(["-ar", str(sample_rate_fixed)]) + cmd.extend(format_settings[of]) + o_dir, o_base = os.path.split(o) + o_base_n, o_base_ext = os.path.splitext(o_base) + o_base_n = self.sanitize(o_base_n) + o_base_n = self.short(o_base_n) + o = os.path.join(o_dir, f"{o_base_n}{o_base_ext}") + o = self.iter(o) + cmd.append(o) + + process = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + try: + stdout, stderr = process.communicate(input=audio_bytes, timeout=300) + except subprocess.TimeoutExpired: + process.kill() + raise ErrorEncode("FFmpeg timeout: операция заняла слишком много времени") + + if process.returncode != 0: + raise ErrorEncode(f"FFmpeg завершился с ошибкой (код: {process.returncode})") + + return os.path.abspath(o) + +class Inverter(Audio): + def __init__(self): + super().__init__() + self.test = "test" + self.w_types = [ + "boxcar", # Прямоугольное окно + "triang", # Треугольное окно + "blackman", # Окно Блэкмана + "hamming", # Окно Хэмминга + "hann", # Окно Ханна + "bartlett", # Окно Бартлетта + "flattop", # Окно с плоской вершиной + "parzen", # Окно Парзена + "bohman", # Окно Бохмана + "blackmanharris", # Окно Блэкмана-Харриса + "nuttall", # Окно Нуттала + "barthann", # Окно Бартлетта-Ханна + "cosine", # Косинусное окно + "exponential", # Экспоненциальное окно + "tukey", # Окно Туки + "taylor", # Окно Тейлора + "lanczos", # Окно Ланцоша + ] + + def load_audio(self, filepath): + """Загрузка аудиофайла с помощью librosa""" + try: + y, sr, _ = self.read(i=filepath, sr=None, mono=False) + return y, sr + except Exception as e: + print(f"Ошибка загрузки аудио: {e}") + return None, None + + def process_channel(self, y1_ch, y2_ch, sr, method, w_size=2048, overlap=2, w_type="hann"): + """Обработка одного аудиоканала""" + HOP_LENGTH = w_size // overlap + if method == "waveform": + return y1_ch - y2_ch + + elif method == "spectrogram": + # Вычисляем спектрограммы + S1 = librosa.stft( + y1_ch, n_fft=w_size, hop_length=HOP_LENGTH, win_length=w_size + ) + S2 = librosa.stft( + y2_ch, n_fft=w_size, hop_length=HOP_LENGTH, win_length=w_size + ) + + # Амплитудные спектрограммы + mag1 = np.abs(S1) + mag2 = np.abs(S2) + + # Спектральное вычитание + mag_result = np.maximum(mag1 - mag2, 0) + + # Сохраняем фазовую информацию исходного сигнала + phase = np.angle(S1) + + # Комбинируем амплитуду результата с фазой + S_result = mag_result * np.exp(1j * phase) + + # Обратное преобразование + return librosa.istft( + S_result, + n_fft=w_size, + hop_length=HOP_LENGTH, + win_length=w_size, + length=len(y1_ch), + ) + + def process_audio(self, audio1_path, audio2_path, out_format, method, output_path="./inverted.mp3", w_size=2048, overlap=2, w_type="hann"): + # Загрузка аудиофайлов + y1, sr1 = self.load_audio(audio1_path) + y2, sr2 = self.load_audio(audio2_path) + + if sr1 is None or sr2 is None: + raise Exception("Произошла ошибка при чтении файлов") + + # Определяем количество каналов + channels1 = 1 if y1.ndim == 1 else y1.shape[0] + channels2 = 1 if y2.ndim == 1 else y2.shape[0] + + # Преобразование в форму (samples, channels) + if channels1 > 1: + y1 = y1.T # (channels, samples) -> (samples, channels) + else: + y1 = y1.reshape(-1, 1) + + if channels2 > 1: + y2 = y2.T # (channels, samples) -> (samples, channels) + else: + y2 = y2.reshape(-1, 1) + + if sr1 != sr2: + if channels2 > 1: + # Ресемплинг для каждого канала отдельно + y2_resampled_list = [] + for c in range(channels2): + channel_resampled = librosa.resample( + y2[:, c], orig_sr=sr2, target_sr=sr1 + ) + y2_resampled_list.append(channel_resampled) + + # Находим минимальную длину среди всех каналов + min_channel_length = min(len(ch) for ch in y2_resampled_list) + + # Обрезаем все каналы до одинаковой длины и собираем в массив + y2_resampled = np.zeros((min_channel_length, channels2), dtype=np.float32) + for c, channel in enumerate(y2_resampled_list): + y2_resampled[:, c] = channel[:min_channel_length] + + y2 = y2_resampled + else: + y2 = librosa.resample(y2[:, 0], orig_sr=sr2, target_sr=sr1) + y2 = y2.reshape(-1, 1) + sr2 = sr1 + + # Приводим к одинаковой длине + min_len = min(len(y1), len(y2)) + y1 = y1[:min_len] + y2 = y2[:min_len] + + # Обрабатываем каждый канал отдельно + result_channels = [] + + # Если основной сигнал моно, а удаляемый стерео - преобразуем удаляемый в моно + if channels1 == 1 and channels2 > 1: + y2 = y2.mean(axis=1, keepdims=True) + channels2 = 1 + + for c in range(channels1): + # Выбираем канал для основного сигнала + y1_ch = y1[:, c] + + # Выбираем канал для удаляемого сигнала + if channels2 == 1: + y2_ch = y2[:, 0] + else: + # Если каналов удаляемого сигнала больше, используем соответствующий канал + y2_ch = y2[:, min(c, channels2 - 1)] + + # Обрабатываем канал + result_ch = self.process_channel(y1_ch, y2_ch, sr1, method, w_size=w_size, overlap=overlap, w_type=w_type) + result_channels.append(result_ch) + + # Собираем каналы в один массив + if len(result_channels) > 1: + result = np.column_stack(result_channels) + else: + result = np.array(result_channels[0]) + + # Нормализация (предотвращение клиппинга) + if result.ndim > 1: + # Для многоканального аудио нормализуем каждый канал отдельно + for c in range(result.shape[1]): + channel = result[:, c] + max_val = np.max(np.abs(channel)) + if max_val > 0: + result[:, c] = channel * 0.9 / max_val + else: + max_val = np.max(np.abs(result)) + if max_val > 0: + result = result * 0.9 / max_val + + inverted = self.write(o=output_path, array=result.T, sr=sr1, of=out_format, br="320k") + return inverted diff --git a/mvsepless/downloader.py b/mvsepless/downloader.py new file mode 100644 index 0000000000000000000000000000000000000000..d5eeb1ff7d4c518c36271ae1b7b87586bb83471c --- /dev/null +++ b/mvsepless/downloader.py @@ -0,0 +1,92 @@ +import os +import yt_dlp +import time +from tqdm import tqdm +import urllib.request + +DOWNLOAD_DIR = os.environ.get( + "MVSEPLESS_DOWNLOAD_DIR", os.path.join(os.getcwd(), "downloaded") +) + +def dw_file(url_model: str, local_path: str, retries: int = 30): + dir_name = os.path.dirname(local_path) + if dir_name != "": + os.makedirs(dir_name, exist_ok=True) + + class TqdmUpTo(tqdm): + def update_to(self, b=1, bsize=1, tsize=None): + if tsize is not None: + self.total = tsize + self.update(b * bsize - self.n) + + for attempt in range(retries): + try: + with TqdmUpTo( + unit="B", + unit_scale=True, + unit_divisor=1024, + miniters=1, + desc=os.path.basename(local_path), + ) as t: + urllib.request.urlretrieve( + url_model, local_path, reporthook=t.update_to + ) + # Если дошли сюда - загрузка успешна, прерываем цикл + break + except Exception as e: + print(f"Попытка {attempt + 1}/{retries} не удалась. Ошибка: {e}") + if attempt < retries - 1: # Если это не последняя попытка + print("Повторная попытка...") + time.sleep(2) # Небольшая задержка перед повторной попыткой + else: + print("Все попытки загрузки завершились неудачно") + raise # Пробрасываем исключение дальше + +def dw_yt_dlp( + url, + output_dir=None, + cookie=None, + output_format="mp3", + output_bitrate="320", + title=None, +): + # Подготовка шаблона имени файла + outtmpl = "%(title)s.%(ext)s" if title is None else f"{title}.%(ext)s" + + ydl_opts = { + "format": "bestaudio/best", + "outtmpl": os.path.join(DOWNLOAD_DIR if not output_dir else output_dir, outtmpl), + "postprocessors": [ + { + "key": "FFmpegExtractAudio", + "preferredcodec": output_format, + "preferredquality": output_bitrate, + } + ], + "noplaylist": True, # Скачивать только одно видео, не плейлист + "quiet": True, # Отключить вывод в консоль + "no_warnings": True, # Скрыть предупреждения + } + + # Добавляем cookies если указаны + if cookie and os.path.exists(cookie): + ydl_opts["cookiefile"] = cookie + + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + try: + info = ydl.extract_info(url, download=True) + if "_type" in info and info["_type"] == "playlist": + # Для плейлистов берем первое видео + entry = info["entries"][0] + filename = ydl.prepare_filename(entry) + else: + # Для одиночного видео + filename = ydl.prepare_filename(info) + + # Заменяем оригинальное расширение на выбранный формат + base, _ = os.path.splitext(filename) + audio_file = base + f".{output_format}" + + return os.path.join(DOWNLOAD_DIR, audio_file) + except Exception as e: + return None \ No newline at end of file diff --git a/mvsepless/ensemble.py b/mvsepless/ensemble.py new file mode 100644 index 0000000000000000000000000000000000000000..c9545704eec617954188d2d74bb88eb30d9a8776 --- /dev/null +++ b/mvsepless/ensemble.py @@ -0,0 +1,224 @@ +# coding: utf-8 +__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/" + +import os +import sys +import librosa +import tempfile +import numpy as np +import argparse +if not __package__: + from audio import Audio + from namer import Namer +else: + from .audio import Audio + from .namer import Namer + +audio = Audio() + +def stft(wave, nfft, hl): + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl) + spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl) + spec = np.asfortranarray([spec_left, spec_right]) + return spec + + +def istft(spec, hl, length): + spec_left = np.asfortranarray(spec[0]) + spec_right = np.asfortranarray(spec[1]) + wave_left = librosa.istft(spec_left, hop_length=hl, length=length) + wave_right = librosa.istft(spec_right, hop_length=hl, length=length) + wave = np.asfortranarray([wave_left, wave_right]) + return wave + + +def absmax(a, *, axis): + dims = list(a.shape) + dims.pop(axis) + indices = np.ogrid[tuple(slice(0, d) for d in dims)] + argmax = np.abs(a).argmax(axis=axis) + # Convert indices to list before insertion + indices = list(indices) + indices.insert(axis % len(a.shape), argmax) + return a[tuple(indices)] + + +def absmin(a, *, axis): + dims = list(a.shape) + dims.pop(axis) + indices = np.ogrid[tuple(slice(0, d) for d in dims)] + argmax = np.abs(a).argmin(axis=axis) + indices.insert((len(a.shape) + axis) % len(a.shape), argmax) + return a[tuple(indices)] + + +def lambda_max(arr, axis=None, key=None, keepdims=False): + idxs = np.argmax(key(arr), axis) + if axis is not None: + idxs = np.expand_dims(idxs, axis) + result = np.take_along_axis(arr, idxs, axis) + if not keepdims: + result = np.squeeze(result, axis=axis) + return result + else: + return arr.flatten()[idxs] + + +def lambda_min(arr, axis=None, key=None, keepdims=False): + idxs = np.argmin(key(arr), axis) + if axis is not None: + idxs = np.expand_dims(idxs, axis) + result = np.take_along_axis(arr, idxs, axis) + if not keepdims: + result = np.squeeze(result, axis=axis) + return result + else: + return arr.flatten()[idxs] + + +def average_waveforms(pred_track, weights, algorithm): + """ + :param pred_track: shape = (num, channels, length) + :param weights: shape = (num, ) + :param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft + :return: averaged waveform in shape (channels, length) + """ + + pred_track = np.array(pred_track) + final_length = pred_track.shape[-1] + + mod_track = [] + for i in range(pred_track.shape[0]): + if algorithm == "avg_wave": + mod_track.append(pred_track[i] * weights[i]) + elif algorithm in ["median_wave", "min_wave", "max_wave"]: + mod_track.append(pred_track[i]) + elif algorithm in ["avg_fft", "min_fft", "max_fft", "median_fft"]: + spec = stft(pred_track[i], nfft=2048, hl=1024) + if algorithm in ["avg_fft"]: + mod_track.append(spec * weights[i]) + else: + mod_track.append(spec) + pred_track = np.array(mod_track) + + if algorithm in ["avg_wave"]: + pred_track = pred_track.sum(axis=0) + pred_track /= np.array(weights).sum().T + elif algorithm in ["median_wave"]: + pred_track = np.median(pred_track, axis=0) + elif algorithm in ["min_wave"]: + pred_track = np.array(pred_track) + pred_track = lambda_min(pred_track, axis=0, key=np.abs) + elif algorithm in ["max_wave"]: + pred_track = np.array(pred_track) + pred_track = lambda_max(pred_track, axis=0, key=np.abs) + elif algorithm in ["avg_fft"]: + pred_track = pred_track.sum(axis=0) + pred_track /= np.array(weights).sum() + pred_track = istft(pred_track, 1024, final_length) + elif algorithm in ["min_fft"]: + pred_track = np.array(pred_track) + pred_track = lambda_min(pred_track, axis=0, key=np.abs) + pred_track = istft(pred_track, 1024, final_length) + elif algorithm in ["max_fft"]: + pred_track = np.array(pred_track) + pred_track = absmax(pred_track, axis=0) + pred_track = istft(pred_track, 1024, final_length) + elif algorithm in ["median_fft"]: + pred_track = np.array(pred_track) + pred_track = np.median(pred_track, axis=0) + pred_track = istft(pred_track, 1024, final_length) + return pred_track + + +def ensemble_audio_files( + files, output="res.wav", ensemble_type="avg_wave", weights=None, out_format="wav", add_wav=False +) -> str | tuple[str, str]: + """ + Основная функция для объединения аудиофайлов + + :param files: список путей к аудиофайлам + :param output: путь для сохранения результата + :param ensemble_type: метод объединения (avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft) + :param weights: список весов для каждого файла (None для равных весов) + :return: None + """ + print("Алгоритм склеивания: {}".format(ensemble_type)) + print("Количество входных файлов: {}".format(len(files))) + if weights is not None: + weights = np.array(weights) + else: + weights = np.ones(len(files)) + print("Весы: {}".format(weights)) + print("Имя выходного файла: {}".format(output)) + + data = [] + sr = None + max_length = 0 + max_channels = 0 + + # Первый проход: определяем максимальную длину и количество каналов + for f in files: + if not os.path.isfile(f): + print("Не удается найти файл: {}. Check paths.".format(f)) + exit() + print("Читается файл: {}".format(f)) + wav, current_sr, _ = audio.read(i=f, sr=None, mono=False) + if sr is None: + sr = current_sr + elif sr != current_sr: + print("Частота дискретизации на всех файлах должна быть одинаковой") + exit() + + # Определяем количество каналов + if wav.ndim == 1: + channels = 1 + length = len(wav) + else: + channels = wav.shape[0] + length = wav.shape[1] + + max_length = max(max_length, length) + max_channels = max(max_channels, channels) + print("Форма сигнала: {} частота дискретизации: {}".format(wav.shape, sr)) + + # Второй проход: обработка и выравнивание файлов + for f in files: + wav, current_sr, _ = audio.read(i=f, sr=None, mono=False) + + # Обработка каналов + if wav.ndim == 1: + # Моно -> стерео + wav = np.vstack([wav, wav]) + elif wav.shape[0] == 1: + # Один канал -> стерео + wav = np.vstack([wav[0], wav[0]]) + elif wav.shape[0] > 2: + # Более 2 каналов -> берем первые два + wav = wav[:2, :] + + # Выравнивание длины + if wav.shape[1] < max_length: + pad_width = ((0, 0), (0, max_length - wav.shape[1])) + wav = np.pad(wav, pad_width, mode="constant") + elif wav.shape[1] > max_length: + wav = wav[:, :max_length] + + data.append(wav) + + data = np.array(data) + res = average_waveforms(data, weights, ensemble_type) + print("Форма результата: {}".format(res.shape)) + + output_wav = f"{output}_orig.wav" + output = f"{output}.{out_format}" + + output = audio.write(o=output, array=res.T, sr=sr, of=out_format, br="320k") + if add_wav: + output_wav = audio.write(o=output_wav, array=res.T, sr=sr, of="wav") + return output, output_wav + + else: + return output \ No newline at end of file diff --git a/mvsepless/infer.py b/mvsepless/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..be215b92020d0fb23238451583a0f08afda43139 --- /dev/null +++ b/mvsepless/infer.py @@ -0,0 +1,623 @@ +import os +import sys +import json +import argparse +import time +from datetime import datetime +import gc +import glob +import yaml +import torch +import numpy as np +import soundfile as sf +import torch.nn as nn + +from typing import Literal + +from audio import Audio +from namer import Namer + +namer = Namer() +audio = Audio() + +from infer_utils import ( + prefer_target_instrument, + demix, + get_model_from_config +) + + +def normalize_peak(audio, peak): + current_peak = np.max(np.abs(audio)) + if current_peak == 0: + return audio # избегаем деления на ноль + scale_factor = peak / current_peak + return audio * scale_factor + + +gc.enable() + + +def cleanup_model(model): + try: + if isinstance(model, torch.nn.DataParallel): + model = model.module + + model.to("cpu") + + for name, param in list(model.named_parameters()): + del param + for name, buf in list(model.named_buffers()): + del buf + + del model + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + gc.collect() + sys.stdout.write(json.dumps({"cleanup": "Модель выгружена из памяти"}, ensure_ascii=False) + '\n') + sys.stdout.flush() + except Exception as e: + sys.stdout.write(json.dumps({"error": f"Ошибка при выгрузке модели: {str(e)}"}, ensure_ascii=False) + '\n') + sys.stdout.flush() + + +def once_inference( + path: str = None, + model: any = None, + config: any = None, + device: any = None, + model_type: str = None, + extract_instrumental: bool = False, + detailed_pbar: bool = False, + output_format: Literal[ + "mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "aiff" + ] = "mp3", + output_bitrate: str = "320k", + use_tta: bool = False, + verbose: bool = False, + model_name: str = None, + sample_rate: int = 44100, + instruments: list = [], + store_dir: str = None, + template: str = None, + selected_instruments: list = [], + model_id: int = 0, +): + results = [] + sys.stdout.write(json.dumps({"reading": path}, ensure_ascii=False) + '\n') + sys.stdout.flush() + sys.stdout.write(json.dumps({"selected_stems": selected_instruments}, ensure_ascii=False) + '\n') + sys.stdout.flush() + sys.stdout.write(json.dumps({"stems": list(instruments)}, ensure_ascii=False) + '\n') + sys.stdout.flush() + + if config.training.target_instrument is not None: + sys.stdout.write(json.dumps({"target_instrument": config.training.target_instrument}, ensure_ascii=False) + '\n') + sys.stdout.flush() + + try: + mix, sr, _ = audio.read(i=path, sr=sample_rate, mono=False) + except Exception as e: + error_msg = f"Не удалось прочитать аудио: {path}\nОшибка: {e}" + sys.stdout.write(json.dumps({"error": error_msg}, ensure_ascii=False) + '\n') + sys.stdout.flush() + return results + + mix_orig = mix.copy() + + mean = std = None + if config.inference.get("normalize", False): + mono = mix.mean(0) + mean = mono.mean() + std = mono.std() + mix = (mix - mean) / std + + if use_tta: + track_proc_list = [mix.copy(), mix[::-1].copy(), -1.0 * mix.copy()] + else: + track_proc_list = [mix.copy()] + full_result = [] + for m in track_proc_list: + try: + waveforms = demix( + config, model, m, device, pbar=detailed_pbar, model_type=model_type + ) + + full_result.append(waveforms) + except Exception as e: + sys.stdout.write(json.dumps({"error": f"Ошибка при демиксе: {e}"}, ensure_ascii=False) + '\n') + sys.stdout.flush() + del m + gc.collect() + + if not full_result: + sys.stdout.write(json.dumps({"error": "Пустой результат демикса."}, ensure_ascii=False) + '\n') + sys.stdout.flush() + return results + + waveforms = full_result[0] + for i in range(1, len(full_result)): + d = full_result[i] + for el in d: + if i == 2: + waveforms[el] += -1.0 * d[el] + elif i == 1: + waveforms[el] += d[el][::-1].copy() + else: + waveforms[el] += d[el] + for el in waveforms: + waveforms[el] /= len(full_result) + + if ( + extract_instrumental and config.training.target_instrument is not None + ): # Если включен "Extract Instrumental / Извлечь инструментал" и найден целевой инструмент + second_stem = [ + s + for s in config.training.instruments + if s != config.training.target_instrument + ] + if second_stem: + second_stem_key = second_stem[0] + if second_stem_key not in instruments: + instruments.append(second_stem_key) + waveforms[second_stem_key] = mix_orig - waveforms[instruments[0]] + + elif ( + extract_instrumental + and selected_instruments + and config.training.target_instrument is None + ): # Если включен "Extract Instrumental / Извлечь инструментал" и выбраны инструменты, то создаются стемы "inverted -" и "inverted +" (если не найден целевого инструмент) + + + all_instruments = config.training.instruments + if len(all_instruments) > 2: + + waveforms["inverted -"] = mix_orig.copy() + for instr in instruments: + if instr in waveforms: + waveforms["inverted -"] -= waveforms[ + instr + ] # стем "inverted -": вычитание выбранного стема из оригинального сигнала (не всегда хорошо) + + if "inverted -" not in instruments: + instruments.append("inverted -") + + unselected_stems = [s for s in all_instruments if s not in selected_instruments] + if unselected_stems: + waveforms["inverted +"] = np.zeros_like(mix_orig) + for stem in unselected_stems: + if stem in waveforms: + waveforms["inverted +"] += waveforms[ + stem + ] # стем "inverted +": сложение не выбранных инструментов в один стем + if "inverted +" not in instruments: + instruments.append("inverted +") + + peak = np.max(np.abs(waveforms["inverted -"])) + waveforms["inverted +"] = normalize_peak(waveforms["inverted +"], peak) + + elif ( + extract_instrumental + and not selected_instruments + and config.training.target_instrument is None + and ( + all( + instr in config.training.instruments + for instr in ["bass", "drums", "other", "vocals"] + ) + or all( + instr in config.training.instruments + for instr in ["bass", "drums", "other", "vocals", "piano", "guitar"] + ) + ) + ): + + waveforms["instrumental -"] = mix_orig.copy() + waveforms["instrumental -"] -= waveforms[ + "vocals" + ] # стем "inverted -": вычитание выбранного стема из оригинального сигнала (не всегда хорошо) + + if "instrumental -" not in instruments: + instruments.append("instrumental -") + + all_instruments = config.training.instruments + non_vocal_stems = [s for s in all_instruments if s not in ["vocals"]] + if non_vocal_stems: + waveforms["instrumental +"] = np.zeros_like(mix_orig) + for stem in non_vocal_stems: + if stem in waveforms: + waveforms["instrumental +"] += waveforms[ + stem + ] # стем "inverted +": сложение не выбранных инструментов в один стем + if "instrumental +" not in instruments: + instruments.append("instrumental +") + + peak = np.max(np.abs(waveforms["instrumental -"])) + waveforms["instrumental +"] = normalize_peak(waveforms["instrumental +"], peak) + + template = namer.sanitize(template) + template = namer.dedup_template(template, keys=["NAME", "MODEL", "STEM", "ID"]) + template = namer.short(template, length=40) + + for instr in instruments: + try: + estimates = waveforms[instr].T + if mean is not None and std is not None: + estimates = estimates * std + mean + + file_name = os.path.splitext(os.path.basename(path))[0] + file_name_shorted = namer.short_input_name_template(template, STEM=instr, MODEL=model_name, ID=model_id, NAME=file_name) + custom_name = namer.template( + template, STEM=instr, MODEL=model_name, ID=model_id, NAME=file_name_shorted + ) + output_path = os.path.join(store_dir, f"{custom_name}.{output_format}") + + sys.stdout.write(json.dumps({"writing": output_path}, ensure_ascii=False) + '\n') + sys.stdout.flush() + + output_path = audio.write( + o=output_path, array=estimates, sr=sr, of=output_format, br=output_bitrate + ) # запись стема в аудио файл с помощью универсальной функции + + results.append( + (instr, output_path) + ) # запись информации о разделении: (название стема, путь к файлу) + del estimates + except Exception as e: + sys.stdout.write(json.dumps({"error": f"Ошибка при обработке {instr}: {e}"}, ensure_ascii=False) + '\n') + sys.stdout.flush() + gc.collect() + + del mix, mix_orig, waveforms, full_result + gc.collect() + + return results + + +def run_inference( + model: any = None, + config: any = None, + input_path: str = None, + store_dir: str = None, + device: any = None, + model_type: str = None, + extract_instrumental: bool = False, + disable_detailed_pbar: bool = False, + output_format: Literal[ + "mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "aiff" + ] = "mp3", + output_bitrate: str = "320k", + use_tta: bool = False, + verbose: bool = False, + model_name: str = None, + template: str = "NAME_STEM", + selected_instruments: list = [], + model_id: int = 0, +): + start_time = time.time() + if model_type != "vr": + model.eval() + sample_rate = 44100 + if "sample_rate" in config.audio: + sample_rate = config.audio["sample_rate"] + + instruments = prefer_target_instrument(config) + + if config.training.target_instrument is not None: + sys.stdout.write(json.dumps({"info": "Целевой инструмент найден в конфигурации модели. Выбранные стемы будут проигнорированы."}, ensure_ascii=False) + '\n') + sys.stdout.flush() + else: + if selected_instruments is not None and selected_instruments != []: + instruments = [ + instr for instr in instruments if instr in selected_instruments + ] + if verbose: + sys.stdout.write(json.dumps({"selected_stems": instruments}, ensure_ascii=False) + '\n') + sys.stdout.flush() + + os.makedirs(store_dir, exist_ok=True) + + detailed_pbar = not disable_detailed_pbar + + results = once_inference( + path=input_path, + model=model, + config=config, + device=device, + model_type=model_type, + extract_instrumental=extract_instrumental, + detailed_pbar=detailed_pbar, + output_format=output_format, + output_bitrate=output_bitrate, + use_tta=use_tta, + verbose=verbose, + model_name=model_name, + sample_rate=sample_rate, + instruments=instruments, + store_dir=store_dir, + template=template, + selected_instruments=selected_instruments, + model_id=model_id, + ) + + time.sleep(1) + time_taken = time.time() - start_time + sys.stdout.write(json.dumps({"time": f"{time_taken:.2f} сек."}, ensure_ascii=False) + '\n') + sys.stdout.flush() + sys.stdout.write(json.dumps({"done": results}, ensure_ascii=False) + '\n') + sys.stdout.flush() + return results + + +def load_model(model_type, config_path, start_check_point, device_ids, force_cpu=False): + device = "cpu" + if force_cpu: + device = "cpu" + elif torch.cuda.is_available(): + sys.stdout.write(json.dumps({"info": "Разделение выполняется на ядрах CUDA. Для выполнения на процессоре установите force_cpu=True."}, ensure_ascii=False) + '\n') + sys.stdout.flush() + device = "cuda" + + if device_ids is None: + device = "cuda:0" + elif isinstance(device_ids, (list, tuple)): + device = f"cuda:{device_ids[0]}" if device_ids else "cuda:0" + elif isinstance(device_ids, bool): + device = "cuda:0" + else: + device = f"cuda:{int(device_ids)}" + elif torch.backends.mps.is_available(): + device = "mps" + + sys.stdout.write(json.dumps({"device": device}, ensure_ascii=False) + '\n') + sys.stdout.flush() + + model_load_start_time = time.time() + torch.backends.cudnn.benchmark = True + + model, config = get_model_from_config(model_type, config_path) + + if model_type == "vr": + model.load_checkpoint(start_check_point, device) + model.settings(enable_post_process=False, + post_process_threshold=config.inference.post_process_threshold, + batch_size=config.inference.batch_size, + window_size=config.inference.window_size, + high_end_process=config.inference.high_end_process, + primary_stem=config.training.instruments[0], + secondary_stem=config.training.instruments[1] + ) + return model, config, device + + elif model_type == "mdxnet": + if start_check_point != "": + sys.stdout.write(json.dumps({"checkpoint": start_check_point}) + '\n') + sys.stdout.flush() + model.init_onnx_session(start_check_point, device) + + return model, config, device + + else: + if start_check_point != "": + sys.stdout.write(json.dumps({"checkpoint": start_check_point}) + '\n') + sys.stdout.flush() + + if model_type in ["htdemucs", "apollo"]: + state_dict = torch.load( + start_check_point, map_location=device, weights_only=False + ) + if "state" in state_dict: + state_dict = state_dict["state"] + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + else: + try: + state_dict = torch.load( + start_check_point, map_location=device, weights_only=True + ) + except torch.serialization.pickle.UnpicklingError: + with torch.serialization.safe_globals([torch._C._nn.gelu]): + state_dict = torch.load( + start_check_point, map_location=device, weights_only=True + ) + try: + model.load_state_dict(state_dict) + except RuntimeError: + model.load_state_dict(state_dict, strict=False) + + sys.stdout.write(json.dumps({"stems": list(config.training.instruments)}, ensure_ascii=False) + '\n') + sys.stdout.flush() + + if ( + isinstance(device_ids, (list, tuple)) + and len(device_ids) > 1 + and not force_cpu + and torch.cuda.is_available() + ): + model = nn.DataParallel(model, device_ids=[int(d) for d in device_ids]) + + model = model.to(device) + + load_time = time.time() - model_load_start_time + + sys.stdout.write(json.dumps({"model_load_time": f"{load_time:.2f} сек."}, ensure_ascii=False) + '\n') + sys.stdout.flush() + + return model, config, device + + +def mvsep_offline( + input_path, + store_dir, + model_type, + config_path, + start_check_point, + extract_instrumental, + output_format, + output_bitrate, + model_name, + template, + device_ids=None, + disable_detailed_pbar=False, + use_tta=False, + force_cpu=False, + verbose=False, + selected_instruments=None, + model_id=0, +): + model, config, device = load_model( + model_type, config_path, start_check_point, device_ids, force_cpu + ) + + results = run_inference( + model=model, + config=config, + input_path=input_path, + store_dir=store_dir, + device=device, + model_type=model_type, + extract_instrumental=extract_instrumental, + disable_detailed_pbar=disable_detailed_pbar, + output_format=output_format, + output_bitrate=output_bitrate, + use_tta=use_tta, + verbose=verbose, + model_name=model_name, + template=template, + selected_instruments=selected_instruments, + model_id=model_id, + ) + + if model_type != "vr": + cleanup_model(model) + del config + gc.collect() + return results + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Модифицированный Music-Source-Separation-Training для разделения аудио на источники" + ) + + # Обязательные аргументы + parser.add_argument("--input", type=str, help="Путь к входному файлу или папке") + parser.add_argument( + "--store_dir", type=str, required=True, help="Путь для сохранения результатов" + ) + + # Основные параметры модели + parser.add_argument( + "--model_type", + type=str, + default="htdemucs", + choices=[ + "mel_band_roformer", + "bs_roformer", + "mdx23c", + "scnet", + "htdemucs", + "bandit", + "bandit_v2", + "mdxnet", + "vr" + ], + help="Тип модели (по умолчанию: htdemucs)", + ) + parser.add_argument( + "--config_path", + type=str, + required=True, + help="Путь к конфигурационному файлу модели", + ) + parser.add_argument( + "--start_check_point", type=str, required=True, help="Путь к чекпоинту модели" + ) + + # Параметры вывода + parser.add_argument( + "--output_format", + type=str, + default="wav", + choices=audio.output_formats, + help="Формат выходных файлов", + ) + parser.add_argument( + "--output_bitrate", type=str, required=True, help="Битрейт выходного файла" + ) + + parser.add_argument( + "--selected_instruments", + nargs="+", + help="Список стемов для сохранения (например: vocals drums)", + ) + parser.add_argument( + "--extract_instrumental", + action="store_true", + help="Извлечь инструментальную версию", + ) + parser.add_argument( + "--template", + type=str, + default="NAME_STEM", + help="Шаблон для имен выходных файлов", + ) + parser.add_argument( + "--model_name", + type=str, + default="model", + help="Имя модели для шаблона имен файлов", + ) + parser.add_argument("-m_id", "--model_id", type=int, required=True, help="Model ID") + parser.add_argument( + "--device_ids", nargs="+", help="ID GPU устройств для использования" + ) + parser.add_argument( + "--force_cpu", action="store_true", help="Принудительно использовать CPU" + ) + parser.add_argument( + "--use_tta", action="store_true", help="Использовать тестовую аугментацию" + ) + parser.add_argument( + "--disable_detailed_pbar", + action="store_true", + help="Отключить детальный прогресс-бар", + ) + parser.add_argument("--verbose", action="store_true", help="Подробный вывод") + + return parser.parse_args() + + +def main(): + args = parse_args() + + device_ids = None + if args.device_ids: + device_ids = [int(x) for x in args.device_ids] + + results = mvsep_offline( + input_path=args.input, + store_dir=args.store_dir, + model_type=args.model_type, + config_path=args.config_path, + start_check_point=args.start_check_point, + extract_instrumental=args.extract_instrumental, + output_format=args.output_format, + output_bitrate=args.output_bitrate, + model_name=args.model_name, + template=args.template, + device_ids=device_ids, + disable_detailed_pbar=args.disable_detailed_pbar, + use_tta=args.use_tta, + force_cpu=args.force_cpu, + verbose=args.verbose, + selected_instruments=args.selected_instruments, + model_id=args.model_id, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/mvsepless/infer_utils.py b/mvsepless/infer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..57b2bb76b9765b809cd29595b6d8bb81b861c86f --- /dev/null +++ b/mvsepless/infer_utils.py @@ -0,0 +1,382 @@ +# coding: utf-8 +__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/" + +import sys +import json +import numpy as np +import torch +import torch.nn as nn +import yaml +import librosa +import torch.nn.functional as F +from ml_collections import ConfigDict +from omegaconf import OmegaConf +from typing import Dict, List, Tuple, Any, List, Optional + + +def load_config(model_type: str, config_path: str) -> Any: + """ + Load the configuration from the specified path based on the model type. + """ + try: + with open(config_path, "r") as f: + if model_type == "htdemucs": + config = OmegaConf.load(config_path) + else: + config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) + return config + except FileNotFoundError: + raise FileNotFoundError(f"Configuration file not found at {config_path}") + except Exception as e: + raise ValueError(f"Error loading configuration: {e}") + + +def get_model_from_config(model_type: str, config_path: str) -> Tuple: + """ + Load the model specified by the model type and configuration file. + """ + config = load_config(model_type, config_path) + + if model_type == "mdx23c": + from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net + + model = TFC_TDF_net(config) + + elif model_type == "mdxnet": + from models.mdx_net import MDXNet + + model = MDXNet(**dict(config.model)) + + # В функции get_model_from_config добавьте: + + elif model_type == "vr": + from models.vr_arch import VRNet + # Передаем instruments из config.training в модель + model = VRNet(**dict(config.model)) + + elif model_type == "htdemucs": + from models.demucs4ht import get_model + + model = get_model(config) + + elif model_type == "mel_band_roformer": + if hasattr(config, "windowed"): # Это не нарушает совместимость со обычными моделями на Mel-Band Roformer + from models.windowed_roformer.model import MelBandRoformerWSA + + model = MelBandRoformerWSA(**dict(config.model)) + + else: + from models.bs_roformer import MelBandRoformer + + model = MelBandRoformer(**dict(config.model)) + + elif model_type == "bs_roformer": + if hasattr(config.model, "use_shared_bias"): + from models.bs_roformer import BSRoformer_SW + + model = BSRoformer_SW(**dict(config.model)) + elif hasattr(config.model, "fno"): + from models.bs_roformer import BSRoformer_FNO + + model = BSRoformer_FNO(**dict(config.model)) + else: + from models.bs_roformer import BSRoformer + + model = BSRoformer(**dict(config.model)) + + elif model_type == "bandit": + from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple + + model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model) + + elif model_type == "bandit_v2": + from models.bandit_v2.bandit import Bandit + + model = Bandit(**config.kwargs) + elif model_type == "scnet_unofficial": + from models.scnet_unofficial import SCNet + + model = SCNet(**config.model) + elif model_type == "scnet": + from models.scnet import SCNet + + model = SCNet(**config.model) + else: + raise ValueError(f"Unknown model type: {model_type}") + + return model, config + +def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor: + """ + Generate a windowing array with a linear fade-in at the beginning and a fade-out at the end. + """ + fadein = torch.linspace(0, 1, fade_size) + fadeout = torch.linspace(1, 0, fade_size) + + window = torch.ones(window_size) + window[-fade_size:] = fadeout + window[:fade_size] = fadein + return window + +def demix_mdxnet( + config: Any, + model: Any, + mix: np.ndarray, + device: torch.device, + pbar: bool = False, +) -> Dict[str, np.ndarray]: + """ + MDX-Net specific demixing function с поддержкой overlap + """ + mix_tensor = torch.tensor(mix, dtype=torch.float32) + inv_mix_tensor = torch.tensor(-mix, dtype=torch.float32) + + num_overlap = config.inference.num_overlap + denoise = config.inference.denoise + stem_name = model.primary_stem + if denoise: + processed_wav = model.process_wave(mix_tensor, device, num_overlap, pbar=pbar) + inv_processed_wav = model.process_wave(inv_mix_tensor, device, num_overlap, pbar=pbar) + result = processed_wav.cpu().numpy() + inv_result = inv_processed_wav.cpu().numpy() + result_separation = (result + -inv_result) * 0.5 + else: + processed_wav = model.process_wave(mix_tensor, device, num_overlap, pbar=pbar) + result_separation = processed_wav.cpu().numpy() + + result_separation = np.nan_to_num(result_separation, nan=0.0, posinf=0.0, neginf=0.0) + + return {stem_name: result_separation} # Перемещаем на CPU для возврата + +def demix_vr( + config: Any, + model: Any, + mix: np.ndarray, + device: torch.device, + pbar: bool = False, +) -> Dict[str, np.ndarray]: + """ + VR-specific demixing function that processes the entire audio at once + since VR architecture doesn't support chunk-based processing + """ + # Convert to tensor and add batch dimension + return model.demix(mix, config.audio.sample_rate, device, config.inference.aggression) + +def demix_demucs(config, model, mix, device, pbar=False): + mix = torch.tensor(mix, dtype=torch.float32) + chunk_size = config.training.samplerate * config.training.segment + num_instruments = len(config.training.instruments) + num_overlap = config.inference.num_overlap + step = chunk_size // num_overlap + fade_size = chunk_size // 10 # Добавляем fade_size для оконной функции + windowing_array = _getWindowingArray(chunk_size, fade_size) # Создаём окно + + batch_size = config.inference.batch_size + use_amp = getattr(config.training, "use_amp", True) + + with torch.cuda.amp.autocast(enabled=use_amp): + with torch.inference_mode(): + req_shape = (num_instruments,) + mix.shape + result = torch.zeros(req_shape, dtype=torch.float32) + counter = torch.zeros(req_shape, dtype=torch.float32) + + i = 0 + batch_data = [] + batch_locations = [] + + while i < mix.shape[1]: + part = mix[:, i : i + chunk_size].to(device) + chunk_len = part.shape[-1] + pad_mode = "reflect" if chunk_len > chunk_size // 2 else "constant" + part = nn.functional.pad( + part, (0, chunk_size - chunk_len), mode=pad_mode, value=0 + ) + + batch_data.append(part) + batch_locations.append((i, chunk_len)) + i += step + + if len(batch_data) >= batch_size or i >= mix.shape[1]: + arr = torch.stack(batch_data, dim=0) + x = model(arr) + + window = windowing_array.clone() + if i - step == 0: # Первый чанк, без fade-in + window[:fade_size] = 1 + elif i >= mix.shape[1]: # Последний чанк, без fade-out + window[-fade_size:] = 1 + + for j, (start, seg_len) in enumerate(batch_locations): + result[..., start : start + seg_len] += ( + x[j, ..., :seg_len].cpu() * window[..., :seg_len] + ) + counter[..., start : start + seg_len] += window[..., :seg_len] + + # Output progress + processed = min(i, mix.shape[1]) + total = mix.shape[1] + sys.stdout.write(json.dumps({"processing": {"processed": processed, "total": total}}) + '\n') + sys.stdout.flush() + + batch_data.clear() + batch_locations.clear() + + estimated_sources = result / counter + estimated_sources = estimated_sources.cpu().numpy() + np.nan_to_num(estimated_sources, copy=False, nan=0.0) + + if num_instruments <= 1: + return estimated_sources + else: + instruments = config.training.instruments + return {k: v for k, v in zip(instruments, estimated_sources)} + +def demix_generic( + config: ConfigDict, + model: torch.nn.Module, + mix: torch.Tensor, + device: torch.device, + pbar: bool = False, +) -> Dict[str, np.ndarray]: + """ + Generic demixing function for models that support chunk-based processing + """ + mix = torch.tensor(mix, dtype=torch.float32) + chunk_size = config.audio.chunk_size + instruments = prefer_target_instrument(config) + num_instruments = len(instruments) + num_overlap = config.inference.num_overlap + + fade_size = chunk_size // 10 + step = chunk_size // num_overlap + border = chunk_size - step + length_init = mix.shape[-1] + windowing_array = _getWindowingArray(chunk_size, fade_size) + + # Add padding to handle edge artifacts + if length_init > 2 * border and border > 0: + mix = nn.functional.pad(mix, (border, border), mode="reflect") + + batch_size = config.inference.batch_size + use_amp = getattr(config.training, "use_amp", True) + + with torch.cuda.amp.autocast(enabled=use_amp): + with torch.inference_mode(): + # Initialize result and counter tensors + req_shape = (num_instruments,) + mix.shape + result = torch.zeros(req_shape, dtype=torch.float32) + counter = torch.zeros(req_shape, dtype=torch.float32) + + i = 0 + batch_data = [] + batch_locations = [] + + while i < mix.shape[1]: + # Extract chunk and apply padding if necessary + part = mix[:, i : i + chunk_size].to(device) + chunk_len = part.shape[-1] + + pad_mode = "reflect" if chunk_len > chunk_size // 2 else "constant" + part = nn.functional.pad( + part, (0, chunk_size - chunk_len), mode=pad_mode, value=0 + ) + + batch_data.append(part) + batch_locations.append((i, chunk_len)) + i += step + + # Process batch if it's full or the end is reached + if len(batch_data) >= batch_size or i >= mix.shape[1]: + arr = torch.stack(batch_data, dim=0) + x = model(arr) + + window = windowing_array.clone() + if i - step == 0: # First audio chunk, no fadein + window[:fade_size] = 1 + elif i >= mix.shape[1]: # Last audio chunk, no fadeout + window[-fade_size:] = 1 + + for j, (start, seg_len) in enumerate(batch_locations): + result[..., start : start + seg_len] += ( + x[j, ..., :seg_len].cpu() * window[..., :seg_len] + ) + counter[..., start : start + seg_len] += window[..., :seg_len] + + # Output progress + processed = min(i, mix.shape[1]) + total = mix.shape[1] + sys.stdout.write(json.dumps({"processing": {"processed": processed, "total": total}}, ensure_ascii=False) + '\n') + sys.stdout.flush() + + batch_data.clear() + batch_locations.clear() + + # Compute final estimated sources + estimated_sources = result / counter + estimated_sources = estimated_sources.cpu().numpy() + np.nan_to_num(estimated_sources, copy=False, nan=0.0) + + # Remove padding + if length_init > 2 * border and border > 0: + estimated_sources = estimated_sources[..., border:-border] + + # Return the result as a dictionary + return {k: v for k, v in zip(instruments, estimated_sources)} + +def demix( + config: ConfigDict, + model: torch.nn.Module, + mix: np.ndarray, + device: torch.device, + model_type: str, + pbar: bool = False, +) -> Dict[str, np.ndarray]: + """ + Unified function for audio source separation with support for multiple processing modes. + """ + # Handle different model types + if model_type == "vr": + return demix_vr(config, model, mix, device, pbar) + elif model_type == "mdxnet": + return demix_mdxnet(config, model, mix, device, pbar) + elif model_type == "htdemucs": + # HTDemucs uses its own processing + return demix_demucs(config, model, mix, device, pbar) + else: + # Generic processing for other models + return demix_generic(config, model, mix, device, pbar) + + +def prefer_target_instrument(config: ConfigDict) -> List[str]: + """ + Return the list of target instruments based on the configuration. + If a specific target instrument is specified in the configuration, + it returns a list with that instrument. Otherwise, it returns the list of instruments. + """ + if config.training.get("target_instrument"): + return [config.training.target_instrument] + else: + return config.training.instruments + + +def prefer_target_instrument_test( + config: ConfigDict, selected_instruments: Optional[List[str]] = None +) -> List[str]: + """ + Return the list of target instruments based on the configuration and selected instruments. + If selected_instruments is specified, returns the intersection with available instruments. + Otherwise, if a target instrument is specified, returns it, else returns all instruments. + """ + available_instruments = config.training.instruments + + if selected_instruments is not None: + # Return only selected instruments that are available + return [ + instr for instr in selected_instruments if instr in available_instruments + ] + elif config.training.get("target_instrument"): + # Default behavior if no selection - return target instrument + return [config.training.target_instrument] + else: + # If no target and no selection, return all instruments + return available_instruments \ No newline at end of file diff --git a/mvsepless/model_manager.py b/mvsepless/model_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..0492e403eee6eedb9e35db7a7ae1d5d4a35b9b5f --- /dev/null +++ b/mvsepless/model_manager.py @@ -0,0 +1,540 @@ +import os +import sys +import json +import yaml +from tabulate import tabulate +import shutil +from tqdm import tqdm +import urllib.request +import gdown +import requests +import zipfile +import tempfile +import secrets +import string +import argparse +from typing import Dict, Any +script_dir = os.path.dirname(os.path.abspath(__file__)) +if not __package__: + from downloader import dw_file +else: + from .downloader import dw_file + +def generate_secure_random(length=10): + """Генерирует криптографически безопасную случайную строку""" + characters = string.ascii_letters + string.digits + return ''.join(secrets.choice(characters) for _ in range(length)) + +class MvseplessModelManager: + def __init__( + self, + models_info_path=os.path.join(script_dir, "models.json"), + cache_dir=os.path.join(script_dir, "mvsepless_models_cache"), + ): + self.models_cache_dir = cache_dir + self.models_info_path = models_info_path + with open(self.models_info_path, "r", encoding="utf-8") as f: + models_info = json.load(f) + self.models_info = models_info + + def get_mt(self): + return list(self.models_info.keys()) + + def get_mn(self, model_type): + try: + mt = self.models_info.get(model_type, None) + if mt: + return list(self.models_info[model_type].keys()) + return [] + except (KeyError, TypeError): + return [] + + def get_stems(self, model_type, model_name): + try: + mt = self.models_info.get(model_type, None) + if mt: + mn = self.models_info[model_type].get(model_name, None) + if mn and "stems" in self.models_info[model_type][model_name]: + return self.models_info[model_type][model_name]["stems"] + return [] + except (KeyError, TypeError): + return [] + + def get_id(self, model_type, model_name): + try: + mt = self.models_info.get(model_type, None) + if mt: + mn = self.models_info[model_type].get(model_name, None) + if mn and "id" in self.models_info[model_type][model_name]: + return self.models_info[model_type][model_name]["id"] + return 0 + except (KeyError, TypeError): + return 0 + + def get_tgt_inst(self, model_type, model_name): + try: + mt = self.models_info.get(model_type, None) + if mt: + mn = self.models_info[model_type].get(model_name, None) + if mn and "target_instrument" in self.models_info[model_type][model_name]: + return self.models_info[model_type][model_name]["target_instrument"] + return None + except (KeyError, TypeError): + return None + + def display_models_info(self, filter: str = None): + # Собираем данные для таблицы + table_data = [] + headers = [ + "Тип модели", + "ID", + "Имя модели", + "Стемы", + "Целевой инструмент", + ] + + for model_type, models in self.models_info.items(): + for model_name, model_info in models.items(): + try: + stems_list = model_info.get("stems", []) + id = model_info.get("id", "н/д") + # Применяем фильтр (регистронезависимо) + if filter: + filter_lower = filter.lower() + if not any(filter_lower == s.lower() for s in stems_list): + continue + + # Подготавливаем данные для строки таблицы + row = [ + model_type, + id, + model_name, + ", ".join(stems_list) or "н/д", + model_info.get("target_instrument", "н/д"), + ] + table_data.append(row) + except (KeyError, TypeError, AttributeError) as e: + print(f"Ошибка при обработке модели {model_type}/{model_name}: {e}") + continue + + # Выводим результат + if table_data: + print(tabulate(table_data, headers=headers, tablefmt="grid")) + else: + print("Нет моделей, которые содержат указанный стем") + + def download_model( + self, model_paths, model_name, model_type, ckpt_url, conf_url + ): + model_dir = os.path.join(model_paths, model_type) + os.makedirs(model_dir, exist_ok=True) + + config_path = None + checkpoint_path = None + + if model_type == "mel_band_roformer": + config_path = os.path.join(model_dir, f"{model_name}_config.yaml") + checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt") + + elif model_type == "vr": + config_path = os.path.join(model_dir, f"{model_name}.yaml") + checkpoint_path = os.path.join(model_dir, f"{model_name}.pth") + + elif model_type == "mdxnet": + config_path = os.path.join(model_dir, f"{model_name}.yaml") + checkpoint_path = os.path.join(model_dir, f"{model_name}.onnx") + + elif model_type == "bs_roformer": + config_path = os.path.join(model_dir, f"{model_name}_config.yaml") + checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt") + + elif model_type == "mdx23c": + config_path = os.path.join(model_dir, f"{model_name}_config.yaml") + checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt") + + elif model_type == "scnet": + config_path = os.path.join(model_dir, f"{model_name}_config.yaml") + checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt") + + elif model_type == "bandit": + config_path = os.path.join(model_dir, f"{model_name}_config.yaml") + checkpoint_path = os.path.join(model_dir, f"{model_name}.chpt") + + elif model_type == "bandit_v2": + config_path = os.path.join(model_dir, f"{model_name}_config.yaml") + checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt") + + elif model_type == "htdemucs": + config_path = os.path.join(model_dir, f"{model_name}_config.yaml") + checkpoint_path = os.path.join(model_dir, f"{model_name}.th") + + else: + raise ValueError( + f"{self.I18N_helper.t('error_unsupported_model_type')}: {model_type}" + ) + + # Проверяем, что пути заданы (на всякий случай) + if config_path is None or checkpoint_path is None: + raise RuntimeError() + + # Если файлы уже есть — пропускаем загрузку + if os.path.exists(checkpoint_path) and os.path.exists(config_path): + if os.path.getsize(checkpoint_path) == 0 or os.path.getsize(checkpoint_path) == 0: + for local_path, url_model in [ + (checkpoint_path, ckpt_url), + (config_path, conf_url), + ]: + if not os.path.exists(local_path): + + dw_file(url_model, local_path) + else: + pass + else: + for local_path, url_model in [ + (checkpoint_path, ckpt_url), + (config_path, conf_url), + ]: + if not os.path.exists(local_path): + + dw_file(url_model, local_path) + + return config_path, checkpoint_path + + def conf_editor(self, config_path, mdx_denoise, vr_aggr, model_type): + + class IndentDumper(yaml.Dumper): + def increase_indent(self, flow=False, indentless=False): + return super(IndentDumper, self).increase_indent(flow, False) + + def tuple_constructor(loader, node): + # Load the sequence of values from the YAML node + values = loader.construct_sequence(node) + # Return a tuple constructed from the sequence + return tuple(values) + + # Register the constructor with PyYAML + yaml.SafeLoader.add_constructor( + "tag:yaml.org,2002:python/tuple", tuple_constructor + ) + + def conf_edit(config_path, mdx_denoise, vr_aggr, model_type): + with open(config_path, "r") as f: + data = yaml.load(f, Loader=yaml.SafeLoader) + + # handle cases where 'use_amp' is missing from config: + if "use_amp" not in data.keys(): + data["training"]["use_amp"] = True + + if model_type != "vr": + if data["inference"]["num_overlap"] != 2: + data["inference"]["num_overlap"] = 2 + + if data["inference"]["batch_size"] != 1: + data["inference"]["batch_size"] = 1 + + if model_type == "mdxnet": + data["inference"]["denoise"] = mdx_denoise + + elif model_type == "vr": + data["inference"]["aggression"] = vr_aggr + + with open(config_path, "w") as f: + yaml.dump( + data, + f, + default_flow_style=False, + sort_keys=False, + Dumper=IndentDumper, + allow_unicode=True, + ) + + conf_edit(config_path, mdx_denoise, vr_aggr, model_type) + +class VbachModelManager: + def __init__(self): + self.rmvpe_path = os.path.join(script_dir, "predictors", "rmvpe.pt") + self.fcpe_path = os.path.join(script_dir, "predictors", "fcpe.pt") + self.hubert_path = os.path.join(script_dir, "embedders", "hubert_base.pt") + self.requirements = [["https://huggingface.co/Politrees/RVC_resources/resolve/main/predictors/rmvpe.pt", self.rmvpe_path], ["https://huggingface.co/Politrees/RVC_resources/resolve/main/predictors/fcpe.pt", self.fcpe_path], ["https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/hubert_base.pt", self.hubert_path]] + self.voicemodels_dir = os.path.join(script_dir, "vbach_models_cache") + os.makedirs(self.voicemodels_dir, exist_ok=True) + self.voicemodels_info = os.path.join(self.voicemodels_dir, "vbach_models.json") + self.voicemodels: Dict[str, Dict[str, str]] = {} + self.download_requirements() + self.check_and_load() + pass + + def write_voicemodels_info(self): + with open(self.voicemodels_info, "w") as f: + json.dump(self.voicemodels, f, indent=4) + + def load_voicemodels_info(self): + with open(self.voicemodels_info, "r") as f: + return json.load(f) + + def add_voice_model( + self, + name, + pth_path, + index_path, + ): + self.voicemodels[name] = {"pth": pth_path, "index": index_path} + self.write_voicemodels_info() + + def del_voice_model( + self, name + ): + if name in self.parse_voice_models(): + pth = self.voicemodels[name].get("pth", None) + index = self.voicemodels[name].get("index", None) + if index: + os.remove(index) + if pth: + os.remove(pth) + del self.voicemodels[name] + self.write_voicemodels_info() + return f"Модель {name} удалена" + else: + return f"Модель не была удалена, как так её не существует" + + def parse_voice_models(self): + list_models = list(self.voicemodels.keys()) + return list_models + + def parse_pth_and_index(self, name): + pth = self.voicemodels[name].get("pth", None) + index = self.voicemodels[name].get("index", None) + return pth, index + + def check_and_load(self): + if os.path.exists(self.voicemodels_info): + self.voicemodels = self.load_voicemodels_info() + else: + self.write_voicemodels_info() + + def clear_voicemodels_info(self): + self.voicemodels: Dict[str, Dict[str, str]] = {} + self.write_voicemodels_info() + + def download_file(self, url_model, local_path): + dir_name = os.path.dirname(local_path) + if dir_name != "": + os.makedirs(dir_name, exist_ok=True) + class TqdmUpTo(tqdm): + def update_to(self, b=1, bsize=1, tsize=None): + if tsize is not None: + self.total = tsize + self.update(b * bsize - self.n) + + with TqdmUpTo( + unit="B", + unit_scale=True, + unit_divisor=1024, + miniters=1, + desc=os.path.basename(local_path), + ) as t: + urllib.request.urlretrieve( + url_model, local_path, reporthook=t.update_to + ) + + def download_requirements(self): + for url, file in self.requirements: + if not os.path.exists(file): + self.download_file(url_model=url, local_path=file) + + def download_voice_model_file(self, url, zip_name): + try: + if "drive.google.com" in url: + self.download_from_google_drive(url, zip_name) + elif "pixeldrain.com" in url: + self.download_from_pixeldrain(url, zip_name) + elif "disk.yandex.ru" in url or "yadi.sk" in url: + self.download_from_yandex(url, zip_name) + else: + self.download_file(url, zip_name) + except Exception as e: + print(e) + + def download_from_google_drive(self, url, zip_name): + file_id = ( + url.split("file/d/")[1].split("/")[0] + if "file/d/" in url + else url.split("id=")[1].split("&")[0] + ) + gdown.download(id=file_id, output=str(zip_name), quiet=False) + + def download_from_pixeldrain(self, url, zip_name): + file_id = url.split("pixeldrain.com/u/")[1] + response = requests.get(f"https://pixeldrain.com/api/file/{file_id}") + with open(zip_name, "wb") as f: + f.write(response.content) + + def download_from_yandex(self, url, zip_name): + yandex_public_key = f"download?public_key={url}" + yandex_api_url = f"https://cloud-api.yandex.net/v1/disk/public/resources/{yandex_public_key}" + response = requests.get(yandex_api_url) + if response.status_code == 200: + download_link = response.json().get("href") + urllib.request.urlretrieve(download_link, zip_name) + else: + print(response.status_code) + + def extract_zip(self, zip_name, model_name): + model_dir = os.path.join(self.voicemodels_dir, f"{model_name}_{generate_secure_random(17)}") + os.makedirs(model_dir, exist_ok=True) + try: + with zipfile.ZipFile(zip_name, "r") as zip_ref: + zip_ref.extractall(model_dir) + os.remove(zip_name) + + added_voice_models = [] + + index_filepath, model_filepaths = None, [] + for root, _, files in os.walk(model_dir): + for name in files: + file_path = os.path.join(root, name) + if name.endswith(".index") and os.stat(file_path).st_size > 1024 * 100: + index_filepath = file_path + if name.endswith(".pth") and os.stat(file_path).st_size > 1024 * 1024 * 20: + model_filepaths.append(file_path) + + if len(model_filepaths) == 1: + self.add_voice_model(model_name, model_filepaths[0], index_filepath) + added_voice_models.append(model_name) + else: + for i, pth in enumerate(model_filepaths): + self.add_voice_model(f"{model_name}_{i + 1}", pth, index_filepath) + added_voice_models.append(f"{model_name}_{i + 1}") + list_models_str = '\n'.join(added_voice_models) + return f"Добавленные модели:\n{list_models_str}" + except Exception as e: + return f"Произошла ошибка при загрузке модели: {e}" + + def install_model_zip(self, zip, model_name, mode="url"): + if model_name in self.parse_voice_models(): + print("Эта модель уже есть в списке установленных моделей. Она будут перезаписана") + if mode == "url": + with tempfile.TemporaryDirectory(prefix="vbach_temp_model", ignore_cleanup_errors=True) as tmp: + zip_path = os.path.join(tmp, "model.zip") + self.download_voice_model_file(zip, zip_path) + status = self.extract_zip(zip_path, model_name) + if mode == "local": + status = self.extract_zip(zip, model_name) + return status + + def install_model_files(self, index, pth, model_name, mode="url"): + if model_name in self.parse_voice_models(): + print("Эта модель уже есть в списке установленных моделей. Она будут перезаписана") + model_dir = os.path.join(self.voicemodels_dir, f"{model_name}_{generate_secure_random(17)}") + os.makedirs(model_dir, exist_ok=True) + local_index_path = None + local_pth_path = None + try: + if mode == "url": + if index: + local_index_path = os.path.join(model_dir, "model.index") + self.download_voice_model_file(index, local_index_path) + if pth: + local_pth_path = os.path.join(model_dir, "model.pth") + self.download_voice_model_file(pth, local_pth_path) + + if mode == "local": + if index: + if os.path.exists(index): + local_index_path = os.path.join(model_dir, os.path.basename(index)) + shutil.copy(index, local_index_path) + if pth: + if os.path.exists(pth): + local_pth_path = os.path.join(model_dir, os.path.basename(pth)) + shutil.copy(pth, local_pth_path) + + self.add_voice_model(model_name, local_pth_path, local_index_path) + return f"Модель {model_name} добавлена" + except Exception as e: + return f"Произошла ошибка при загрузке модели: {e}" + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Менеджер моделей") + subparsers = parser.add_subparsers(title="subcommands", dest="command", required=True) + + # Mvsepless subcommand + mvsepless_parser = subparsers.add_parser("mvsepless", help="Скачивание моделей в MVSepLess") + mvsepless_parser.add_argument("--model_type", required=True, help="Тип модели") + mvsepless_parser.add_argument("--model_name", required=True, help="Имя модели") + + # Vbach subcommand + vbach_parser = subparsers.add_parser("vbach", help="Установка голосовых моделей в Vbach") + vbach_subparsers = vbach_parser.add_subparsers(title="vbach_commands", dest="vbach_command", required=True) + + # Vbach install_local + install_local_parser = vbach_subparsers.add_parser("install_local", help="Установка голосовой модели по локальным файлам") + install_local_parser.add_argument("--model_name", required=True, help="Имя голосовой модели") + install_local_parser.add_argument("--pth", required=True, help="Путь к *.pth файлу") + install_local_parser.add_argument("--index", required=False, help="Путь к *.index файлу") + + # Vbach install_url_zip + install_url_zip_parser = vbach_subparsers.add_parser("install_url_zip", help="Установка голосовой модели по URL (архив с файлами)") + install_url_zip_parser.add_argument("--model_name", required=True, help="Имя голосовой модели") + install_url_zip_parser.add_argument("--url", required=True, help="URL *.zip файла") + + # Vbach install_url_files + install_url_files_parser = vbach_subparsers.add_parser("install_url_files", help="Установка голосовой модели по URL (отдельные файлы)") + install_url_files_parser.add_argument("--model_name", required=True, help="Имя голосовой модели") + install_url_files_parser.add_argument("--pth_url", required=True, help="URL *.pth файла") + install_url_files_parser.add_argument("--index_url", required=False, help="URL *.index файла") + + # Vbach list + list_parser = vbach_subparsers.add_parser("list", help="List installed voice models") + + args = parser.parse_args() + + if args.command == "mvsepless": + + _model_manager = MvseplessModelManager() + info = _model_manager.models_info[args.model_type].get(args.model_name, None) + if not info: + raise ValueError(f"Модель {args.model_name} не найдена для типа {args.model_type}") + conf, ckpt = _model_manager.download_model( + _model_manager.models_cache_dir, + args.model_name, + args.model_type, + info["checkpoint_url"], + info["config_url"], + ) + + elif args.command == "vbach": + model_manager = VbachModelManager() + + if args.vbach_command == "install_local": + status = model_manager.install_model_files( + args.index, args.pth, args.model_name, mode="local" + ) + print(status) + + elif args.vbach_command == "install_url_zip": + status = model_manager.install_model_zip( + args.url, args.model_name, mode="url" + ) + print(status) + + elif args.vbach_command == "install_url_files": + status = model_manager.install_model_files( + args.index_url, args.pth_url, args.model_name, mode="url" + ) + print(status) + + elif args.vbach_command == "list": + models = model_manager.parse_voice_models() + if models: + print("Установленные модели:") + for model in models: + print(f" - {model}") + else: + print("Нет установленных моделей") + + + + + + diff --git a/mvsepless/models.json b/mvsepless/models.json new file mode 100644 index 0000000000000000000000000000000000000000..c28d8e7b9db1735ab51795fd67f74e8b91336c17 --- /dev/null +++ b/mvsepless/models.json @@ -0,0 +1,2859 @@ +{ + "mel_band_roformer": { + "mbr_vocals_kim": { + "category": "Вокал", + "id": 1000, + "full_name": "Mel-Band Roformer Vocals by Kimberley Jensen", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/KimberleyJSN/melbandroformer/resolve/main/MelBandRoformer.ckpt?download=true", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/KimberleyJensen/config_vocals_mel_band_roformer_kj.yaml" + }, + "mbr_wsa": { + "category": "Вокал", + "id": 1910, + "full_name": "Windowed Sink Attention Mel-Band Roformer Vocals by Smule Labs", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/smulelabs/windowed-roformer/resolve/main/mbr-win10-sink8.ckpt?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/config_windowed_roformer_by_smulelabs_wsa.yaml?download=true" + }, + "mbr_instvoc_duality1_unwa": { + "category": "Инструментал и вокал", + "id": 1010, + "full_name": "Mel-Band Roformer InstVoc Duality v1 by Unwa", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/resolve/main/melband_roformer_instvoc_duality_v1.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/resolve/main/config_melbandroformer_instvoc_duality.yaml?download=true" + }, + "mbr_instvoc_duality2_unwa": { + "category": "Инструментал и вокал", + "id": 1011, + "full_name": "Mel-Band Roformer InstVoc Duality v2 by Unwa", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/resolve/main/melband_roformer_instvox_duality_v2.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/resolve/main/config_melbandroformer_instvoc_duality.yaml?download=true" + }, + "mbr_kimft1_unwa": { + "category": "Вокал", + "id": 1020, + "full_name": "Mel-Band Roformer Kim FT v1 by Unwa", + "stems": [ + "Vocals", + "other" + ], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/kimmel_unwa_ft.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/config_kimmel_unwa_ft.yaml?download=true" + }, + "mbr_kimft2_unwa": { + "category": "Вокал", + "id": 1021, + "full_name": "Mel-Band Roformer Kim FT v2 by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/kimmel_unwa_ft2.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/config_kimmel_unwa_ft.yaml?download=true" + }, + "mbr_kimft2b_unwa": { + "category": "Вокал", + "id": 1022, + "full_name": "Mel-Band Roformer Kim FT v2 Bleedless by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/kimmel_unwa_ft2_bleedless.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/config_kimmel_unwa_ft.yaml?download=true" + }, + "mbr_kimft3_prev_unwa": { + "category": "Вокал", + "id": 1023, + "full_name": "Mel-Band Roformer Kim FT v3 preview by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/kimmel_unwa_ft3_prev.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/config_kimmel_unwa_ft.yaml?download=true" + }, + "mbr_bigbeta1_unwa": { + "category": "Вокал", + "id": 1030, + "full_name": "Mel-Band Roformer Big Beta v1 by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/melband_roformer_big_beta1.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/config_melbandroformer_big.yaml?download=true" + }, + "mbr_bigbeta2_unwa": { + "category": "Вокал", + "id": 1031, + "full_name": "Mel-Band Roformer Big Beta v2 by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/melband_roformer_big_beta2.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/config_melbandroformer_big.yaml?download=true" + }, + "mbr_bigbeta3_unwa": { + "category": "Вокал", + "id": 1032, + "full_name": "Mel-Band Roformer Big Beta v3 by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/melband_roformer_big_beta3.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/config_melbandroformer_big.yaml?download=true" + }, + "mbr_bigbeta4_unwa": { + "category": "Вокал", + "id": 1033, + "full_name": "Mel-Band Roformer Big Beta v4 by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/melband_roformer_big_beta4.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/config_melbandroformer_big_beta4.yaml?download=true" + }, + "mbr_bigbeta5e_unwa": { + "category": "Вокал", + "id": 1034, + "full_name": "Mel-Band Roformer Vocals Big Beta v5e by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta5e.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta5e.yaml?download=true" + }, + "mbr_bigbeta6_unwa": { + "category": "Вокал", + "id": 1035, + "full_name": "Mel-Band Roformer Big Beta v6 by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta6.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta6.yaml?download=true" + }, + "mbr_bigbeta6x_unwa": { + "category": "Вокал", + "id": 1036, + "full_name": "Mel-Band Roformer Big Beta v6x by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta6x.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta6x.yaml?download=true" + }, + "mbr_inst1_unwa": { + "category": "Инструментал", + "id": 1040, + "full_name": "Mel-Band Roformer Instrumental v1 by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "other", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/melband_roformer_inst_v1.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/config_melbandroformer_inst.yaml?download=true" + }, + "mbr_inst1+_unwa": { + "category": "Инструментал", + "id": 1041, + "full_name": "Mel-Band Roformer Instrumental v1+ by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "other", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/inst_v1_plus_test.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/config_melbandroformer_inst.yaml?download=true" + }, + "mbr_inst1e_unwa": { + "category": "Инструментал", + "id": 1042, + "full_name": "Mel-Band Roformer Instrumental v1e by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "other", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/inst_v1e.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/config_melbandroformer_inst.yaml?download=true" + }, + "mbr_inst1e+_unwa": { + "category": "Инструментал", + "id": 1043, + "full_name": "Mel-Band Roformer Instrumental v1e Plus by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "other", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/inst_v1e_plus.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/config_melbandroformer_inst.yaml?download=true" + }, + "mbr_inst2_unwa": { + "category": "Инструментал", + "id": 1044, + "full_name": "Mel-Band Roformer Instrumental v2 by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "other", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/melband_roformer_inst_v2.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/config_melbandroformer_inst_v2.yaml?download=true" + }, + "mbr_small_unwa": { + "category": "Вокал", + "id": 1050, + "full_name": "Mel-Band Roformer Small v1 by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-small/resolve/main/melband_roformer_small_v1.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-small/resolve/main/config_melbandroformer_small.yaml?download=true" + }, + "mbr_bleed_supressor_unwa_97chris": { + "category": "Шум", + "id": 1051, + "full_name": "Mel-Band Roformer Bleed Suppressor v1 by Unwa / 97chris", + "stems": [ + "Instrumental", + "Bleed" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/jarredou/bleed_suppressor_melband_rofo_by_unwa_97chris/resolve/main/bleed_suppressor_v1.ckpt?download=true", + "config_url": "https://huggingface.co/jarredou/bleed_suppressor_melband_rofo_by_unwa_97chris/resolve/main/config_bleed_suppressor_v1.yaml?download=true" + }, + "mbr_inst_becruily": { + "category": "Инструментал", + "id": 1060, + "full_name": "Mel-Band Roformer Instrumental by Becruily", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/becruily/mel-band-roformer-instrumental/resolve/main/mel_band_roformer_instrumental_becruily.ckpt?download=true", + "config_url": "https://huggingface.co/becruily/mel-band-roformer-instrumental/resolve/main/config_instrumental_becruily.yaml?download=true" + }, + "mbr_guitar_becruily": { + "category": "Гитара", + "id": 1061, + "full_name": "Mel-Band Roformer Instrumental by Becruily", + "stems": [ + "Guitar", + "Other" + ], + "target_instrument": "Guitar", + "checkpoint_url": "https://huggingface.co/becruily/mel-band-roformer-guitar/resolve/main/becruily_guitar.ckpt?download=true", + "config_url": "https://huggingface.co/becruily/mel-band-roformer-guitar/resolve/main/config_guitar_becruily.yaml?download=true" + }, + "mbr_karaoke_becruily": { + "category": "Караоке", + "id": 1062, + "full_name": "Mel-Band Roformer Karaoke by Becruily", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/becruily/mel-band-roformer-karaoke/resolve/main/mel_band_roformer_karaoke_becruily.ckpt?download=true", + "config_url": "https://huggingface.co/becruily/mel-band-roformer-karaoke/resolve/main/config_karaoke_becruily.yaml?download=true" + }, + "mbr_vocals_becruily": { + "category": "Вокал", + "id": 1063, + "full_name": "Mel-Band Roformer Vocals by Becruily", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/becruily/mel-band-roformer-vocals/resolve/main/mel_band_roformer_vocals_becruily.ckpt?download=true", + "config_url": "https://huggingface.co/becruily/mel-band-roformer-vocals/resolve/main/config_vocals_becruily.yaml?download=true" + }, + "mbr_syhft1": { + "category": "Вокал", + "id": 1070, + "full_name": "Mel-Band Roformer SYHFT v1 by SYH99999", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFT/resolve/main/MelBandRoformerSYHFT.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFT/resolve/main/config_vocals_mel_band_roformer_ft.yaml?download=true" + }, + "mbr_syhft2": { + "category": "Вокал", + "id": 1071, + "full_name": "Mel-Band Roformer SYHFT v2 by SYH99999", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTV2/resolve/main/MelBandRoformerSYHFTV2.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFT/resolve/main/config_vocals_mel_band_roformer_ft.yaml?download=true" + }, + "mbr_syhft2.5": { + "category": "Вокал", + "id": 1072, + "full_name": "Mel-Band Roformer SYHFT v2.5 by SYH99999", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTV2.5/resolve/main/MelBandRoformerSYHFTV2.5.ckpt/MelBandRoformerSYHFTV2.5.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFT/resolve/main/config_vocals_mel_band_roformer_ft.yaml?download=true" + }, + "mbr_syhft3": { + "category": "Вокал", + "id": 1073, + "full_name": "Mel-Band Roformer SYHFT v3 by SYH99999", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTV3Epsilon/resolve/main/MelBandRoformerSYHFTV3Epsilon.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFT/resolve/main/config_vocals_mel_band_roformer_ft.yaml?download=true" + }, + "mbr_bigsyhft1fast": { + "category": "Вокал", + "id": 1074, + "full_name": "Mel-Band Roformer Big SYHFT v1 Fast by SYH99999", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerBigSYHFTV1Fast/resolve/main/MelBandRoformerBigSYHFTV1.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformerBigSYHFTV1Fast/resolve/main/config.yaml?download=true" + }, + "mbr_syhftbeta1": { + "category": "Вокал", + "id": 1075, + "full_name": "Mel-Band Roformer Merged Beta v1 by SYH99999", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerMergedSYHFTBeta1/resolve/main/merge_syhft.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFT/resolve/main/config_vocals_mel_band_roformer_ft.yaml?download=true" + }, + "mbr_syhftB1_1": { + "category": "Вокал", + "id": 1076, + "full_name": "Mel-Band Roformer SYHFT B1 1 by SYH99999", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/model.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/config.yaml?download=true" + }, + "mbr_syhftB1_2": { + "category": "Вокал", + "id": 1077, + "full_name": "Mel-Band Roformer SYHFT B1 2 by SYH99999", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/model2.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/config.yaml?download=true" + }, + "mbr_syhftB1_3": { + "category": "Вокал", + "id": 1078, + "full_name": "Mel-Band Roformer SYHFT B1 3 by SYH99999", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/model3.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/config.yaml?download=true" + }, + "mbr_syhft_4stem": { + "category": "4 стема", + "id": 1079, + "full_name": "Mel-Band Roformer 4 Stems FT Large v1 by SYH99999", + "stems": [ + "vocals", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/MelBandRoformer4StemFTLarge.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/config.yaml?download=true" + }, + "mbr_syhft_4stem2": { + "category": "4 стема", + "id": 1080, + "full_name": "Mel-Band Roformer 4 Stems FT Large v2 by SYH99999", + "stems": [ + "vocals", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/ver2.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/config.yaml?download=true" + }, + "mbr_inst_1652_essid": { + "category": "Инструментал", + "id": 1085, + "full_name": "Mel-Band Roformer Instrumental by Essid (sdr 16.52)", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/Essid/Essid-MelBandRoformer/resolve/3960860f7895c87a12707ca6b378df2b3c68e2c0/model_mel_band_roformer_ep_17_sdr_16.5244.ckpt?download=true", + "config_url": "https://huggingface.co/Essid/Essid-MelBandRoformer/resolve/4768859bd59bc699d33f4567e82082993dde7eb9/config_vocals_mel_band_roformer_essid.yaml?download=true" + }, + "mbr_inst_1681_essid": { + "category": "Инструментал", + "id": 1086, + "full_name": "Mel-Band Roformer Instrumental by Essid (sdr 16.81)", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/Essid/Essid-MelBandRoformer/resolve/main/essid_mel_inst_old.ckpt?download=true", + "config_url": "https://huggingface.co/Essid/Essid-MelBandRoformer/resolve/4768859bd59bc699d33f4567e82082993dde7eb9/config_vocals_mel_band_roformer_essid.yaml?download=true" + }, + "mbr_instfv1_gabox": { + "category": "Инструментал", + "id": 1100, + "full_name": "Mel-Band Roformer Instrumental Fv1 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxFv1.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instfv2_gabox": { + "category": "Инструментал", + "id": 1101, + "full_name": "Mel-Band Roformer Instrumental Fv2 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxFv2.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instfv3_gabox": { + "category": "Инструментал", + "id": 1102, + "full_name": "Mel-Band Roformer Instrumental Fv3 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxFv3.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instfv4_gabox": { + "category": "Инструментал", + "id": 1117, + "full_name": "Mel-Band Roformer Instrumental Fv4 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_Fv4.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instfv4n_gabox": { + "category": "Инструментал", + "id": 1103, + "full_name": "Mel-Band Roformer Instrumental Fv4 Noise by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_Fv4Noise.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instfv5_gabox": { + "category": "Инструментал", + "id": 1104, + "full_name": "Mel-Band Roformer Instrumental Fv5 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV5.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instfv5n_gabox": { + "category": "Инструментал", + "id": 1105, + "full_name": "Mel-Band Roformer Instrumental Fv5 Noise by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV5N.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instfv6_gabox": { + "category": "Инструментал", + "id": 1106, + "full_name": "Mel-Band Roformer Instrumental Fv6 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV6.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instfv6n_gabox": { + "category": "Инструментал", + "id": 1107, + "full_name": "Mel-Band Roformer Instrumental Fv6 Noise by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV6N.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instfv7_gabox": { + "category": "Инструментал", + "id": 1108, + "full_name": "Mel-Band Roformer Instrumental Fv7 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/Inst_GaboxV7.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instfv7n_gabox": { + "category": "Инструментал", + "id": 1109, + "full_name": "Mel-Band Roformer Instrumental Fv7 Noise by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV7N.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instfv7+_gabox": { + "category": "Инструментал", + "id": 1110, + "full_name": "Mel-Band Roformer Instrumental Fv7+ by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/5fba9605d4b6bc1a31c04c50d08d757c5107d23f/melbandroformers/experimental/instv7plus.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instfv7z_gabox": { + "category": "Инструментал", + "id": 1111, + "full_name": "Mel-Band Roformer Instrumental Fv7z by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/Inst_GaboxFv7z.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instfv8_gabox": { + "category": "Инструментал", + "id": 1112, + "full_name": "Mel-Band Roformer Instrumental Fv8 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/Inst_GaboxFv8.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instfv8b_gabox": { + "category": "Инструментал", + "id": 1113, + "full_name": "Mel-Band Roformer Instrumental Fv8b by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/experimental/Inst_FV8b.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instfv9_gabox": { + "category": "Инструментал", + "id": 1114, + "full_name": "Mel-Band Roformer Instrumental Fv9 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/experimental/Inst_Fv9.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instfv10_gabox": { + "category": "Инструментал", + "id": 1115, + "full_name": "Mel-Band Roformer Instrumental Fv10 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/experimental/INSTV10.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instfvx_gabox": { + "category": "Инструментал", + "id": 1116, + "full_name": "Mel-Band Roformer Instrumental FvX by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/Inst_GaboxFVX.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instbv1_gabox": { + "category": "Инструментал", + "id": 1120, + "full_name": "Mel-Band Roformer Instrumental Bv1 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxBv1.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instbv2_gabox": { + "category": "Инструментал", + "id": 1121, + "full_name": "Mel-Band Roformer Instrumental Bv2 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxBv2.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_instbv3_gabox": { + "category": "Инструментал", + "id": 1122, + "full_name": "Mel-Band Roformer Instrumental Bv3 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxBv3.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_vocalsfv1_gabox": { + "category": "Вокал", + "id": 1130, + "full_name": "Mel-Band Roformer Vocals Fv1 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gaboxFv1.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml?download=true" + }, + "mbr_vocalsfv2_gabox": { + "category": "Вокал", + "id": 1131, + "full_name": "Mel-Band Roformer Vocals Fv2 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gaboxFv2.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml?download=true" + }, + "mbr_vocalsfv3_gabox": { + "category": "Вокал", + "id": 1132, + "full_name": "Mel-Band Roformer Vocals Fv3 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_Fv3.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml?download=true" + }, + "mbr_vocalsfv4_gabox": { + "category": "Вокал", + "id": 1133, + "full_name": "Mel-Band Roformer Vocals Fv4 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_fv4.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml?download=true" + }, + "mbr_vocalsfv5_gabox": { + "category": "Вокал", + "id": 1134, + "full_name": "Mel-Band Roformer Vocals Fv5 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_fv5.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml?download=true" + }, + "mbr_vocalsfv6_gabox": { + "category": "Вокал", + "id": 1135, + "full_name": "Mel-Band Roformer Vocals Fv6 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_fv6.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml?download=true" + }, + "mbr_karaoke25022025_gabox": { + "category": "Караоке", + "id": 1140, + "full_name": "Mel-Band Roformer Karaoke 25-02-2025 by GaboxR67", + "stems": [ + "karaoke", + "other" + ], + "target_instrument": "karaoke", + "checkpoint_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/gabox_karaoke_25_02_2025.ckpt?download=true", + "config_url": "https://huggingface.co/jarredou/aufr33-viperx-karaoke-melroformer-model/resolve/main/config_mel_band_roformer_karaoke.yaml?download=true" + }, + "mbr_karaoke28022025_gabox": { + "category": "Караоке", + "id": 1141, + "full_name": "Mel-Band Roformer Karaoke 28-02-2025 by GaboxR67", + "stems": [ + "karaoke", + "other" + ], + "target_instrument": "karaoke", + "checkpoint_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/gabox_karaoke_28_02_2025.ckpt?download=true", + "config_url": "https://huggingface.co/jarredou/aufr33-viperx-karaoke-melroformer-model/resolve/main/config_mel_band_roformer_karaoke.yaml?download=true" + }, + "mbr_karaoke1_gabox": { + "category": "Караоке", + "id": 1142, + "full_name": "Mel-Band Roformer Karaoke v1 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/27e73ca2beec0ab7daa46e366159753a166612e1/melbandroformers/karaoke/Karaoke_GaboxV1.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/27e73ca2beec0ab7daa46e366159753a166612e1/melbandroformers/karaoke/karaokegabox_1750911344.yaml?download=true" + }, + "mbr_karaoke2_gabox": { + "category": "Караоке", + "id": 1143, + "full_name": "Mel-Band Roformer Karaoke v2 by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/experimental/Karaoke_GaboxV2.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/27e73ca2beec0ab7daa46e366159753a166612e1/melbandroformers/karaoke/karaokegabox_1750911344.yaml?download=true" + }, + "mbr_leadvoc_dereverb_gabox": { + "category": "Реверб", + "id": 1144, + "full_name": "Mel-Band Roformer Lead Vocals DeReverb by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/experimental/Lead_VocalDereverb.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/27e73ca2beec0ab7daa46e366159753a166612e1/melbandroformers/karaoke/karaokegabox_1750911344.yaml?download=true" + }, + "mbr_denoise_debleed_gabox": { + "category": "Шум", + "id": 1150, + "full_name": "Mel-Band Roformer Denoise DeBleed by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/denoisedebleed.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + "mbr_karaoke_fusion_gonzaluigi": { + "category": "Караоке", + "id": 1160, + "full_name": "Mel-Band Roformer Karaoke Fusion by Gonzaluigi", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/mel_band_karaoke_fusion_standard.ckpt?download=true", + "config_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/melband_karaokefusion_gonza.yaml?download=true" + }, + "mbr_karaoke_fusion_aggr_gonzaluigi": { + "category": "Караоке", + "id": 1161, + "full_name": "Mel-Band Roformer Karaoke Fusion Aggressive by Gonzaluigi", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/mel_band_karaoke_fusion_aggressive.ckpt?download=true", + "config_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/melband_karaokefusion_gonza.yaml?download=true" + }, + "mbr_bve_gonzaluigi": { + "category": "Караоке", + "id": 1162, + "full_name": "Mel-Band Roformer BVE by Gonzaluigi", + "stems": [ + "Lead", + "Back" + ], + "target_instrument": "Lead", + "checkpoint_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Roformer-BVE-Gonzaluigi/resolve/main/mel_band_roformer_bve_gonza.ckpt?download=true", + "config_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Roformer-BVE-Gonzaluigi/resolve/main/config_bve_gonza.yaml?download=true" + }, + "mbr_karaoke_fusion2_aggr_gonzaluigi": { + "category": "Караоке", + "id": 1163, + "full_name": "Mel-Band Roformer Karaoke Fusion Aggressive by Gonzaluigi", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/mel_band_karaoke_fusion_agg_v2.ckpt?download=true", + "config_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/karaoke_fusion_v2_config.yaml?download=true" + }, + "mbr_karaoke_fusion_total_aggr_gonzaluigi": { + "category": "Караоке", + "id": 1164, + "full_name": "Mel-Band Roformer Karaoke Fusion Total by Gonzaluigi", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/mel_band_karaoke_fusion_total.ckpt?download=true", + "config_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/karaoke_fusion_total_config.yaml?download=true" + }, + "mbr_dereverb_anvuew": { + "category": "Реверб", + "id": 1170, + "full_name": "Mel-Band Roformer DeReverb by Anvuew", + "stems": [ + "reverb", + "noreverb" + ], + "target_instrument": "noreverb", + "checkpoint_url": "https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt?download=true", + "config_url": "https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew.yaml?download=true" + }, + "mbr_dereverb_less_aggr_anvuew": { + "category": "Реверб", + "id": 1171, + "full_name": "Mel-Band Roformer DeReverb Less Aggressive by Anvuew", + "stems": [ + "reverb", + "noreverb" + ], + "target_instrument": "noreverb", + "checkpoint_url": "https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_less_aggressive_anvuew_sdr_18.8050.ckpt?download=true", + "config_url": "https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew.yaml?download=true" + }, + "mbr_dereverb_mono_anvuew": { + "category": "Реверб", + "id": 1172, + "full_name": "Mel-Band Roformer DeReverb Mono by Anvuew", + "stems": [ + "reverb", + "noreverb" + ], + "target_instrument": "noreverb", + "checkpoint_url": "https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_mono_anvuew_sdr_20.4029.ckpt?download=true", + "config_url": "https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew.yaml?download=true" + }, + "mbr_aspiration_sucial": { + "category": "Дыхание", + "id": 1180, + "full_name": "Mel-Band Roformer Aspiration by Sucial", + "stems": [ + "aspiration", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Sucial/Aspiration_Mel_Band_Roformer/resolve/main/aspiration_mel_band_roformer_less_aggr_sdr_18.1201.ckpt?download=true", + "config_url": "https://huggingface.co/Sucial/Aspiration_Mel_Band_Roformer/resolve/main/config_aspiration_mel_band_roformer.yaml?download=true" + }, + "mbr_derverb_echo1_sucial": { + "category": "Реверб и эхо", + "id": 1181, + "full_name": "Mel-Band Roformer DeReverb-Echo by Sucial", + "stems": [ + "dry", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/dereverb-echo_mel_band_roformer_sdr_10.0169.ckpt?download=true", + "config_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/config_dereverb-echo_mel_band_roformer.yaml?download=true" + }, + "mbr_debigreverb_sucial": { + "category": "Реверб", + "id": 1182, + "full_name": "Mel-Band Roformer DeBigReverb by Sucial", + "stems": [ + "dry", + "other" + ], + "target_instrument": "dry", + "checkpoint_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/de_big_reverb_mbr_ep_362.ckpt?download=true", + "config_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/config_dereverb_echo_mbr_v2.yaml?download=true" + }, + "mbr_desuperbigreverb_sucial": { + "category": "Реверб", + "id": 1183, + "full_name": "Mel-Band Roformer Super Big DeReverb by Sucial", + "stems": [ + "dry", + "other" + ], + "target_instrument": "dry", + "checkpoint_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/de_super_big_reverb_mbr_ep_346.ckpt?download=true", + "config_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/config_dereverb_echo_mbr_v2.yaml?download=true" + }, + "mbr_dereverb-echo_fused_sucial": { + "category": "Реверб и эхо", + "id": 1184, + "full_name": "Mel-Band Roformer DeReverb-Echo Fused by Sucial", + "stems": [ + "dry", + "other" + ], + "target_instrument": "dry", + "checkpoint_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/dereverb_echo_mbr_fused_0.5_v2_0.25_big_0.25_super.ckpt?download=true", + "config_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/config_dereverb_echo_mbr_v2.yaml?download=true" + }, + "mbr_dereverb-echo2_sucial": { + "category": "Реверб и эхо", + "id": 1185, + "full_name": "Mel-Band Roformer DeReverb-Echo v2 by Sucial", + "stems": [ + "dry", + "other" + ], + "target_instrument": "dry", + "checkpoint_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/dereverb_echo_mbr_v2_sdr_dry_13.4843.ckpt?download=true", + "config_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/config_dereverb_echo_mbr_v2.yaml?download=true" + }, + "mbr_karaoke_aufr33_viperx": { + "category": "Караоке", + "id": 1190, + "full_name": "Mel-Band Roformer Karaoke by Aufr33 & ViperX", + "stems": [ + "karaoke", + "other" + ], + "target_instrument": "karaoke", + "checkpoint_url": "https://huggingface.co/jarredou/aufr33-viperx-karaoke-melroformer-model/resolve/main/mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956.ckpt?download=true", + "config_url": "https://huggingface.co/jarredou/aufr33-viperx-karaoke-melroformer-model/resolve/main/config_mel_band_roformer_karaoke.yaml?download=true" + }, + "mbr_denoise_aufr33": { + "category": "Шум", + "id": 1191, + "full_name": "Mel-Band Roformer DeNoise by Aufr33", + "stems": [ + "dry", + "other" + ], + "target_instrument": "dry", + "checkpoint_url": "https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/denoise_mel_band_roformer_aufr33_aggr_sdr_27.9768.ckpt?download=true", + "config_url": "https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/model_mel_band_roformer_denoise.yaml?download=true" + }, + "mbr_denoise_aggr_aufr33": { + "category": "Шум", + "id": 1192, + "full_name": "Mel-Band Roformer DeNoise Aggressive by Aufr33", + "stems": [ + "dry", + "other" + ], + "target_instrument": "dry", + "checkpoint_url": "https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/denoise_mel_band_roformer_aufr33_aggr_sdr_27.9768.ckpt?download=true", + "config_url": "https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/model_mel_band_roformer_denoise.yaml?download=true" + }, + "mbr_crowd_aufr33_viperx": { + "category": "Звуки толпы", + "id": 1193, + "full_name": "Mel-Band Roformer Crowd by Aufr33 & ViperX", + "stems": [ + "crowd", + "other" + ], + "target_instrument": "crowd", + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.4/mel_band_roformer_crowd_aufr33_viperx_sdr_8.7144.ckpt", + "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.4/model_mel_band_roformer_crowd.yaml" + }, + "mbr_vocals_viperx": { + "category": "Вокал", + "id": 1194, + "full_name": "Mel-Band Roformer Vocals by ViperX", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/viperx/model_mel_band_roformer_ep_3005_sdr_11.4360.yaml" + }, + "mbr_vocalsf_aname": { + "category": "Вокал", + "id": 1200, + "full_name": "Mel-Band Roformer Vocals Fullness by Aname", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Aname-Tommy/MelBandRoformers/resolve/main/FullnessVocalModel.ckpt?download=true", + "config_url": "https://huggingface.co/Aname-Tommy/MelBandRoformers/resolve/main/config.yaml?download=true" + }, + "mbr_kinft1_aname": { + "category": "Вокал", + "id": 1201, + "full_name": "Mel-Band Roformer Kim FT v1 by Aname", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Aname-Tommy/Test/resolve/main/model_kim.ckpt?download=true", + "config_url": "https://huggingface.co/Aname-Tommy/MelBandRoformers/resolve/main/config.yaml?download=true" + }, + "mbr_kinft2_aname": { + "category": "Вокал", + "id": 1202, + "full_name": "Mel-Band Roformer Kim FT v2 by Aname", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Aname-Tommy/Test/resolve/main/model_kim_2.ckpt?download=true", + "config_url": "https://huggingface.co/Aname-Tommy/MelBandRoformers/resolve/main/config.yaml?download=true" + }, + "mbr_kinft2f_aname": { + "category": "Вокал", + "id": 1203, + "full_name": "Mel-Band Roformer Kim FT v2 Fullness by Aname", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Aname-Tommy/Test/resolve/main/model_kim_2_fullness.ckpt?download=true", + "config_url": "https://huggingface.co/Aname-Tommy/MelBandRoformers/resolve/main/config.yaml?download=true" + }, + "mbr_kinft3_aname": { + "category": "Вокал", + "id": 1204, + "full_name": "Mel-Band Roformer Kim FT v3 by Aname", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Aname-Tommy/Test/resolve/main/model_kim_3.ckpt?download=true", + "config_url": "https://huggingface.co/Aname-Tommy/MelBandRoformers/resolve/main/config.yaml?download=true" + }, + "mbr_small_aname": { + "category": "Вокал", + "id": 1205, + "full_name": "Mel-Band Roformer Small by Aname", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Aname-Tommy/Mel_Band_Roformer_small/resolve/main/mel_band_roformer_small.ckpt?download=true", + "config_url": "https://huggingface.co/Aname-Tommy/Mel_Band_Roformer_small/resolve/main/config.yaml?download=true" + }, + "mbr_duality1_aname": { + "category": "Вокал", + "id": 1209, + "full_name": "Mel-Band Roformer Duality v1 by Aname", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Aname-Tommy/Mel-Band-Roformer_Duality/resolve/main/duality_v1.ckpt?download=true", + "config_url": "https://huggingface.co/Aname-Tommy/Mel-Band-Roformer_Duality/resolve/main/config_v1.yaml?download=true" + }, + "mbr_4stemlarge1_aname": { + "category": "4 стема", + "id": 1206, + "full_name": "Mel-Band Roformer 4 Stems Large by Aname", + "stems": [ + "vocals", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Aname-Tommy/melbandroformer4stems/resolve/main/mel_band_roformer_4stems_large_ver1.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/config.yaml?download=true" + }, + "mbr_4stemlarge2_aname": { + "category": "4 стема", + "id": 1207, + "full_name": "Mel-Band Roformer 4 Stems v2 Large by Aname", + "stems": [ + "vocals", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Aname-Tommy/Test/resolve/main/4stemsver2.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/config.yaml?download=true" + }, + "mbr_4stemxl1_aname": { + "category": "4 стема", + "id": 1208, + "full_name": "Mel-Band Roformer 4 Stems XL by Aname", + "stems": [ + "vocals", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Aname-Tommy/melbandroformer4stems/resolve/main/mel_band_roformer_4stems_xl_ver1.ckpt?download=true", + "config_url": "https://huggingface.co/Aname-Tommy/melbandroformer4stems/resolve/main/config_xl.yaml?download=true" + }, + "mbr_percussion_yolkispaliks": { + "category": "Ударные", + "id": 1900, + "full_name": "Mel-Band Roformer Drums Experimental by yolkispalkis", + "stems": [ + "percussions", + "other" + ], + "target_instrument": "percussions", + "checkpoint_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/model_mel_band_roformer_ep_11_sdr_7.6853.ckpt?download=true", + "config_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/config_drums_musdb18_moises_mel_band_roformer.yaml?download=true" + }, + "mbr_inst_metal_prev_meskvlla33": { + "category": "Инструментал", + "id": 1901, + "full_name": "Mel-Band Roformer Metal Inst Preview by Mesk", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "other", + "checkpoint_url": "https://huggingface.co/meskvlla33/metal_roformer_preview/resolve/main/metal_roformer_inst_mesk_preview.ckpt?download=true", + "config_url": "https://huggingface.co/meskvlla33/metal_roformer_preview/resolve/main/config_inst_metal_roformer_mesk.yaml?download=true" + }, + "mbr_neo_inst_vfx": { + "category": "Инструментал", + "id": 1902, + "full_name": "Mel-Band Roformer NEO Inst VFX by natanworkspace", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/natanworkspace/melband_roformer/resolve/main/Neo_InstVFX.ckpt?download=true", + "config_url": "https://huggingface.co/natanworkspace/melband_roformer/resolve/main/config_neo_inst.yaml?download=true" + } + }, + "bs_roformer": { + "bs_drums_beatloo_labs": { + "category": "Ударные", + "id": 200, + "full_name": "BS Roformer Drums Experimental by BeatLoo Labs", + "stems": [ + "drums", + "other" + ], + "target_instrument": "drums", + "checkpoint_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/model_drums_bs_roformer_ep_12_sdr_7.2279.ckpt?download=true", + "config_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/BS-Roformer_Drums_beatloo_labs_config.yaml?download=true" + }, + "bs_bass_beatloo_labs": { + "category": "Басс", + "id": 201, + "full_name": "BS Roformer Bass Experimental by BeatLoo Labs", + "stems": [ + "bass", + "other" + ], + "target_instrument": "bass", + "checkpoint_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/model_bass_bs_roformer_ep_10_sdr_5.7862.ckpt?download=true", + "config_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/BS-Roformer_Bass_beatloo_labs_config.yaml?download=true" + }, + "bs_vocals_1296_viperx": { + "category": "Вокал", + "id": 202, + "full_name": "BS Roformer Vocals (sdr 12.96) by ViperX", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Vocals", + "checkpoint_url": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_bs_roformer_ep_368_sdr_12.9628.ckpt", + "config_url": "https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/mdx_c_configs/model_bs_roformer_ep_368_sdr_12.9628.yaml" + }, + "bs_other_viperx": { + "category": "Прочее", + "id": 203, + "full_name": "BS Roformer Other by ViperX", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "other", + "checkpoint_url": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_bs_roformer_ep_937_sdr_10.5309.ckpt", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/viperx/model_bs_roformer_ep_937_sdr_10.5309.yaml" + }, + "bs_revive1_unwa": { + "category": "Вокал", + "id": 210, + "full_name": "BS Roformer Vocals Revive v1 by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/BS-Roformer-Revive/resolve/main/bs_roformer_revive.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/BS-Roformer-Revive/resolve/main/config.yaml?download=true" + }, + "bs_revive2_unwa": { + "category": "Вокал", + "id": 211, + "full_name": "BS Roformer Vocals Revive v2 by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/BS-Roformer-Revive/resolve/main/bs_roformer_revive2.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/BS-Roformer-Revive/resolve/main/config.yaml?download=true" + }, + "bs_revive3e_unwa": { + "category": "Вокал", + "id": 212, + "full_name": "BS Roformer Vocals Revive v3e by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/BS-Roformer-Revive/resolve/main/bs_roformer_revive3e.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/BS-Roformer-Revive/resolve/main/config.yaml?download=true" + }, + "bs_resurrection_unwa": { + "category": "Вокал", + "id": 213, + "full_name": "BS Roformer Vocals Resurrection by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/BS-Roformer-Resurrection/resolve/main/BS-Roformer-Resurrection.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/BS-Roformer-Resurrection/resolve/main/BS-Roformer-Resurrection-Config.yaml?download=true" + }, + "bs_resurrection_inst_unwa": { + "category": "Инструментал", + "id": 214, + "full_name": "BS Roformer Instrumental Resurrection by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "other", + "checkpoint_url": "https://huggingface.co/pcunwa/BS-Roformer-Resurrection/resolve/main/BS-Roformer-Resurrection-Inst.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/BS-Roformer-Resurrection/resolve/main/BS-Roformer-Resurrection-Inst-Config.yaml?download=true" + }, + "bs_inst_fno_unwa": { + "category": "Инструментал", + "id": 226, + "full_name": "BS Roformer Instrumental FNO by Unwa", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "other", + "checkpoint_url": "https://huggingface.co/pcunwa/BS-Roformer-Inst-FNO/resolve/main/bs_roformer_fno.ckpt?download=true", + "config_url": "https://raw.githubusercontent.com/noblebarkrr/mvsepless/refs/heads/beta/fixed_configs/BS-Roformer_Inst_FNO_unwa_config.yaml" + }, + "bs_karaoke_becruily": { + "category": "Караоке", + "id": 215, + "full_name": "BS Roformer Karaoke by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/becruily/bs-roformer-karaoke/resolve/main/bs_roformer_karaoke_frazer_becruily.ckpt?download=true", + "config_url": "https://huggingface.co/becruily/bs-roformer-karaoke/resolve/main/config_karaoke_frazer_becruily.yaml?download=true" + }, + "bs_voctest_gabox": { + "category": "Вокал", + "id": 216, + "full_name": "BS Roformer Vocals by GaboxR67", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/GaboxR67/BSRoformerVocTest/resolve/main/voc_gaboxBSR.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/BSRoformerVocTest/resolve/main/voc_gaboxBSroformer.yaml?download=true" + }, + "bs_karaoke_gabox": { + "category": "Караоке", + "id": 229, + "full_name": "BS Roformer Karaoke by GaboxR67", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/aa0a94097a222fcc8ebdd691b1435145eda31b46/bsroformers/bs_karaoke_gabox_IS.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/aa0a94097a222fcc8ebdd691b1435145eda31b46/bsroformers/karaoke_bs_roformer.yaml?download=true" + }, + "bs_6stem": { + "category": "6 стемов", + "id": 217, + "full_name": "BS Roformer SW", + "stems": [ + "bass", + "drums", + "other", + "piano", + "guitar", + "vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://github.com/undef13/splifft/releases/download/v0.0.1/roformer-fp16.pt", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/BS-Roformer_SW_config.yaml?download=true" + }, + "bs_6stem_fixed": { + "category": "6 стемов", + "id": 227, + "full_name": "BS Roformer SW Fixed by jarredou", + "stems": [ + "bass", + "drums", + "other", + "piano", + "guitar", + "vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/jarredou/BS-ROFO-SW-Fixed/resolve/main/BS-Rofo-SW-Fixed.ckpt?download=true", + "config_url": "https://huggingface.co/jarredou/BS-ROFO-SW-Fixed/resolve/main/BS-Rofo-SW-Fixed.yaml?download=true" + }, + "bs_4stem_zfturbo": { + "category": "4 стема", + "id": 218, + "full_name": "BS Roformer 4 Stems by ZFTurbo", + "stems": [ + "vocals", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.12/model_bs_roformer_ep_17_sdr_9.6568.ckpt", + "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.12/config_bs_roformer_384_8_2_485100.yaml" + }, + "bs_4stemft_syh99999": { + "category": "4 стема", + "id": 219, + "full_name": "BS Roformer 4 Stems FT by SYH99999", + "stems": [ + "vocals", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/SYH99999/bs_roformer_4stems_ft/resolve/main/bs_roformer_4stems_ft.pth?download=true", + "config_url": "https://huggingface.co/SYH99999/bs_roformer_4stems_ft/resolve/main/config.yaml?download=true" + }, + "bs_male_female_146_sucial": { + "category": "Мужской/Женский вокал", + "id": 220, + "full_name": "BS Roformer Male-Female (ep 146) by Sucial", + "stems": [ + "male", + "female" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Sucial/Chorus_Male_Female_BS_Roformer/resolve/main/model_chorus_bs_roformer_ep_146_sdr_23.8613.ckpt?download=true", + "config_url": "https://huggingface.co/Sucial/Chorus_Male_Female_BS_Roformer/resolve/main/config_chorus_male_female_bs_roformer.yaml?download=true" + }, + "bs_male_female_267_sucial": { + "category": "Мужской/Женский вокал", + "id": 221, + "full_name": "BS Roformer Male-Female (ep 267) by Sucial", + "stems": [ + "male", + "female" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Sucial/Chorus_Male_Female_BS_Roformer/resolve/main/model_chorus_bs_roformer_ep_267_sdr_24.1275.ckpt?download=true", + "config_url": "https://huggingface.co/Sucial/Chorus_Male_Female_BS_Roformer/resolve/main/config_chorus_male_female_bs_roformer.yaml?download=true" + }, + "bs_male_female_aufr33": { + "category": "Мужской/Женский вокал", + "id": 222, + "full_name": "BS Roformer Male-Female by Aufr33", + "stems": [ + "male", + "female" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/RareSirMix/AIModelRehosting/resolve/main/bs_roformer_male_female_by_aufr33_sdr_7.2889.ckpt", + "config_url": "https://huggingface.co/Sucial/Chorus_Male_Female_BS_Roformer/resolve/main/config_chorus_male_female_bs_roformer.yaml" + }, + "bs_deverb_256_8_anvuew": { + "category": "Реверб", + "id": 223, + "full_name": "BS Roformer Deverb 256-8 by Anvuew", + "stems": [ + "reverb", + "noreverb" + ], + "target_instrument": "noreverb", + "checkpoint_url": "https://huggingface.co/anvuew/deverb_bs_roformer/resolve/main/deverb_bs_roformer_8_256dim_8depth.ckpt?download=true", + "config_url": "https://huggingface.co/anvuew/deverb_bs_roformer/resolve/main/deverb_bs_roformer_8_256dim_8depth.yaml?download=true" + }, + "bs_deverb_384_10_anvuew": { + "category": "Реверб", + "id": 224, + "full_name": "BS Roformer Deverb 384-10 by Anvuew", + "stems": [ + "reverb", + "noreverb" + ], + "target_instrument": "noreverb", + "checkpoint_url": "https://huggingface.co/anvuew/deverb_bs_roformer/resolve/main/deverb_bs_roformer_8_384dim_10depth.ckpt?download=true", + "config_url": "https://huggingface.co/anvuew/deverb_bs_roformer/resolve/main/deverb_bs_roformer_8_384dim_10depth.yaml?download=true" + }, + "bs_karaoke_anvuew": { + "category": "Караоке", + "id": 228, + "full_name": "BS Roformer Karaoke by Anvuew", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/anvuew/karaoke_bs_roformer/resolve/main/karaoke_bs_roformer_anvuew.ckpt?download=true", + "config_url": "https://huggingface.co/anvuew/karaoke_bs_roformer/resolve/main/karaoke_bs_roformer_anvuew.yaml?download=true" + }, + "bs_vocals_anvuew": { + "category": "Вокал", + "id": 230, + "full_name": "BS Roformer Vocals by Anvuew", + "stems": [ + "vocals", + "instrument" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/anvuew/BS-RoFormer/resolve/main/bs_roformer_anvuew_sdr_12.45.ckpt?download=true", + "config_url": "https://huggingface.co/anvuew/BS-RoFormer/resolve/main/config.yaml?download=true" + }, + "bs_4stem_aname": { + "category": "4 стема", + "id": 225, + "full_name": "BS Roformer 4 stems by Aname", + "stems": [ + "vocals", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/RareSirMix/AIModelRehosting/resolve/main/Amane4stem_bs_roformer.ckpt", + "config_url": "https://huggingface.co/RareSirMix/AIModelRehosting/resolve/main/Amane4stem_bs_roformer.yaml" + } + }, + "mdx23c": { + "mdx23c_instvoc_zfturbo": { + "category": "Инструментал и вокал", + "id": 300, + "full_name": "MDX23C Inst-Voc HQ by ZFTurbo", + "stems": [ + "vocals", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.0/model_vocals_mdx23c_sdr_10.17.ckpt", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_vocals_mdx23c.yaml" + }, + "mdx23c_instvoc_hq1": { + "category": "Инструментал и вокал", + "id": 301, + "full_name": "MDX23C 8k FFT Inst-Voc HQ v1", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDX23C/MDX23C-8KFFT-InstVoc_HQ.ckpt?download=true", + "config_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDX23C/model_2_stem_full_band_8k.yaml?download=true" + }, + "mdx23c_instvoc_hq2": { + "category": "Инструментал и вокал", + "id": 302, + "full_name": "MDX23C 8k FFT Inst-Voc HQ v2", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDX23C/MDX23C-8KFFT-InstVoc_HQ_2.ckpt?download=true", + "config_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDX23C/model_2_stem_full_band_8k.yaml?download=true" + }, + "mdx23c_d1581": { + "category": "Инструментал и вокал", + "id": 303, + "full_name": "MDX23C D1581", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDX23C/MDX23C_D1581.ckpt?download=true", + "config_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDX23C/model_2_stem_061321.yaml?download=true" + }, + "mdx23c_drumsep_6stem_aufr33_jarredou": { + "category": "DrumSep", + "id": 304, + "full_name": "MDX23C DrumSep 6 stems by Aufr33 & Jarredou", + "stems": [ + "kick", + "snare", + "toms", + "hh", + "ride", + "crash" + ], + "target_instrument": null, + "checkpoint_url": "https://github.com/jarredou/models/releases/download/aufr33-jarredou_MDX23C_DrumSep_model_v0.1/aufr33-jarredou_DrumSep_model_mdx23c_ep_141_sdr_10.8059.ckpt", + "config_url": "https://github.com/jarredou/models/releases/download/aufr33-jarredou_MDX23C_DrumSep_model_v0.1/aufr33-jarredou_DrumSep_model_mdx23c_ep_141_sdr_10.8059.yaml" + }, + "mdx23c_drumsep_5stem_aufr33_jarredou": { + "category": "DrumSep", + "id": 309, + "full_name": "MDX23C DrumSep 5 stems by Aufr33 & Jarredou", + "stems": [ + "kick", + "snare", + "toms", + "hh", + "cymbals" + ], + "target_instrument": null, + "checkpoint_url": "https://github.com/jarredou/models/releases/download/DrumSep/drumsep_5stems_mdx23c_jarredou.ckpt", + "config_url": "https://github.com/jarredou/models/releases/download/DrumSep/config_mdx23c.yaml" + }, + "mdx23c_derverb_aufr33_jarredou": { + "category": "Реверб", + "id": 305, + "full_name": "MDX23C DeReverb by Aufr33 & Jarredou", + "stems": [ + "dry", + "other" + ], + "target_instrument": "dry", + "checkpoint_url": "https://huggingface.co/jarredou/aufr33_jarredou_MDXv3_DeReverb/resolve/main/dereverb_mdx23c_sdr_6.9096.ckpt", + "config_url": "https://huggingface.co/jarredou/aufr33_jarredou_MDXv3_DeReverb/resolve/main/config_dereverb_mdx23c.yaml" + }, + "mdx23c_mid_side_wesleyr36": { + "category": "Фантомный центр", + "id": 306, + "full_name": "MDX23C Mid-Side by WesleyR36", + "stems": [ + "similarity", + "difference" + ], + "target_instrument": "similarity", + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.10/model_mdx23c_ep_271_l1_freq_72.2383.ckpt", + "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.10/config_mdx23c_similarity.yaml" + }, + "mdx23c_4stem_zfturbo": { + "category": "4 стема", + "id": 307, + "full_name": "MDX23C 4 Stems by ZFTurbo", + "stems": [ + "vocals", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.1/model_mdx23c_ep_168_sdr_7.0207.ckpt", + "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.1/config_musdb18_mdx23c.yaml" + }, + "mdx23c_orch_verosment": { + "category": "Оркестр", + "id": 308, + "full_name": "MDX23C Orchestra Experimental by Verosment", + "stems": [ + "inst", + "orch" + ], + "target_instrument": "orch", + "checkpoint_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/model_mdx23c_ep_120_sdr_4.4174.ckpt?download=true", + "config_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/config_orchestra_mdx23c.yaml?download=true" + } + }, + "mdxnet": { + "mdx_kim_inst": { + "category": "Инструментал", + "id": 600, + "full_name": "MDX-Net Model: UVR-MDX-NET Kim Inst", + "stems": [ + "instrumental", + "vocals" + ], + "target_instrument": "instrumental", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/Kim_Inst.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/Kim_Inst.yaml?download=true" + }, + "mdx_kim_vocal1": { + "category": "Вокал", + "id": 601, + "full_name": "MDX-Net Model: UVR-MDX-NET Kim Vocal 1", + "stems": [ + "vocals", + "instrumental" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/Kim_Vocal_1.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/Kim_Vocal_1.yaml?download=true" + }, + "mdx_kim_vocal2": { + "category": "Вокал", + "id": 602, + "full_name": "MDX-Net Model: UVR-MDX-NET Kim Vocal 2", + "stems": [ + "vocals", + "instrumental" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/Kim_Vocal_2.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/Kim_Vocal_2.yaml?download=true" + }, + "mdx_kuielab_a_bass": { + "category": "Басс", + "id": 603, + "full_name": "MDX-Net Model: UVR-MDX-NET Kuielab A Bass", + "stems": [ + "bass", + "other" + ], + "target_instrument": "bass", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/kuielab_a_bass.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/kuielab_a_bass.yaml?download=true" + }, + "mdx_kuielab_a_drums": { + "category": "Ударные", + "id": 604, + "full_name": "MDX-Net Model: UVR-MDX-NET Kuielab A Drums", + "stems": [ + "drums", + "other" + ], + "target_instrument": "drums", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/kuielab_a_drums.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/kuielab_a_drums.yaml?download=true" + }, + "mdx_kuielab_a_other": { + "category": "Прочее", + "id": 605, + "full_name": "MDX-Net Model: UVR-MDX-NET Kuielab A Other", + "stems": [ + "other", + "no_other" + ], + "target_instrument": "other", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/kuielab_a_other.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/kuielab_a_other.yaml?download=true" + }, + "mdx_kuielab_a_vocals": { + "category": "Вокал", + "id": 606, + "full_name": "MDX-Net Model: UVR-MDX-NET Kuielab A Vocals", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/kuielab_a_vocals.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/kuielab_a_vocals.yaml?download=true" + }, + "mdx_kuielab_b_bass": { + "category": "Басс", + "id": 607, + "full_name": "MDX-Net Model: UVR-MDX-NET Kuielab B Bass", + "stems": [ + "bass", + "other" + ], + "target_instrument": "bass", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/kuielab_b_bass.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/kuielab_b_bass.yaml?download=true" + }, + "mdx_kuielab_b_drums": { + "category": "Ударные", + "id": 608, + "full_name": "MDX-Net Model: UVR-MDX-NET Kuielab B Drums", + "stems": [ + "drums", + "other" + ], + "target_instrument": "drums", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/kuielab_b_drums.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/kuielab_b_drums.yaml?download=true" + }, + "mdx_kuielab_b_other": { + "category": "Прочее", + "id": 609, + "full_name": "MDX-Net Model: UVR-MDX-NET Kuielab B Other", + "stems": [ + "other", + "no_other" + ], + "target_instrument": "other", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/kuielab_b_other.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/kuielab_b_other.yaml?download=true" + }, + "mdx_kuielab_b_vocals": { + "category": "Вокал", + "id": 610, + "full_name": "MDX-Net Model: UVR-MDX-NET Kuielab B Vocals", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/kuielab_b_vocals.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/kuielab_b_vocals.yaml?download=true" + }, + "mdx_reverb_hq_foxjoy": { + "category": "Реверб", + "id": 611, + "full_name": "MDX-Net Model: UVR-MDX-NET Reverb HQ FoxJoy", + "stems": [ + "reverb", + "other" + ], + "target_instrument": "reverb", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/Reverb_HQ_By_FoxJoy.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/Reverb_HQ_By_FoxJoy.yaml?download=true" + }, + "mdx_inst1": { + "category": "Инструментал", + "id": 612, + "full_name": "MDX-Net Model: UVR-MDX-NET Inst 1", + "stems": [ + "instrumental", + "vocals" + ], + "target_instrument": "instrumental", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_1.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_1.yaml?download=true" + }, + "mdx_inst2": { + "category": "Инструментал", + "id": 613, + "full_name": "MDX-Net Model: UVR-MDX-NET Inst 2", + "stems": [ + "instrumental", + "vocals" + ], + "target_instrument": "instrumental", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_2.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_2.yaml?download=true" + }, + "mdx_inst3": { + "category": "Инструментал", + "id": 614, + "full_name": "MDX-Net Model: UVR-MDX-NET Inst 3", + "stems": [ + "instrumental", + "vocals" + ], + "target_instrument": "instrumental", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_3.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_3.yaml?download=true" + }, + "mdx_inst_full_292": { + "category": "Инструментал", + "id": 615, + "full_name": "MDX-Net Model: UVR-MDX-NET Inst Full 292", + "stems": [ + "instrumental", + "vocals" + ], + "target_instrument": "instrumental", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_full_292.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_full_292.yaml?download=true" + }, + "mdx_inst_hq1": { + "category": "Инструментал", + "id": 616, + "full_name": "MDX-Net Model: UVR-MDX-NET Inst HQ 1", + "stems": [ + "instrumental", + "vocals" + ], + "target_instrument": "instrumental", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_HQ_1.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_HQ_1.yaml?download=true" + }, + "mdx_inst_hq2": { + "category": "Инструментал", + "id": 617, + "full_name": "MDX-Net Model: UVR-MDX-NET Inst HQ 2", + "stems": [ + "instrumental", + "vocals" + ], + "target_instrument": "instrumental", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_HQ_2.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_HQ_2.yaml?download=true" + }, + "mdx_inst_hq3": { + "category": "Инструментал", + "id": 618, + "full_name": "MDX-Net Model: UVR-MDX-NET Inst HQ 3", + "stems": [ + "instrumental", + "vocals" + ], + "target_instrument": "instrumental", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_HQ_3.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_HQ_3.yaml?download=true" + }, + "mdx_inst_hq4": { + "category": "Инструментал", + "id": 619, + "full_name": "MDX-Net Model: UVR-MDX-NET Inst HQ 4", + "stems": [ + "instrumental", + "vocals" + ], + "target_instrument": "instrumental", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_HQ_4.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_HQ_4.yaml?download=true" + }, + "mdx_inst_hq5": { + "category": "Инструментал", + "id": 620, + "full_name": "MDX-Net Model: UVR-MDX-NET Inst HQ 5", + "stems": [ + "instrumental", + "vocals" + ], + "target_instrument": "instrumental", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_HQ_5.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_HQ_5.yaml?download=true" + }, + "mdx_inst_main": { + "category": "Инструментал", + "id": 621, + "full_name": "MDX-Net Model: UVR-MDX-NET Inst Main", + "stems": [ + "instrumental", + "vocals" + ], + "target_instrument": "instrumental", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Inst_Main.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Inst_Main.yaml?download=true" + }, + "mdx_vocft": { + "category": "Вокал", + "id": 622, + "full_name": "MDX-Net Model: UVR-MDX-NET Voc FT", + "stems": [ + "vocals", + "instrumental" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET-Voc_FT.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET-Voc_FT.yaml?download=true" + }, + "mdx_crowd_hq1": { + "category": "Звуки толпы", + "id": 623, + "full_name": "MDX-Net Model: UVR-MDX-NET Crowd HQ 1", + "stems": [ + "other", + "crowd" + ], + "target_instrument": "other", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET_Crowd_HQ_1.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET_Crowd_HQ_1.yaml?download=true" + }, + "mdx_inst_187_beta": { + "category": "Инструментал", + "id": 624, + "full_name": "MDX-Net Model: UVR-MDX-NET Inst 187 beta", + "stems": [ + "instrumental", + "vocals" + ], + "target_instrument": "instrumental", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET_Inst_187_beta.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET_Inst_187_beta.yaml?download=true" + }, + "mdx_inst_82_beta": { + "category": "Инструментал", + "id": 625, + "full_name": "MDX-Net Model: UVR-MDX-NET Inst 82 beta", + "stems": [ + "instrumental", + "vocals" + ], + "target_instrument": "instrumental", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET_Inst_82_beta.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET_Inst_82_beta.yaml?download=true" + }, + "mdx_inst_90_beta": { + "category": "Инструментал", + "id": 626, + "full_name": "MDX-Net Model: UVR-MDX-NET Inst 90 beta", + "stems": [ + "instrumental", + "vocals" + ], + "target_instrument": "instrumental", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET_Inst_90_beta.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET_Inst_90_beta.yaml?download=true" + }, + "mdx_main_340": { + "category": "Вокал", + "id": 627, + "full_name": "MDX-Net Model: UVR-MDX-NET Main 340", + "stems": [ + "vocals", + "instrumental" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET_Main_340.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET_Main_340.yaml?download=true" + }, + "mdx_main_390": { + "category": "Вокал", + "id": 628, + "full_name": "MDX-Net Model: UVR-MDX-NET Main 390", + "stems": [ + "vocals", + "instrumental" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET_Main_390.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET_Main_390.yaml?download=true" + }, + "mdx_main_406": { + "category": "Вокал", + "id": 629, + "full_name": "MDX-Net Model: UVR-MDX-NET Main 406", + "stems": [ + "vocals", + "instrumental" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET_Main_406.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET_Main_406.yaml?download=true" + }, + "mdx_main_427": { + "category": "Вокал", + "id": 630, + "full_name": "MDX-Net Model: UVR-MDX-NET Main 427", + "stems": [ + "vocals", + "instrumental" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET_Main_427.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET_Main_427.yaml?download=true" + }, + "mdx_main_438": { + "category": "Вокал", + "id": 631, + "full_name": "MDX-Net Model: UVR-MDX-NET Main 438", + "stems": [ + "vocals", + "instrumental" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR-MDX-NET_Main_438.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR-MDX-NET_Main_438.yaml?download=true" + }, + "mdx_1_9703": { + "category": "Вокал", + "id": 632, + "full_name": "MDX-Net Model: UVR-MDX-NET 1 9703", + "stems": [ + "vocals", + "instrumental" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR_MDXNET_1_9703.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR_MDXNET_1_9703.yaml?download=true" + }, + "mdx_2_9682": { + "category": "Вокал", + "id": 633, + "full_name": "MDX-Net Model: UVR-MDX-NET 2 9682", + "stems": [ + "vocals", + "instrumental" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR_MDXNET_2_9682.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR_MDXNET_2_9682.yaml?download=true" + }, + "mdx_3_9662": { + "category": "Вокал", + "id": 634, + "full_name": "MDX-Net Model: UVR-MDX-NET 3 9662", + "stems": [ + "vocals", + "instrumental" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR_MDXNET_3_9662.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR_MDXNET_3_9662.yaml?download=true" + }, + "mdx_9482": { + "category": "Вокал", + "id": 635, + "full_name": "MDX-Net Model: UVR-MDX-NET 9482", + "stems": [ + "vocals", + "instrumental" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR_MDXNET_9482.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR_MDXNET_9482.yaml?download=true" + }, + "mdx_karaoke1": { + "category": "Караоке", + "id": 636, + "full_name": "MDX-Net Model: UVR-MDX-NET Karaoke 1", + "stems": [ + "vocals", + "other" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR_MDXNET_KARA.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR_MDXNET_KARA.yaml?download=true" + }, + "mdx_karaoke2": { + "category": "Караоке", + "id": 637, + "full_name": "MDX-Net Model: UVR-MDX-NET Karaoke 2", + "stems": [ + "other", + "vocals" + ], + "target_instrument": "other", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR_MDXNET_KARA_2.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR_MDXNET_KARA_2.yaml?download=true" + }, + "mdx_main": { + "category": "Вокал", + "id": 638, + "full_name": "MDX-Net Model: UVR-MDX-NET Main", + "stems": [ + "vocals", + "instrumental" + ], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDXNet/UVR_MDXNET_Main.onnx?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/mdx_configs/UVR_MDXNET_Main.yaml?download=true" + } + }, + "vr": { + "1_hp-uvr": { + "category": "Инструментал", + "id": 500, + "full_name": "VR Arch Single Model v5: 1_HP-UVR", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/1_HP-UVR.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/1_HP-UVR.yaml?download=true" + }, + "2_hp-uvr": { + "category": "Инструментал", + "id": 501, + "full_name": "VR Arch Single Model v5: 2_HP-UVR", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/2_HP-UVR.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/2_HP-UVR.yaml?download=true" + }, + "3_hp-vocal-uvr": { + "category": "Вокал", + "id": 502, + "full_name": "VR Arch Single Model v5: 3_HP-Vocal-UVR", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/3_HP-Vocal-UVR.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/3_HP-Vocal-UVR.yaml?download=true" + }, + "4_hp-vocal-uvr": { + "category": "Вокал", + "id": 503, + "full_name": "VR Arch Single Model v5: 4_HP-Vocal-UVR", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/4_HP-Vocal-UVR.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/4_HP-Vocal-UVR.yaml?download=true" + }, + "5_hp-karaoke-uvr": { + "category": "Караоке", + "id": 504, + "full_name": "VR Arch Single Model v5: 5_HP-Karaoke-UVR", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/5_HP-Karaoke-UVR.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/5_HP-Karaoke-UVR.yaml?download=true" + }, + "6_hp-karaoke-uvr": { + "category": "Караоке", + "id": 505, + "full_name": "VR Arch Single Model v5: 6_HP-Karaoke-UVR", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/6_HP-Karaoke-UVR.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/6_HP-Karaoke-UVR.yaml?download=true" + }, + "7_hp2-uvr": { + "category": "Инструментал и вокал", + "id": 506, + "full_name": "VR Arch Single Model v5: 7_HP2-UVR", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/7_HP2-UVR.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/7_HP2-UVR.yaml?download=true" + }, + "8_hp2-uvr": { + "category": "Инструментал и вокал", + "id": 507, + "full_name": "VR Arch Single Model v5: 8_HP2-UVR", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/8_HP2-UVR.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/8_HP2-UVR.yaml?download=true" + }, + "9_hp2-uvr": { + "category": "Инструментал и вокал", + "id": 508, + "full_name": "VR Arch Single Model v5: 9_HP2-UVR", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/9_HP2-UVR.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/9_HP2-UVR.yaml?download=true" + }, + "10_sp-uvr-2b-32000-1": { + "category": "Инструментал и вокал", + "id": 509, + "full_name": "VR Arch Single Model v5: 10_SP-UVR-2B-32000-1", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/10_SP-UVR-2B-32000-1.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/10_SP-UVR-2B-32000-1.yaml?download=true" + }, + "11_sp-uvr-2b-32000-2": { + "category": "Инструментал и вокал", + "id": 510, + "full_name": "VR Arch Single Model v5: 11_SP-UVR-2B-32000-2", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/11_SP-UVR-2B-32000-2.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/11_SP-UVR-2B-32000-2.yaml?download=true" + }, + "12_sp-uvr-3b-44100": { + "category": "Инструментал и вокал", + "id": 511, + "full_name": "VR Arch Single Model v5: 12_SP-UVR-3B-44100", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/12_SP-UVR-3B-44100.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/12_SP-UVR-3B-44100.yaml?download=true" + }, + "13_sp-uvr-4b-44100-1": { + "category": "Инструментал и вокал", + "id": 512, + "full_name": "VR Arch Single Model v5: 13_SP-UVR-4B-44100-1", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/13_SP-UVR-4B-44100-1.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/13_SP-UVR-4B-44100-1.yaml?download=true" + }, + "14_sp-uvr-4b-44100-2": { + "category": "Инструментал и вокал", + "id": 513, + "full_name": "VR Arch Single Model v5: 14_SP-UVR-4B-44100-2", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/14_SP-UVR-4B-44100-2.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/14_SP-UVR-4B-44100-2.yaml?download=true" + }, + "15_sp-uvr-mid-44100-1": { + "category": "Инструментал и вокал", + "id": 514, + "full_name": "VR Arch Single Model v5: 15_SP-UVR-MID-44100-1", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/15_SP-UVR-MID-44100-1.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/15_SP-UVR-MID-44100-1.yaml?download=true" + }, + "16_sp-uvr-mid-44100-2": { + "category": "Инструментал и вокал", + "id": 515, + "full_name": "VR Arch Single Model v5: 16_SP-UVR-MID-44100-2", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/16_SP-UVR-MID-44100-2.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/16_SP-UVR-MID-44100-2.yaml?download=true" + }, + "17_hp-wind_inst-uvr": { + "category": "Деревянные духовые", + "id": 516, + "full_name": "VR Arch Single Model v5: 17_HP-Wind_Inst-UVR", + "stems": [ + "No Woodwinds", + "Woodwinds" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/17_HP-Wind_Inst-UVR.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/17_HP-Wind_Inst-UVR.yaml?download=true" + }, + "uvr-de-echo-aggressive": { + "category": "Эхо", + "id": 517, + "full_name": "VR Arch Single Model v5: UVR-De-Echo-Aggressive by FoxJoy", + "stems": [ + "No Echo", + "Echo" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/UVR-De-Echo-Aggressive.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-De-Echo-Aggressive.yaml?download=true" + }, + "uvr-de-echo-normal": { + "category": "Эхо", + "id": 518, + "full_name": "VR Arch Single Model v5: UVR-De-Echo-Normal by FoxJoy", + "stems": [ + "No Echo", + "Echo" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/UVR-De-Echo-Normal.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-De-Echo-Normal.yaml?download=true" + }, + "uvr-deecho-dereverb": { + "category": "Реверб", + "id": 519, + "full_name": "VR Arch Single Model v5: UVR-DeEcho-DeReverb by FoxJoy", + "stems": [ + "No Reverb", + "Reverb" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/UVR-DeEcho-DeReverb.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-DeEcho-DeReverb.yaml?download=true" + }, + "uvr-denoise-lite": { + "category": "Шум", + "id": 520, + "full_name": "VR Arch Single Model v5: UVR-DeNoise-Lite by FoxJoy", + "stems": [ + "Noise", + "No Noise" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/UVR-DeNoise-Lite.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-DeNoise-Lite.yaml?download=true" + }, + "uvr-denoise": { + "category": "Шум", + "id": 521, + "full_name": "VR Arch Single Model v5: UVR-DeNoise by FoxJoy", + "stems": [ + "Noise", + "No Noise" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/UVR-DeNoise.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-DeNoise.yaml?download=true" + }, + "uvr-bve-4b_sn-44100-1": { + "category": "Караоке", + "id": 522, + "full_name": "VR Arch Single Model v5: UVR-BVE-4B_SN-44100", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/UVR-BVE-4B_SN-44100-1.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-BVE-4B_SN-44100-1.yaml?download=true" + }, + "uvr-bve-v2-4b-sn-44100": { + "category": "Караоке", + "id": 523, + "full_name": "VR Arch Single Model v4: UVR-BVE-v2-4B-SN-44100", + "stems": [ + "Vocals", + "Instrumental" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/RareSirMix/AIModelRehosting/resolve/main/UVR-5-1_4band_v4_ms_fullband_BVE_v2_by_aufr33.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-BVE-v2-4B-SN-44100.yaml?download=true" + }, + "mgm-v5-karokee-32000-beta1": { + "category": "Караоке", + "id": 524, + "full_name": "VR Arch Single Model v5: MGM-v5-KAROKEE-32000-BETA1", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://github.com/lucassantilli/UVR-Colab-GUI/releases/download/m5.1/MGM-v5-KAROKEE-32000-BETA1.pth", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/MGM-v5-KAROKEE-32000-BETA1.yaml?download=true" + }, + "mgm-v5-karokee-32000-beta2-agr": { + "category": "Караоке", + "id": 525, + "full_name": "VR Arch Single Model v5: MGM-v5-KAROKEE-32000-BETA2-AGR.pth", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://github.com/lucassantilli/UVR-Colab-GUI/releases/download/m5.1/MGM-v5-KAROKEE-32000-BETA2-AGR.pth", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/MGM-v5-KAROKEE-32000-BETA2-AGR.yaml?download=true" + }, + "mgm_highend_v4": { + "category": "Инструментал и вокал", + "id": 526, + "full_name": "VR Arch Single Model v4: MGM_HIGHEND_v4", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/MGM_HIGHEND_v4.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/MGM_HIGHEND_v4.yaml?download=true" + }, + "mgm_lowend_a_v4": { + "category": "Инструментал и вокал", + "id": 527, + "full_name": "VR Arch Single Model v4: MGM_LOWEND_A_v4", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/MGM_LOWEND_A_v4.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/MGM_LOWEND_A_v4.yaml?download=true" + }, + "mgm_lowend_b_v4": { + "category": "Инструментал и вокал", + "id": 528, + "full_name": "VR Arch Single Model v4: MGM_LOWEND_B_v4", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/MGM_LOWEND_B_v4.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/MGM_LOWEND_B_v4.yaml?download=true" + }, + "mgm_main_v4": { + "category": "Инструментал и вокал", + "id": 529, + "full_name": "VR Arch Single Model v4: MGM_MAIN_v4", + "stems": [ + "Instrumental", + "Vocals" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/MGM_MAIN_v4.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/MGM_MAIN_v4.yaml?download=true" + }, + "uvr-de-reverb-aufr33-jarredou": { + "category": "Реверб", + "id": 530, + "full_name": "VR Arch Single Model v4: UVR-De-Reverb by aufr33-jarredou", + "stems": [ + "Dry", + "No Dry" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/VR_Arch/UVR-De-Reverb-aufr33-jarredou.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-De-Reverb-aufr33-jarredou.yaml?download=true" + }, + "uvr-de-breath-sucial-v1": { + "category": "Дыхание", + "id": 531, + "full_name": "VR Arch Single Model v4: UVR-De-Breath v1 by Sucial", + "stems": [ + "Breath", + "No Breath" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Sucial/De-Breathe-Models/resolve/main/UVR_De-Breathe_1band_sr44100_hl1024_Sucial_v1.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-De-Breath-sucial-v1.yaml?download=true" + }, + "uvr-de-breath-sucial-v2": { + "category": "Дыхание", + "id": 532, + "full_name": "VR Arch Single Model v4: UVR-De-Breath v2 by Sucial", + "stems": [ + "Breath", + "No Breath" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Sucial/De-Breathe-Models/resolve/main/UVR_De-Breathe_1band_sr44100_hl1024_Sucial_v2.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/UVR-De-Breath-sucial-v2.yaml?download=true" + }, + "vr_harmonic_noise_sep": { + "category": "Дыхание", + "id": 533, + "full_name": "VR Arch Single Model v5: Harmonic_Noise_Sep", + "stems": [ + "Noise", + "No Noise" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/Sucial/MSST-WebUI/resolve/main/All_Models/VR_Models/Harmonic_Noise_Separation_yxlllc.pth?download=true", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/vr_configs/VR_Harmonic_Noise_Sep.yaml?download=true" + } + }, + "scnet": { + "scnet_4stem_zfturbo": { + "category": "4 стема", + "id": 400, + "full_name": "SCNet 4 Stems by ZFTurbo", + "stems": [ + "vocals", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.9/SCNet-large_starrytong_fixed.ckpt", + "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.9/config_musdb18_scnet_large_starrytong.yaml" + }, + "scnet_xl_ihf_4stem_zfturbo": { + "category": "4 стема", + "id": 401, + "full_name": "SCNet XL IHF 4 Stems by ZFTurbo", + "stems": [ + "vocals", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.15/model_scnet_ep_36_sdr_10.0891.ckpt", + "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.15/config_musdb18_scnet_xl_more_wide_v5.yaml" + }, + "scnet_xl_4stem_starrytong": { + "category": "4 стема", + "id": 402, + "full_name": "SCNet 4 Stems XL by StarryTong", + "stems": [ + "vocals", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.13/model_scnet_ep_54_sdr_9.8051.ckpt", + "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.13/config_musdb18_scnet_xl.yaml" + }, + "scnet_xl_4stem_zftrubo": { + "category": "4 стема", + "id": 403, + "full_name": "SCNet 4 Stems XL by ZFTurbo", + "stems": [ + "vocals", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.6/scnet_checkpoint_musdb18.ckpt", + "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.6/config_musdb18_scnet.yaml" + }, + "scnet_jazz_4stem_jorisvaneyghen": { + "category": "4 стема", + "id": 404, + "full_name": "SCNet Large Jazz model by Joris Vaneyghen", + "stems": [ + "piano", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/jorisvaneyghen/SCNet/resolve/main/model_jazz_scnet_large.ckpt?download=true", + "config_url": "https://huggingface.co/spaces/jorisvaneyghen/jazz_playalong/resolve/main/configs/config_jazz_scnet_large.yaml?download=true" + }, + "scnet_xl_jazz_4stem_jorisvaneyghen": { + "category": "4 стема", + "id": 405, + "full_name": "SCNet XL Jazz model by Joris Vaneyghen", + "stems": [ + "piano", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/jorisvaneyghen/SCNet/resolve/main/model_jazz_scnet_xl.ckpt?download=true", + "config_url": "https://huggingface.co/spaces/jorisvaneyghen/jazz_playalong/resolve/main/configs/config_jazz_scnet_xl.yaml?download=true" + }, + "scnet_choirsep_exp": { + "category": "Хор", + "id": 420, + "full_name": "SCNet Choirsep by concert.isolations.business@gmail.com", + "stems": [ + "alto", + "bass", + "tenor", + "soprano" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/scnet_choirsep/model_scnet_ep_36_sdr_5.4596.ckpt?download=true", + "config_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/scnet_choirsep/config_scnet_choirsep.yaml?download=true" + } + }, + "htdemucs": { + "demucs4_mvsep_vocals": { + "category": "Вокал", + "id": 2, + "full_name": "HTDemucs4 (MVSep finetuned)", + "stems": [ + "vocals", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.0/model_vocals_htdemucs_sdr_8.78.ckpt", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_vocals_htdemucs.yaml" + }, + "demucs4_4stem": { + "category": "4 стема", + "id": 0, + "full_name": "HTDemucs4", + "stems": [ + "vocals", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/955717e8-8726e21a.th", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml" + }, + "demucs4_6stem": { + "category": "6 стемов", + "id": 1, + "full_name": "HTDemucs4 (6 stems)", + "stems": [ + "vocals", + "drums", + "bass", + "other", + "guitar", + "piano" + ], + "target_instrument": null, + "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/5c90dfd2-34c22ccb.th", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_htdemucs_6stems.yaml" + }, + "demucs3_mmi": { + "category": "4 стема", + "id": 3, + "full_name": "Demucs3 mmi", + "stems": [ + "vocals", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/75fc33f5-1941ce65.th", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_demucs3_mmi.yaml" + }, + "demucs4_ft_bass": { + "category": "Басс", + "id": 4, + "full_name": "HTDemucs4 FT Bass", + "stems": [ + "vocals", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/d12395a8-e57c48e6.th", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml" + }, + "demucs4_ft_drums": { + "category": "Ударные", + "id": 5, + "full_name": "HTDemucs4 FT Drums", + "stems": [ + "vocals", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml" + }, + "demucs4_ft_vocals": { + "category": "Вокал", + "id": 6, + "full_name": "HTDemucs4 FT Vocals", + "stems": [ + "vocals", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/04573f0d-f3cf25b2.th", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml" + }, + "demucs4_ft_other": { + "category": "Прочее", + "id": 7, + "full_name": "HTDemucs4 FT Other", + "stems": [ + "vocals", + "drums", + "bass", + "other" + ], + "target_instrument": null, + "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/92cfc3b6-ef3bcb9c.th", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml" + }, + "demucs_mid_side_wesleyr36": { + "category": "Фантомный центр", + "id": 8, + "full_name": "HTDemucs4 MId-Side by wesleyr36", + "stems": [ + "similarity", + "difference" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/jarredou/HTDemucs_Similarity_Extractor_by_wesleyr36/resolve/main/model_htdemucs_ep_21_sdr_13.6970.ckpt?download=true", + "config_url": "https://huggingface.co/jarredou/HTDemucs_Similarity_Extractor_by_wesleyr36/resolve/main/config_htdemucs_similarity.yaml?download=true" + }, + "demucs4_choirsep": { + "category": "Хор", + "id": 9, + "full_name": "HTDemucs4 Choirsep by concert.isolations.business@gmail.com", + "stems": [ + "alto", + "bass", + "tenor", + "soprano" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/demucs_choirsep/model_htdemucs_ep_94_sdr_5.2474.ckpt?download=true", + "config_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/demucs_choirsep/config_htdemucs_choirsep.yaml?download=true" + } + }, + "bandit": { + "bandit_plus": { + "category": "Кинематограф", + "id": 10, + "full_name": "Bandit Plus: Cinematic Bandit Plus (by kwatcharasupat)", + "stems": [ + "speech", + "music", + "effects" + ], + "target_instrument": null, + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.3/model_bandit_plus_dnr_sdr_11.47.chpt", + "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.3/config_dnr_bandit_bsrnn_multi_mus64.yaml" + } + }, + "bandit_v2": { + "bandit_v2_multi": { + "category": "Кинематограф", + "id": 11, + "full_name": "Bandit v2: Cinematic Bandit v2 Multilang (by kwatcharasupat)", + "stems": [ + "speech", + "music", + "sfx" + ], + "target_instrument": null, + "checkpoint_url": "https://huggingface.co/jarredou/banditv2_state_dicts_only/resolve/main/checkpoint-multi_state_dict.ckpt", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/refs/heads/main/configs/config_dnr_bandit_v2_mus64.yaml" + } + } +} \ No newline at end of file diff --git a/mvsepless/models/bandit/core/__init__.py b/mvsepless/models/bandit/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..662d4075fc46d93e919b5dcdd3c215bcb8f57ef0 --- /dev/null +++ b/mvsepless/models/bandit/core/__init__.py @@ -0,0 +1,691 @@ +import os.path +from collections import defaultdict +from itertools import chain, combinations +from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Type, TypedDict + +import pytorch_lightning as pl +import torch +import torchaudio as ta +import torchmetrics as tm +from asteroid import losses as asteroid_losses + +# from deepspeed.ops.adam import DeepSpeedCPUAdam +# from geoopt import optim as gooptim +from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch import nn, optim +from torch.optim import lr_scheduler +from torch.optim.lr_scheduler import LRScheduler + +from . import loss, metrics as metrics_, model +from .data._types import BatchedDataDict +from .data.augmentation import BaseAugmentor, StemAugmentor +from .utils import audio as audio_ +from .utils.audio import BaseFader + +# from pandas.io.json._normalize import nested_to_record + +ConfigDict = TypedDict("ConfigDict", {"name": str, "kwargs": Dict[str, Any]}) + + +class SchedulerConfigDict(ConfigDict): + monitor: str + + +OptimizerSchedulerConfigDict = TypedDict( + "OptimizerSchedulerConfigDict", + {"optimizer": ConfigDict, "scheduler": SchedulerConfigDict}, + total=False, +) + + +class LRSchedulerReturnDict(TypedDict, total=False): + scheduler: LRScheduler + monitor: str + + +class ConfigureOptimizerReturnDict(TypedDict, total=False): + optimizer: torch.optim.Optimizer + lr_scheduler: LRSchedulerReturnDict + + +OutputType = Dict[str, Any] +MetricsType = Dict[str, torch.Tensor] + + +def get_optimizer_class(name: str) -> Type[optim.Optimizer]: + + if name == "DeepSpeedCPUAdam": + return DeepSpeedCPUAdam + + for module in [optim, gooptim]: + if name in module.__dict__: + return module.__dict__[name] + + raise NameError + + +def parse_optimizer_config( + config: OptimizerSchedulerConfigDict, parameters: Iterator[nn.Parameter] +) -> ConfigureOptimizerReturnDict: + optim_class = get_optimizer_class(config["optimizer"]["name"]) + optimizer = optim_class(parameters, **config["optimizer"]["kwargs"]) + + optim_dict: ConfigureOptimizerReturnDict = { + "optimizer": optimizer, + } + + if "scheduler" in config: + + lr_scheduler_class_ = config["scheduler"]["name"] + lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_] + lr_scheduler_dict: LRSchedulerReturnDict = { + "scheduler": lr_scheduler_class(optimizer, **config["scheduler"]["kwargs"]) + } + + if lr_scheduler_class_ == "ReduceLROnPlateau": + lr_scheduler_dict["monitor"] = config["scheduler"]["monitor"] + + optim_dict["lr_scheduler"] = lr_scheduler_dict + + return optim_dict + + +def parse_model_config(config: ConfigDict) -> Any: + name = config["name"] + + for module in [model]: + if name in module.__dict__: + return module.__dict__[name](**config["kwargs"]) + + raise NameError + + +_LEGACY_LOSS_NAMES = ["HybridL1Loss"] + + +def _parse_legacy_loss_config(config: ConfigDict) -> nn.Module: + name = config["name"] + + if name == "HybridL1Loss": + return loss.TimeFreqL1Loss(**config["kwargs"]) + + raise NameError + + +def parse_loss_config(config: ConfigDict) -> nn.Module: + name = config["name"] + + if name in _LEGACY_LOSS_NAMES: + return _parse_legacy_loss_config(config) + + for module in [loss, nn.modules.loss, asteroid_losses]: + if name in module.__dict__: + # print(config["kwargs"]) + return module.__dict__[name](**config["kwargs"]) + + raise NameError + + +def get_metric(config: ConfigDict) -> tm.Metric: + name = config["name"] + + for module in [tm, metrics_]: + if name in module.__dict__: + return module.__dict__[name](**config["kwargs"]) + raise NameError + + +def parse_metric_config(config: Dict[str, ConfigDict]) -> tm.MetricCollection: + metrics = {} + + for metric in config: + metrics[metric] = get_metric(config[metric]) + + return tm.MetricCollection(metrics) + + +def parse_fader_config(config: ConfigDict) -> BaseFader: + name = config["name"] + + for module in [audio_]: + if name in module.__dict__: + return module.__dict__[name](**config["kwargs"]) + + raise NameError + + +class LightningSystem(pl.LightningModule): + _VOX_STEMS = ["speech", "vocals"] + _BG_STEMS = ["background", "effects", "mne"] + + def __init__( + self, config: Dict, loss_adjustment: float = 1.0, attach_fader: bool = False + ) -> None: + super().__init__() + self.optimizer_config = config["optimizer"] + self.model = parse_model_config(config["model"]) + self.loss = parse_loss_config(config["loss"]) + self.metrics = nn.ModuleDict( + { + stem: parse_metric_config(config["metrics"]["dev"]) + for stem in self.model.stems + } + ) + + self.metrics.disallow_fsdp = True + + self.test_metrics = nn.ModuleDict( + { + stem: parse_metric_config(config["metrics"]["test"]) + for stem in self.model.stems + } + ) + + self.test_metrics.disallow_fsdp = True + + self.fs = config["model"]["kwargs"]["fs"] + + self.fader_config = config["inference"]["fader"] + if attach_fader: + self.fader = parse_fader_config(config["inference"]["fader"]) + else: + self.fader = None + + self.augmentation: Optional[BaseAugmentor] + if config.get("augmentation", None) is not None: + self.augmentation = StemAugmentor(**config["augmentation"]) + else: + self.augmentation = None + + self.predict_output_path: Optional[str] = None + self.loss_adjustment = loss_adjustment + + self.val_prefix = None + self.test_prefix = None + + def configure_optimizers(self) -> Any: + return parse_optimizer_config( + self.optimizer_config, self.trainer.model.parameters() + ) + + def compute_loss( + self, batch: BatchedDataDict, output: OutputType + ) -> Dict[str, torch.Tensor]: + return {"loss": self.loss(output, batch)} + + def update_metrics( + self, batch: BatchedDataDict, output: OutputType, mode: str + ) -> None: + + if mode == "test": + metrics = self.test_metrics + else: + metrics = self.metrics + + for stem, metric in metrics.items(): + + if stem == "mne:+": + stem = "mne" + + # print(f"matching for {stem}") + if mode == "train": + metric.update( + output["audio"][stem], # .cpu(), + batch["audio"][stem], # .cpu() + ) + else: + if stem not in batch["audio"]: + matched = False + if stem in self._VOX_STEMS: + for bstem in self._VOX_STEMS: + if bstem in batch["audio"]: + batch["audio"][stem] = batch["audio"][bstem] + matched = True + break + elif stem in self._BG_STEMS: + for bstem in self._BG_STEMS: + if bstem in batch["audio"]: + batch["audio"][stem] = batch["audio"][bstem] + matched = True + break + else: + matched = True + + # print(batch["audio"].keys()) + + if matched: + # print(f"matched {stem}!") + if stem == "mne" and "mne" not in output["audio"]: + output["audio"]["mne"] = ( + output["audio"]["music"] + output["audio"]["effects"] + ) + + metric.update( + output["audio"][stem], # .cpu(), + batch["audio"][stem], # .cpu(), + ) + + # print(metric.compute()) + + def compute_metrics(self, mode: str = "dev") -> Dict[str, torch.Tensor]: + + if mode == "test": + metrics = self.test_metrics + else: + metrics = self.metrics + + metric_dict = {} + + for stem, metric in metrics.items(): + md = metric.compute() + metric_dict.update({f"{stem}/{k}": v for k, v in md.items()}) + + self.log_dict(metric_dict, prog_bar=True, logger=False) + + return metric_dict + + def reset_metrics(self, test_mode: bool = False) -> None: + + if test_mode: + metrics = self.test_metrics + else: + metrics = self.metrics + + for _, metric in metrics.items(): + metric.reset() + + def forward(self, batch: BatchedDataDict) -> Any: + batch, output = self.model(batch) + + return batch, output + + def common_step(self, batch: BatchedDataDict, mode: str) -> Any: + batch, output = self.forward(batch) + # print(batch) + # print(output) + loss_dict = self.compute_loss(batch, output) + + with torch.no_grad(): + self.update_metrics(batch, output, mode=mode) + + if mode == "train": + self.log("loss", loss_dict["loss"], prog_bar=True) + + return output, loss_dict + + def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]: + + if self.augmentation is not None: + with torch.no_grad(): + batch = self.augmentation(batch) + + _, loss_dict = self.common_step(batch, mode="train") + + with torch.inference_mode(): + self.log_dict_with_prefix( + loss_dict, "train", batch_size=batch["audio"]["mixture"].shape[0] + ) + + loss_dict["loss"] *= self.loss_adjustment + + return loss_dict + + def on_train_batch_end( + self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int + ) -> None: + + metric_dict = self.compute_metrics() + self.log_dict_with_prefix(metric_dict, "train") + self.reset_metrics() + + def validation_step( + self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0 + ) -> Dict[str, Any]: + + with torch.inference_mode(): + curr_val_prefix = f"val{dataloader_idx}" if dataloader_idx > 0 else "val" + + if curr_val_prefix != self.val_prefix: + # print(f"Switching to validation dataloader {dataloader_idx}") + if self.val_prefix is not None: + self._on_validation_epoch_end() + self.val_prefix = curr_val_prefix + _, loss_dict = self.common_step(batch, mode="val") + + self.log_dict_with_prefix( + loss_dict, + self.val_prefix, + batch_size=batch["audio"]["mixture"].shape[0], + prog_bar=True, + add_dataloader_idx=False, + ) + + return loss_dict + + def on_validation_epoch_end(self) -> None: + self._on_validation_epoch_end() + + def _on_validation_epoch_end(self) -> None: + metric_dict = self.compute_metrics() + self.log_dict_with_prefix( + metric_dict, self.val_prefix, prog_bar=True, add_dataloader_idx=False + ) + # self.logger.save() + # print(self.val_prefix, "Validation metrics:", metric_dict) + self.reset_metrics() + + def old_predtest_step( + self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0 + ) -> Tuple[BatchedDataDict, OutputType]: + + audio_batch = batch["audio"]["mixture"] + track_batch = batch.get("track", ["" for _ in range(len(audio_batch))]) + + output_list_of_dicts = [ + self.fader(audio[None, ...], lambda a: self.test_forward(a, track)) + for audio, track in zip(audio_batch, track_batch) + ] + + output_dict_of_lists = defaultdict(list) + + for output_dict in output_list_of_dicts: + for stem, audio in output_dict.items(): + output_dict_of_lists[stem].append(audio) + + output = { + "audio": { + stem: torch.concat(output_list, dim=0) + for stem, output_list in output_dict_of_lists.items() + } + } + + return batch, output + + def predtest_step( + self, batch: BatchedDataDict, batch_idx: int = -1, dataloader_idx: int = 0 + ) -> Tuple[BatchedDataDict, OutputType]: + + if getattr(self.model, "bypass_fader", False): + batch, output = self.model(batch) + else: + audio_batch = batch["audio"]["mixture"] + output = self.fader( + audio_batch, lambda a: self.test_forward(a, "", batch=batch) + ) + + return batch, output + + def test_forward( + self, audio: torch.Tensor, track: str = "", batch: BatchedDataDict = None + ) -> torch.Tensor: + + if self.fader is None: + self.attach_fader() + + cond = batch.get("condition", None) + + if cond is not None and cond.shape[0] == 1: + cond = cond.repeat(audio.shape[0], 1) + + _, output = self.forward( + { + "audio": {"mixture": audio}, + "track": track, + "condition": cond, + } + ) # TODO: support track properly + + return output["audio"] + + def on_test_epoch_start(self) -> None: + self.attach_fader(force_reattach=True) + + def test_step( + self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0 + ) -> Any: + curr_test_prefix = f"test{dataloader_idx}" + + # print(batch["audio"].keys()) + + if curr_test_prefix != self.test_prefix: + # print(f"Switching to test dataloader {dataloader_idx}") + if self.test_prefix is not None: + self._on_test_epoch_end() + self.test_prefix = curr_test_prefix + + with torch.inference_mode(): + _, output = self.predtest_step(batch, batch_idx, dataloader_idx) + # print(output) + self.update_metrics(batch, output, mode="test") + + return output + + def on_test_epoch_end(self) -> None: + self._on_test_epoch_end() + + def _on_test_epoch_end(self) -> None: + metric_dict = self.compute_metrics(mode="test") + self.log_dict_with_prefix( + metric_dict, self.test_prefix, prog_bar=True, add_dataloader_idx=False + ) + # self.logger.save() + # print(self.test_prefix, "Test metrics:", metric_dict) + self.reset_metrics() + + def predict_step( + self, + batch: BatchedDataDict, + batch_idx: int = 0, + dataloader_idx: int = 0, + include_track_name: Optional[bool] = None, + get_no_vox_combinations: bool = True, + get_residual: bool = False, + treat_batch_as_channels: bool = False, + fs: Optional[int] = None, + ) -> Any: + assert self.predict_output_path is not None + + batch_size = batch["audio"]["mixture"].shape[0] + + if include_track_name is None: + include_track_name = batch_size > 1 + + with torch.inference_mode(): + batch, output = self.predtest_step(batch, batch_idx, dataloader_idx) + print("Pred test finished...") + torch.cuda.empty_cache() + metric_dict = {} + + if get_residual: + mixture = batch["audio"]["mixture"] + extracted = sum([output["audio"][stem] for stem in output["audio"]]) + residual = mixture - extracted + print(extracted.shape, mixture.shape, residual.shape) + + output["audio"]["residual"] = residual + + if get_no_vox_combinations: + no_vox_stems = [ + stem for stem in output["audio"] if stem not in self._VOX_STEMS + ] + no_vox_combinations = chain.from_iterable( + combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1) + ) + + for combination in no_vox_combinations: + combination_ = list(combination) + output["audio"]["+".join(combination_)] = sum( + [output["audio"][stem] for stem in combination_] + ) + + if treat_batch_as_channels: + for stem in output["audio"]: + output["audio"][stem] = output["audio"][stem].reshape( + 1, -1, output["audio"][stem].shape[-1] + ) + batch_size = 1 + + for b in range(batch_size): + print("!!", b) + for stem in output["audio"]: + print(f"Saving audio for {stem} to {self.predict_output_path}") + track_name = batch["track"][b].split("/")[-1] + + if batch.get("audio", {}).get(stem, None) is not None: + self.test_metrics[stem].reset() + metrics = self.test_metrics[stem]( + batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...] + ) + snr = metrics["snr"] + sisnr = metrics["sisnr"] + sdr = metrics["sdr"] + metric_dict[stem] = metrics + print( + track_name, + f"snr={snr:2.2f} dB", + f"sisnr={sisnr:2.2f}", + f"sdr={sdr:2.2f} dB", + ) + filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav" + else: + filename = f"{stem}.wav" + + if include_track_name: + output_dir = os.path.join(self.predict_output_path, track_name) + else: + output_dir = self.predict_output_path + + os.makedirs(output_dir, exist_ok=True) + + if fs is None: + fs = self.fs + + ta.save( + os.path.join(output_dir, filename), + output["audio"][stem][b, ...].cpu(), + fs, + ) + + return metric_dict + + def get_stems( + self, + batch: BatchedDataDict, + batch_idx: int = 0, + dataloader_idx: int = 0, + include_track_name: Optional[bool] = None, + get_no_vox_combinations: bool = True, + get_residual: bool = False, + treat_batch_as_channels: bool = False, + fs: Optional[int] = None, + ) -> Any: + assert self.predict_output_path is not None + + batch_size = batch["audio"]["mixture"].shape[0] + + if include_track_name is None: + include_track_name = batch_size > 1 + + with torch.inference_mode(): + batch, output = self.predtest_step(batch, batch_idx, dataloader_idx) + torch.cuda.empty_cache() + metric_dict = {} + + if get_residual: + mixture = batch["audio"]["mixture"] + extracted = sum([output["audio"][stem] for stem in output["audio"]]) + residual = mixture - extracted + # print(extracted.shape, mixture.shape, residual.shape) + + output["audio"]["residual"] = residual + + if get_no_vox_combinations: + no_vox_stems = [ + stem for stem in output["audio"] if stem not in self._VOX_STEMS + ] + no_vox_combinations = chain.from_iterable( + combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1) + ) + + for combination in no_vox_combinations: + combination_ = list(combination) + output["audio"]["+".join(combination_)] = sum( + [output["audio"][stem] for stem in combination_] + ) + + if treat_batch_as_channels: + for stem in output["audio"]: + output["audio"][stem] = output["audio"][stem].reshape( + 1, -1, output["audio"][stem].shape[-1] + ) + batch_size = 1 + + result = {} + for b in range(batch_size): + for stem in output["audio"]: + track_name = batch["track"][b].split("/")[-1] + + if batch.get("audio", {}).get(stem, None) is not None: + self.test_metrics[stem].reset() + metrics = self.test_metrics[stem]( + batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...] + ) + snr = metrics["snr"] + sisnr = metrics["sisnr"] + sdr = metrics["sdr"] + metric_dict[stem] = metrics + print( + track_name, + f"snr={snr:2.2f} dB", + f"sisnr={sisnr:2.2f}", + f"sdr={sdr:2.2f} dB", + ) + filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav" + else: + filename = f"{stem}.wav" + + if include_track_name: + output_dir = os.path.join(self.predict_output_path, track_name) + else: + output_dir = self.predict_output_path + + os.makedirs(output_dir, exist_ok=True) + + if fs is None: + fs = self.fs + + result[stem] = output["audio"][stem][b, ...].cpu().numpy() + + return result + + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = False + ) -> Any: + + return super().load_state_dict(state_dict, strict=False) + + def set_predict_output_path(self, path: str) -> None: + self.predict_output_path = path + os.makedirs(self.predict_output_path, exist_ok=True) + + self.attach_fader() + + def attach_fader(self, force_reattach=False) -> None: + if self.fader is None or force_reattach: + self.fader = parse_fader_config(self.fader_config) + self.fader.to(self.device) + + def log_dict_with_prefix( + self, + dict_: Dict[str, torch.Tensor], + prefix: str, + batch_size: Optional[int] = None, + **kwargs: Any, + ) -> None: + self.log_dict( + {f"{prefix}/{k}": v for k, v in dict_.items()}, + batch_size=batch_size, + logger=True, + sync_dist=True, + **kwargs, + ) diff --git a/mvsepless/models/bandit/core/data/__init__.py b/mvsepless/models/bandit/core/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9d4d672bd3b6ad90a26e19ee6c26e02ee3be84c --- /dev/null +++ b/mvsepless/models/bandit/core/data/__init__.py @@ -0,0 +1,2 @@ +from .dnr.datamodule import DivideAndRemasterDataModule +from .musdb.datamodule import MUSDB18DataModule diff --git a/mvsepless/models/bandit/core/data/_types.py b/mvsepless/models/bandit/core/data/_types.py new file mode 100644 index 0000000000000000000000000000000000000000..65e4607a558e6b6a65ee68de883b69e282f8fcf4 --- /dev/null +++ b/mvsepless/models/bandit/core/data/_types.py @@ -0,0 +1,17 @@ +from typing import Dict, Sequence, TypedDict + +import torch + +AudioDict = Dict[str, torch.Tensor] + +DataDict = TypedDict("DataDict", {"audio": AudioDict, "track": str}) + +BatchedDataDict = TypedDict( + "BatchedDataDict", {"audio": AudioDict, "track": Sequence[str]} +) + + +class DataDictWithLanguage(TypedDict): + audio: AudioDict + track: str + language: str diff --git a/mvsepless/models/bandit/core/data/augmentation.py b/mvsepless/models/bandit/core/data/augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..cc2ec18ef256cf266a81d7efae9a5fb902ae1d5a --- /dev/null +++ b/mvsepless/models/bandit/core/data/augmentation.py @@ -0,0 +1,102 @@ +from abc import ABC +from typing import Any, Dict, Union + +import torch +import torch_audiomentations as tam +from torch import nn + +from ._types import BatchedDataDict, DataDict + + +class BaseAugmentor(nn.Module, ABC): + def forward( + self, item: Union[DataDict, BatchedDataDict] + ) -> Union[DataDict, BatchedDataDict]: + raise NotImplementedError + + +class StemAugmentor(BaseAugmentor): + def __init__( + self, + audiomentations: Dict[str, Dict[str, Any]], + fix_clipping: bool = True, + scaler_margin: float = 0.5, + apply_both_default_and_common: bool = False, + ) -> None: + super().__init__() + + augmentations = {} + + self.has_default = "[default]" in audiomentations + self.has_common = "[common]" in audiomentations + self.apply_both_default_and_common = apply_both_default_and_common + + for stem in audiomentations: + if audiomentations[stem]["name"] == "Compose": + augmentations[stem] = getattr(tam, audiomentations[stem]["name"])( + [ + getattr(tam, aug["name"])(**aug["kwargs"]) + for aug in audiomentations[stem]["kwargs"]["transforms"] + ], + **audiomentations[stem]["kwargs"]["kwargs"], + ) + else: + augmentations[stem] = getattr(tam, audiomentations[stem]["name"])( + **audiomentations[stem]["kwargs"] + ) + + self.augmentations = nn.ModuleDict(augmentations) + self.fix_clipping = fix_clipping + self.scaler_margin = scaler_margin + + def check_and_fix_clipping( + self, item: Union[DataDict, BatchedDataDict] + ) -> Union[DataDict, BatchedDataDict]: + max_abs = [] + + for stem in item["audio"]: + max_abs.append(item["audio"][stem].abs().max().item()) + + if max(max_abs) > 1.0: + scaler = 1.0 / ( + max(max_abs) + + torch.rand((1,), device=item["audio"]["mixture"].device) + * self.scaler_margin + ) + + for stem in item["audio"]: + item["audio"][stem] *= scaler + + return item + + def forward( + self, item: Union[DataDict, BatchedDataDict] + ) -> Union[DataDict, BatchedDataDict]: + + for stem in item["audio"]: + if stem == "mixture": + continue + + if self.has_common: + item["audio"][stem] = self.augmentations["[common]"]( + item["audio"][stem] + ).samples + + if stem in self.augmentations: + item["audio"][stem] = self.augmentations[stem]( + item["audio"][stem] + ).samples + elif self.has_default: + if not self.has_common or self.apply_both_default_and_common: + item["audio"][stem] = self.augmentations["[default]"]( + item["audio"][stem] + ).samples + + item["audio"]["mixture"] = sum( + [item["audio"][stem] for stem in item["audio"] if stem != "mixture"] + ) # type: ignore[call-overload, assignment] + + if self.fix_clipping: + item = self.check_and_fix_clipping(item) + + return item diff --git a/mvsepless/models/bandit/core/data/augmented.py b/mvsepless/models/bandit/core/data/augmented.py new file mode 100644 index 0000000000000000000000000000000000000000..3c0524409bf99b009605989eba5d5f46f0560f2e --- /dev/null +++ b/mvsepless/models/bandit/core/data/augmented.py @@ -0,0 +1,34 @@ +import warnings +from typing import Dict, Optional, Union + +import torch +from torch import nn +from torch.utils import data + + +class AugmentedDataset(data.Dataset): + def __init__( + self, + dataset: data.Dataset, + augmentation: nn.Module = nn.Identity(), + target_length: Optional[int] = None, + ) -> None: + warnings.warn( + "This class is no longer used. Attach augmentation to " + "the LightningSystem instead.", + DeprecationWarning, + ) + + self.dataset = dataset + self.augmentation = augmentation + + self.ds_length: int = len(dataset) # type: ignore[arg-type] + self.length = target_length if target_length is not None else self.ds_length + + def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str, torch.Tensor]]]: + item = self.dataset[index % self.ds_length] + item = self.augmentation(item) + return item + + def __len__(self) -> int: + return self.length diff --git a/mvsepless/models/bandit/core/data/base.py b/mvsepless/models/bandit/core/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..3c85194f613d2043a26406624da5cb83cb04033a --- /dev/null +++ b/mvsepless/models/bandit/core/data/base.py @@ -0,0 +1,60 @@ +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +import numpy as np +import pedalboard as pb +import torch +import torchaudio as ta +from torch.utils import data + +from ._types import AudioDict, DataDict + + +class BaseSourceSeparationDataset(data.Dataset, ABC): + def __init__( + self, + split: str, + stems: List[str], + files: List[str], + data_path: str, + fs: int, + npy_memmap: bool, + recompute_mixture: bool, + ): + self.split = split + self.stems = stems + self.stems_no_mixture = [s for s in stems if s != "mixture"] + self.files = files + self.data_path = data_path + self.fs = fs + self.npy_memmap = npy_memmap + self.recompute_mixture = recompute_mixture + + @abstractmethod + def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor: + raise NotImplementedError + + def _get_audio(self, stems, identifier: Dict[str, Any]): + audio = {} + for stem in stems: + audio[stem] = self.get_stem(stem=stem, identifier=identifier) + + return audio + + def get_audio(self, identifier: Dict[str, Any]) -> AudioDict: + + if self.recompute_mixture: + audio = self._get_audio(self.stems_no_mixture, identifier=identifier) + audio["mixture"] = self.compute_mixture(audio) + return audio + else: + return self._get_audio(self.stems, identifier=identifier) + + @abstractmethod + def get_identifier(self, index: int) -> Dict[str, Any]: + pass + + def compute_mixture(self, audio: AudioDict) -> torch.Tensor: + + return sum(audio[stem] for stem in audio if stem != "mixture") diff --git a/mvsepless/models/bandit/core/data/dnr/__init__.py b/mvsepless/models/bandit/core/data/dnr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mvsepless/models/bandit/core/data/dnr/datamodule.py b/mvsepless/models/bandit/core/data/dnr/datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..2971d419d433e335668f9e52cf54afda55a48f88 --- /dev/null +++ b/mvsepless/models/bandit/core/data/dnr/datamodule.py @@ -0,0 +1,68 @@ +import os +from typing import Mapping, Optional + +import pytorch_lightning as pl + +from .dataset import ( + DivideAndRemasterDataset, + DivideAndRemasterDeterministicChunkDataset, + DivideAndRemasterRandomChunkDataset, + DivideAndRemasterRandomChunkDatasetWithSpeechReverb, +) + + +def DivideAndRemasterDataModule( + data_root: str = "$DATA_ROOT/DnR/v2", + batch_size: int = 2, + num_workers: int = 8, + train_kwargs: Optional[Mapping] = None, + val_kwargs: Optional[Mapping] = None, + test_kwargs: Optional[Mapping] = None, + datamodule_kwargs: Optional[Mapping] = None, + use_speech_reverb: bool = False, + # augmentor=None +) -> pl.LightningDataModule: + if train_kwargs is None: + train_kwargs = {} + + if val_kwargs is None: + val_kwargs = {} + + if test_kwargs is None: + test_kwargs = {} + + if datamodule_kwargs is None: + datamodule_kwargs = {} + + if num_workers is None: + num_workers = os.cpu_count() + + if num_workers is None: + num_workers = 32 + + num_workers = min(num_workers, 64) + + if use_speech_reverb: + train_cls = DivideAndRemasterRandomChunkDatasetWithSpeechReverb + else: + train_cls = DivideAndRemasterRandomChunkDataset + + train_dataset = train_cls(data_root, "train", **train_kwargs) + + # if augmentor is not None: + # train_dataset = AugmentedDataset(train_dataset, augmentor) + + datamodule = pl.LightningDataModule.from_datasets( + train_dataset=train_dataset, + val_dataset=DivideAndRemasterDeterministicChunkDataset( + data_root, "val", **val_kwargs + ), + test_dataset=DivideAndRemasterDataset(data_root, "test", **test_kwargs), + batch_size=batch_size, + num_workers=num_workers, + **datamodule_kwargs + ) + + datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign] + + return datamodule diff --git a/mvsepless/models/bandit/core/data/dnr/dataset.py b/mvsepless/models/bandit/core/data/dnr/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8ae6dba590312c83b61addec26284c8a9979c1a0 --- /dev/null +++ b/mvsepless/models/bandit/core/data/dnr/dataset.py @@ -0,0 +1,366 @@ +import os +from abc import ABC +from typing import Any, Dict, List, Optional + +import numpy as np +import pedalboard as pb +import torch +import torchaudio as ta +from torch.utils import data + +from .._types import AudioDict, DataDict +from ..base import BaseSourceSeparationDataset + + +class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC): + ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"] + STEM_NAME_MAP = { + "mixture": "mix", + "speech": "speech", + "music": "music", + "effects": "sfx", + } + SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"} + + FULL_TRACK_LENGTH_SECOND = 60 + FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100 + + def __init__( + self, + split: str, + stems: List[str], + files: List[str], + data_path: str, + fs: int = 44100, + npy_memmap: bool = True, + recompute_mixture: bool = False, + ) -> None: + super().__init__( + split=split, + stems=stems, + files=files, + data_path=data_path, + fs=fs, + npy_memmap=npy_memmap, + recompute_mixture=recompute_mixture, + ) + + def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor: + + if stem == "mne": + return self.get_stem(stem="music", identifier=identifier) + self.get_stem( + stem="effects", identifier=identifier + ) + + track = identifier["track"] + path = os.path.join(self.data_path, track) + + if self.npy_memmap: + audio = np.load( + os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"), mmap_mode="r" + ) + else: + # noinspection PyUnresolvedReferences + audio, _ = ta.load(os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav")) + + return audio + + def get_identifier(self, index): + return dict(track=self.files[index]) + + def __getitem__(self, index: int) -> DataDict: + identifier = self.get_identifier(index) + audio = self.get_audio(identifier) + + return {"audio": audio, "track": f"{self.split}/{identifier['track']}"} + + +class DivideAndRemasterDataset(DivideAndRemasterBaseDataset): + def __init__( + self, + data_root: str, + split: str, + stems: Optional[List[str]] = None, + fs: int = 44100, + npy_memmap: bool = True, + ) -> None: + + if stems is None: + stems = self.ALLOWED_STEMS + self.stems = stems + + data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split]) + + files = sorted(os.listdir(data_path)) + files = [ + f + for f in files + if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f)) + ] + # pprint(list(enumerate(files))) + if split == "train": + assert len(files) == 3406, len(files) + elif split == "val": + assert len(files) == 487, len(files) + elif split == "test": + assert len(files) == 973, len(files) + + self.n_tracks = len(files) + + super().__init__( + data_path=data_path, + split=split, + stems=stems, + files=files, + fs=fs, + npy_memmap=npy_memmap, + ) + + def __len__(self) -> int: + return self.n_tracks + + +class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset): + def __init__( + self, + data_root: str, + split: str, + target_length: int, + chunk_size_second: float, + stems: Optional[List[str]] = None, + fs: int = 44100, + npy_memmap: bool = True, + ) -> None: + + if stems is None: + stems = self.ALLOWED_STEMS + self.stems = stems + + data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split]) + + files = sorted(os.listdir(data_path)) + files = [ + f + for f in files + if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f)) + ] + + if split == "train": + assert len(files) == 3406, len(files) + elif split == "val": + assert len(files) == 487, len(files) + elif split == "test": + assert len(files) == 973, len(files) + + self.n_tracks = len(files) + + self.target_length = target_length + self.chunk_size = int(chunk_size_second * fs) + + super().__init__( + data_path=data_path, + split=split, + stems=stems, + files=files, + fs=fs, + npy_memmap=npy_memmap, + ) + + def __len__(self) -> int: + return self.target_length + + def get_identifier(self, index): + return super().get_identifier(index % self.n_tracks) + + def get_stem( + self, + *, + stem: str, + identifier: Dict[str, Any], + chunk_here: bool = False, + ) -> torch.Tensor: + + stem = super().get_stem(stem=stem, identifier=identifier) + + if chunk_here: + start = np.random.randint( + 0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size + ) + end = start + self.chunk_size + + stem = stem[:, start:end] + + return stem + + def __getitem__(self, index: int) -> DataDict: + identifier = self.get_identifier(index) + # self.index_lock = index + audio = self.get_audio(identifier) + # self.index_lock = None + + start = np.random.randint(0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size) + end = start + self.chunk_size + + audio = {k: v[:, start:end] for k, v in audio.items()} + + return {"audio": audio, "track": f"{self.split}/{identifier['track']}"} + + +class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset): + def __init__( + self, + data_root: str, + split: str, + chunk_size_second: float, + hop_size_second: float, + stems: Optional[List[str]] = None, + fs: int = 44100, + npy_memmap: bool = True, + ) -> None: + + if stems is None: + stems = self.ALLOWED_STEMS + self.stems = stems + + data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split]) + + files = sorted(os.listdir(data_path)) + files = [ + f + for f in files + if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f)) + ] + # pprint(list(enumerate(files))) + if split == "train": + assert len(files) == 3406, len(files) + elif split == "val": + assert len(files) == 487, len(files) + elif split == "test": + assert len(files) == 973, len(files) + + self.n_tracks = len(files) + + self.chunk_size = int(chunk_size_second * fs) + self.hop_size = int(hop_size_second * fs) + self.n_chunks_per_track = int( + (self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second + ) + + self.length = self.n_tracks * self.n_chunks_per_track + + super().__init__( + data_path=data_path, + split=split, + stems=stems, + files=files, + fs=fs, + npy_memmap=npy_memmap, + ) + + def get_identifier(self, index): + return super().get_identifier(index % self.n_tracks) + + def __len__(self) -> int: + return self.length + + def __getitem__(self, item: int) -> DataDict: + + index = item % self.n_tracks + chunk = item // self.n_tracks + + data_ = super().__getitem__(index) + + audio = data_["audio"] + + start = chunk * self.hop_size + end = start + self.chunk_size + + for stem in self.stems: + data_["audio"][stem] = audio[stem][:, start:end] + + return data_ + + +class DivideAndRemasterRandomChunkDatasetWithSpeechReverb( + DivideAndRemasterRandomChunkDataset +): + def __init__( + self, + data_root: str, + split: str, + target_length: int, + chunk_size_second: float, + stems: Optional[List[str]] = None, + fs: int = 44100, + npy_memmap: bool = True, + ) -> None: + + if stems is None: + stems = self.ALLOWED_STEMS + + stems_no_mixture = [s for s in stems if s != "mixture"] + + super().__init__( + data_root=data_root, + split=split, + target_length=target_length, + chunk_size_second=chunk_size_second, + stems=stems_no_mixture, + fs=fs, + npy_memmap=npy_memmap, + ) + + self.stems = stems + self.stems_no_mixture = stems_no_mixture + + def __getitem__(self, index: int) -> DataDict: + + data_ = super().__getitem__(index) + + dry = data_["audio"]["speech"][:] + n_samples = dry.shape[-1] + + wet_level = np.random.rand() + + speech = pb.Reverb( + room_size=np.random.rand(), + damping=np.random.rand(), + wet_level=wet_level, + dry_level=(1 - wet_level), + width=np.random.rand(), + ).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples] + + data_["audio"]["speech"] = speech + + data_["audio"]["mixture"] = sum( + [data_["audio"][s] for s in self.stems_no_mixture] + ) + + return data_ + + def __len__(self) -> int: + return super().__len__() + + +if __name__ == "__main__": + + from pprint import pprint + from tqdm.auto import tqdm + + for split_ in ["train", "val", "test"]: + ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb( + data_root="$DATA_ROOT/DnR/v2np", + split=split_, + target_length=100, + chunk_size_second=6.0, + ) + + print(split_, len(ds)) + + for track_ in tqdm(ds): # type: ignore + pprint(track_) + track_["audio"] = {k: v.shape for k, v in track_["audio"].items()} + pprint(track_) + # break + + break diff --git a/mvsepless/models/bandit/core/data/dnr/preprocess.py b/mvsepless/models/bandit/core/data/dnr/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..18d68b18fbe963647df1253190625ea639035572 --- /dev/null +++ b/mvsepless/models/bandit/core/data/dnr/preprocess.py @@ -0,0 +1,51 @@ +import glob +import os +from typing import Tuple + +import numpy as np +import torchaudio as ta +from tqdm.contrib.concurrent import process_map + + +def process_one(inputs: Tuple[str, str, int]) -> None: + infile, outfile, target_fs = inputs + + dir = os.path.dirname(outfile) + os.makedirs(dir, exist_ok=True) + + data, fs = ta.load(infile) + + if fs != target_fs: + data = ta.functional.resample( + data, fs, target_fs, resampling_method="sinc_interp_kaiser" + ) + fs = target_fs + + data = data.numpy() + data = data.astype(np.float32) + + if os.path.exists(outfile): + data_ = np.load(outfile) + if np.allclose(data, data_): + return + + np.save(outfile, data) + + +def preprocess(data_path: str, output_path: str, fs: int) -> None: + files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True) + print(files) + outfiles = [ + f.replace(data_path, output_path).replace(".wav", ".npy") for f in files + ] + + os.makedirs(output_path, exist_ok=True) + inputs = list(zip(files, outfiles, [fs] * len(files))) + + process_map(process_one, inputs, chunksize=32) + + +if __name__ == "__main__": + import fire + + fire.Fire() diff --git a/mvsepless/models/bandit/core/data/musdb/__init__.py b/mvsepless/models/bandit/core/data/musdb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mvsepless/models/bandit/core/data/musdb/datamodule.py b/mvsepless/models/bandit/core/data/musdb/datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba2cb25e3fd7bbaf120b0d1cde90adc26d9d1df --- /dev/null +++ b/mvsepless/models/bandit/core/data/musdb/datamodule.py @@ -0,0 +1,75 @@ +import os.path +from typing import Mapping, Optional + +import pytorch_lightning as pl + +from .dataset import ( + MUSDB18BaseDataset, + MUSDB18FullTrackDataset, + MUSDB18SadDataset, + MUSDB18SadOnTheFlyAugmentedDataset, +) + + +def MUSDB18DataModule( + data_root: str = "$DATA_ROOT/MUSDB18/HQ", + target_stem: str = "vocals", + batch_size: int = 2, + num_workers: int = 8, + train_kwargs: Optional[Mapping] = None, + val_kwargs: Optional[Mapping] = None, + test_kwargs: Optional[Mapping] = None, + datamodule_kwargs: Optional[Mapping] = None, + use_on_the_fly: bool = True, + npy_memmap: bool = True, +) -> pl.LightningDataModule: + if train_kwargs is None: + train_kwargs = {} + + if val_kwargs is None: + val_kwargs = {} + + if test_kwargs is None: + test_kwargs = {} + + if datamodule_kwargs is None: + datamodule_kwargs = {} + + train_dataset: MUSDB18BaseDataset + + if use_on_the_fly: + train_dataset = MUSDB18SadOnTheFlyAugmentedDataset( + data_root=os.path.join(data_root, "saded-np"), + split="train", + target_stem=target_stem, + **train_kwargs + ) + else: + train_dataset = MUSDB18SadDataset( + data_root=os.path.join(data_root, "saded-np"), + split="train", + target_stem=target_stem, + **train_kwargs + ) + + datamodule = pl.LightningDataModule.from_datasets( + train_dataset=train_dataset, + val_dataset=MUSDB18SadDataset( + data_root=os.path.join(data_root, "saded-np"), + split="val", + target_stem=target_stem, + **val_kwargs + ), + test_dataset=MUSDB18FullTrackDataset( + data_root=os.path.join(data_root, "canonical"), split="test", **test_kwargs + ), + batch_size=batch_size, + num_workers=num_workers, + **datamodule_kwargs + ) + + datamodule.predict_dataloader = ( # type: ignore[method-assign] + datamodule.test_dataloader + ) + + return datamodule diff --git a/mvsepless/models/bandit/core/data/musdb/dataset.py b/mvsepless/models/bandit/core/data/musdb/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..82eb33bb32a6e2172ace1bb8ae1a9649041af5bf --- /dev/null +++ b/mvsepless/models/bandit/core/data/musdb/dataset.py @@ -0,0 +1,273 @@ +import os +from abc import ABC +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torchaudio as ta +from torch.utils import data + +from .._types import AudioDict, DataDict +from ..base import BaseSourceSeparationDataset + + +class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC): + + ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"] + + def __init__( + self, + split: str, + stems: List[str], + files: List[str], + data_path: str, + fs: int = 44100, + npy_memmap=False, + ) -> None: + super().__init__( + split=split, + stems=stems, + files=files, + data_path=data_path, + fs=fs, + npy_memmap=npy_memmap, + recompute_mixture=False, + ) + + def get_stem(self, *, stem: str, identifier) -> torch.Tensor: + track = identifier["track"] + path = os.path.join(self.data_path, track) + # noinspection PyUnresolvedReferences + + if self.npy_memmap: + audio = np.load(os.path.join(path, f"{stem}.wav.npy"), mmap_mode="r") + else: + audio, _ = ta.load(os.path.join(path, f"{stem}.wav")) + + return audio + + def get_identifier(self, index): + return dict(track=self.files[index]) + + def __getitem__(self, index: int) -> DataDict: + identifier = self.get_identifier(index) + audio = self.get_audio(identifier) + + return {"audio": audio, "track": f"{self.split}/{identifier['track']}"} + + +class MUSDB18FullTrackDataset(MUSDB18BaseDataset): + + N_TRAIN_TRACKS = 100 + N_TEST_TRACKS = 50 + VALIDATION_FILES = [ + "Actions - One Minute Smile", + "Clara Berry And Wooldog - Waltz For My Victims", + "Johnny Lokke - Promises & Lies", + "Patrick Talbot - A Reason To Leave", + "Triviul - Angelsaint", + "Alexander Ross - Goodbye Bolero", + "Fergessen - Nos Palpitants", + "Leaf - Summerghost", + "Skelpolu - Human Mistakes", + "Young Griffo - Pennies", + "ANiMAL - Rockshow", + "James May - On The Line", + "Meaxic - Take A Step", + "Traffic Experiment - Sirens", + ] + + def __init__( + self, data_root: str, split: str, stems: Optional[List[str]] = None + ) -> None: + + if stems is None: + stems = self.ALLOWED_STEMS + self.stems = stems + + if split == "test": + subset = "test" + elif split in ["train", "val"]: + subset = "train" + else: + raise NameError + + data_path = os.path.join(data_root, subset) + + files = sorted(os.listdir(data_path)) + files = [f for f in files if not f.startswith(".")] + # pprint(list(enumerate(files))) + if subset == "train": + assert len(files) == 100, len(files) + if split == "train": + files = [f for f in files if f not in self.VALIDATION_FILES] + assert len(files) == 100 - len(self.VALIDATION_FILES) + else: + files = [f for f in files if f in self.VALIDATION_FILES] + assert len(files) == len(self.VALIDATION_FILES) + else: + split = "test" + assert len(files) == 50 + + self.n_tracks = len(files) + + super().__init__(data_path=data_path, split=split, stems=stems, files=files) + + def __len__(self) -> int: + return self.n_tracks + + +class MUSDB18SadDataset(MUSDB18BaseDataset): + def __init__( + self, + data_root: str, + split: str, + target_stem: str, + stems: Optional[List[str]] = None, + target_length: Optional[int] = None, + npy_memmap=False, + ) -> None: + + if stems is None: + stems = self.ALLOWED_STEMS + + data_path = os.path.join(data_root, target_stem, split) + + files = sorted(os.listdir(data_path)) + files = [f for f in files if not f.startswith(".")] + + super().__init__( + data_path=data_path, + split=split, + stems=stems, + files=files, + npy_memmap=npy_memmap, + ) + self.n_segments = len(files) + self.target_stem = target_stem + self.target_length = ( + target_length if target_length is not None else self.n_segments + ) + + def __len__(self) -> int: + return self.target_length + + def __getitem__(self, index: int) -> DataDict: + + index = index % self.n_segments + + return super().__getitem__(index) + + def get_identifier(self, index): + return super().get_identifier(index % self.n_segments) + + +class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset): + def __init__( + self, + data_root: str, + split: str, + target_stem: str, + stems: Optional[List[str]] = None, + target_length: int = 20000, + apply_probability: Optional[float] = None, + chunk_size_second: float = 3.0, + random_scale_range_db: Tuple[float, float] = (-10, 10), + drop_probability: float = 0.1, + rescale: bool = True, + ) -> None: + super().__init__(data_root, split, target_stem, stems) + + if apply_probability is None: + apply_probability = (target_length - self.n_segments) / target_length + + self.apply_probability = apply_probability + self.drop_probability = drop_probability + self.chunk_size_second = chunk_size_second + self.random_scale_range_db = random_scale_range_db + self.rescale = rescale + + self.chunk_size_sample = int(self.chunk_size_second * self.fs) + self.target_length = target_length + + def __len__(self) -> int: + return self.target_length + + def __getitem__(self, index: int) -> DataDict: + + index = index % self.n_segments + + # if np.random.rand() > self.apply_probability: + # return super().__getitem__(index) + + audio = {} + identifier = self.get_identifier(index) + + # assert self.target_stem in self.stems_no_mixture + for stem in self.stems_no_mixture: + if stem == self.target_stem: + identifier_ = identifier + else: + if np.random.rand() < self.apply_probability: + index_ = np.random.randint(self.n_segments) + identifier_ = self.get_identifier(index_) + else: + identifier_ = identifier + + audio[stem] = self.get_stem(stem=stem, identifier=identifier_) + + # if stem == self.target_stem: + + if self.chunk_size_sample < audio[stem].shape[-1]: + chunk_start = np.random.randint( + audio[stem].shape[-1] - self.chunk_size_sample + ) + else: + chunk_start = 0 + + if np.random.rand() < self.drop_probability: + # db_scale = "-inf" + linear_scale = 0.0 + else: + db_scale = np.random.uniform(*self.random_scale_range_db) + linear_scale = np.power(10, db_scale / 20) + # db_scale = f"{db_scale:+2.1f}" + # print(linear_scale) + audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample] = ( + linear_scale + * audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample] + ) + + audio["mixture"] = self.compute_mixture(audio) + + if self.rescale: + max_abs_val = max( + [torch.max(torch.abs(audio[stem])) for stem in self.stems] + ) # type: ignore[type-var] + if max_abs_val > 1: + audio = {k: v / max_abs_val for k, v in audio.items()} + + track = identifier["track"] + + return {"audio": audio, "track": f"{self.split}/{track}"} + + +# if __name__ == "__main__": +# +# from pprint import pprint +# from tqdm.auto import tqdm +# +# for split_ in ["train", "val", "test"]: +# ds = MUSDB18SadOnTheFlyAugmentedDataset( +# data_root="$DATA_ROOT/MUSDB18/HQ/saded", +# split=split_, +# target_stem="vocals" +# ) +# +# print(split_, len(ds)) +# +# for track_ in tqdm(ds): +# track_["audio"] = { +# k: v.shape for k, v in track_["audio"].items() +# } +# pprint(track_) diff --git a/mvsepless/models/bandit/core/data/musdb/preprocess.py b/mvsepless/models/bandit/core/data/musdb/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..ad95caeb1ebcbac261ca10cd333896f54bad927b --- /dev/null +++ b/mvsepless/models/bandit/core/data/musdb/preprocess.py @@ -0,0 +1,226 @@ +import glob +import os + +import numpy as np +import torch +import torchaudio as ta +from torch import nn +from torch.nn import functional as F +from tqdm.contrib.concurrent import process_map + +from .._types import DataDict +from .dataset import MUSDB18FullTrackDataset +import pyloudnorm as pyln + + +class SourceActivityDetector(nn.Module): + def __init__( + self, + analysis_stem: str, + output_path: str, + fs: int = 44100, + segment_length_second: float = 6.0, + hop_length_second: float = 3.0, + n_chunks: int = 10, + chunk_epsilon: float = 1e-5, + energy_threshold_quantile: float = 0.15, + segment_epsilon: float = 1e-3, + salient_proportion_threshold: float = 0.5, + target_lufs: float = -24, + ) -> None: + super().__init__() + + self.fs = fs + self.segment_length = int(segment_length_second * self.fs) + self.hop_length = int(hop_length_second * self.fs) + self.n_chunks = n_chunks + assert self.segment_length % self.n_chunks == 0 + self.chunk_size = self.segment_length // self.n_chunks + self.chunk_epsilon = chunk_epsilon + self.energy_threshold_quantile = energy_threshold_quantile + self.segment_epsilon = segment_epsilon + self.salient_proportion_threshold = salient_proportion_threshold + self.analysis_stem = analysis_stem + + self.meter = pyln.Meter(self.fs) + self.target_lufs = target_lufs + + self.output_path = output_path + + def forward(self, data: DataDict) -> None: + + stem_ = self.analysis_stem if (self.analysis_stem != "none") else "mixture" + + x = data["audio"][stem_] + + xnp = x.numpy() + loudness = self.meter.integrated_loudness(xnp.T) + + for stem in data["audio"]: + s = data["audio"][stem] + s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T + s = torch.as_tensor(s) + data["audio"][stem] = s + + if x.ndim == 3: + assert x.shape[0] == 1 + x = x[0] + + n_chan, n_samples = x.shape + + n_segments = ( + int(np.ceil((n_samples - self.segment_length) / self.hop_length)) + 1 + ) + + segments = torch.zeros((n_segments, n_chan, self.segment_length)) + for i in range(n_segments): + start = i * self.hop_length + end = start + self.segment_length + end = min(end, n_samples) + + xseg = x[:, start:end] + + if end - start < self.segment_length: + xseg = F.pad( + xseg, pad=(0, self.segment_length - (end - start)), value=torch.nan + ) + + segments[i, :, :] = xseg + + chunks = segments.reshape((n_segments, n_chan, self.n_chunks, self.chunk_size)) + + if self.analysis_stem != "none": + chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3)) + chunk_energies = torch.nan_to_num(chunk_energies, nan=0) + chunk_energies[chunk_energies == 0] = self.chunk_epsilon + + energy_threshold = torch.nanquantile( + chunk_energies, q=self.energy_threshold_quantile + ) + + if energy_threshold < self.segment_epsilon: + energy_threshold = self.segment_epsilon # type: ignore[assignment] + + chunks_above_threshold = chunk_energies > energy_threshold + n_chunks_above_threshold = torch.mean( + chunks_above_threshold.to(torch.float), dim=-1 + ) + + segment_above_threshold = ( + n_chunks_above_threshold > self.salient_proportion_threshold + ) + + if torch.sum(segment_above_threshold) == 0: + return + + else: + segment_above_threshold = torch.ones((n_segments,)) + + for i in range(n_segments): + if not segment_above_threshold[i]: + continue + + outpath = os.path.join( + self.output_path, + self.analysis_stem, + f"{data['track']} - {self.analysis_stem}{i:03d}", + ) + os.makedirs(outpath, exist_ok=True) + + for stem in data["audio"]: + if stem == self.analysis_stem: + segment = torch.nan_to_num(segments[i, :, :], nan=0) + else: + start = i * self.hop_length + end = start + self.segment_length + end = min(n_samples, end) + + segment = data["audio"][stem][:, start:end] + + if end - start < self.segment_length: + segment = F.pad( + segment, (0, self.segment_length - (end - start)) + ) + + assert segment.shape[-1] == self.segment_length, segment.shape + + # ta.save(os.path.join(outpath, f"{stem}.wav"), segment, self.fs) + + np.save(os.path.join(outpath, f"{stem}.wav"), segment) + + +def preprocess( + analysis_stem: str, + output_path: str = "/data/MUSDB18/HQ/saded-np", + fs: int = 44100, + segment_length_second: float = 6.0, + hop_length_second: float = 3.0, + n_chunks: int = 10, + chunk_epsilon: float = 1e-5, + energy_threshold_quantile: float = 0.15, + segment_epsilon: float = 1e-3, + salient_proportion_threshold: float = 0.5, +) -> None: + + sad = SourceActivityDetector( + analysis_stem=analysis_stem, + output_path=output_path, + fs=fs, + segment_length_second=segment_length_second, + hop_length_second=hop_length_second, + n_chunks=n_chunks, + chunk_epsilon=chunk_epsilon, + energy_threshold_quantile=energy_threshold_quantile, + segment_epsilon=segment_epsilon, + salient_proportion_threshold=salient_proportion_threshold, + ) + + for split in ["train", "val", "test"]: + ds = MUSDB18FullTrackDataset( + data_root="/data/MUSDB18/HQ/canonical", + split=split, + ) + + tracks = [] + for i, track in enumerate(tqdm(ds, total=len(ds))): + if i % 32 == 0 and tracks: + process_map(sad, tracks, max_workers=8) + tracks = [] + tracks.append(track) + process_map(sad, tracks, max_workers=8) + + +def loudness_norm_one(inputs): + infile, outfile, target_lufs = inputs + + audio, fs = ta.load(infile) + audio = audio.mean(dim=0, keepdim=True).numpy().T + + meter = pyln.Meter(fs) + loudness = meter.integrated_loudness(audio) + audio = pyln.normalize.loudness(audio, loudness, target_lufs) + + os.makedirs(os.path.dirname(outfile), exist_ok=True) + np.save(outfile, audio.T) + + +def loudness_norm( + data_path: str, + # output_path: str, + target_lufs=-17.0, +): + files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True) + + outfiles = [f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files] + + files = [(f, o, target_lufs) for f, o in zip(files, outfiles)] + + process_map(loudness_norm_one, files, chunksize=2) + + +if __name__ == "__main__": + + from tqdm.auto import tqdm + import fire + + fire.Fire() diff --git a/mvsepless/models/bandit/core/data/musdb/validation.yaml b/mvsepless/models/bandit/core/data/musdb/validation.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2f8752478d285d1d13d5e842225af1de95cae57a --- /dev/null +++ b/mvsepless/models/bandit/core/data/musdb/validation.yaml @@ -0,0 +1,15 @@ +validation: + - 'Actions - One Minute Smile' + - 'Clara Berry And Wooldog - Waltz For My Victims' + - 'Johnny Lokke - Promises & Lies' + - 'Patrick Talbot - A Reason To Leave' + - 'Triviul - Angelsaint' + - 'Alexander Ross - Goodbye Bolero' + - 'Fergessen - Nos Palpitants' + - 'Leaf - Summerghost' + - 'Skelpolu - Human Mistakes' + - 'Young Griffo - Pennies' + - 'ANiMAL - Rockshow' + - 'James May - On The Line' + - 'Meaxic - Take A Step' + - 'Traffic Experiment - Sirens' \ No newline at end of file diff --git a/mvsepless/models/bandit/core/loss/__init__.py b/mvsepless/models/bandit/core/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..993be521fa7ab8f06a2a012beabdb9fdd6cd0a80 --- /dev/null +++ b/mvsepless/models/bandit/core/loss/__init__.py @@ -0,0 +1,8 @@ +from ._multistem import MultiStemWrapperFromConfig +from ._timefreq import ( + ReImL1Loss, + ReImL2Loss, + TimeFreqL1Loss, + TimeFreqL2Loss, + TimeFreqSignalNoisePNormRatioLoss, +) diff --git a/mvsepless/models/bandit/core/loss/_complex.py b/mvsepless/models/bandit/core/loss/_complex.py new file mode 100644 index 0000000000000000000000000000000000000000..68c82f204709d07cba013f1582ca985bcf66dde6 --- /dev/null +++ b/mvsepless/models/bandit/core/loss/_complex.py @@ -0,0 +1,27 @@ +from typing import Any + +import torch +from torch import nn +from torch.nn.modules import loss as _loss +from torch.nn.modules.loss import _Loss + + +class ReImLossWrapper(_Loss): + def __init__(self, module: _Loss) -> None: + super().__init__() + self.module = module + + def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + return self.module(torch.view_as_real(preds), torch.view_as_real(target)) + + +class ReImL1Loss(ReImLossWrapper): + def __init__(self, **kwargs: Any) -> None: + l1_loss = _loss.L1Loss(**kwargs) + super().__init__(module=(l1_loss)) + + +class ReImL2Loss(ReImLossWrapper): + def __init__(self, **kwargs: Any) -> None: + l2_loss = _loss.MSELoss(**kwargs) + super().__init__(module=(l2_loss)) diff --git a/mvsepless/models/bandit/core/loss/_multistem.py b/mvsepless/models/bandit/core/loss/_multistem.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c4a4f776a318b4450cc98c820a057983e2f9a3 --- /dev/null +++ b/mvsepless/models/bandit/core/loss/_multistem.py @@ -0,0 +1,43 @@ +from typing import Any, Dict + +import torch +from asteroid import losses as asteroid_losses +from torch import nn +from torch.nn.modules.loss import _Loss + +from . import snr + + +def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss: + + for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]: + if name in module.__dict__: + return module.__dict__[name](**kwargs) + + raise NameError + + +class MultiStemWrapper(_Loss): + def __init__(self, module: _Loss, modality: str = "audio") -> None: + super().__init__() + self.loss = module + self.modality = modality + + def forward( + self, + preds: Dict[str, Dict[str, torch.Tensor]], + target: Dict[str, Dict[str, torch.Tensor]], + ) -> torch.Tensor: + loss = { + stem: self.loss(preds[self.modality][stem], target[self.modality][stem]) + for stem in preds[self.modality] + if stem in target[self.modality] + } + + return sum(list(loss.values())) + + +class MultiStemWrapperFromConfig(MultiStemWrapper): + def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None: + loss = parse_loss(name, kwargs) + super().__init__(module=loss, modality=modality) diff --git a/mvsepless/models/bandit/core/loss/_timefreq.py b/mvsepless/models/bandit/core/loss/_timefreq.py new file mode 100644 index 0000000000000000000000000000000000000000..edf848e05650b7ce833f6d642a8a53896cd2d8fd --- /dev/null +++ b/mvsepless/models/bandit/core/loss/_timefreq.py @@ -0,0 +1,95 @@ +from typing import Any, Dict, Optional + +import torch +from torch import nn +from torch.nn.modules.loss import _Loss + +from ._multistem import MultiStemWrapper +from ._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper +from .snr import SignalNoisePNormRatio + + +class TimeFreqWrapper(_Loss): + def __init__( + self, + time_module: _Loss, + freq_module: Optional[_Loss] = None, + time_weight: float = 1.0, + freq_weight: float = 1.0, + multistem: bool = True, + ) -> None: + super().__init__() + + if freq_module is None: + freq_module = time_module + + if multistem: + time_module = MultiStemWrapper(time_module, modality="audio") + freq_module = MultiStemWrapper(freq_module, modality="spectrogram") + + self.time_module = time_module + self.freq_module = freq_module + + self.time_weight = time_weight + self.freq_weight = freq_weight + + # TODO: add better type hints + def forward(self, preds: Any, target: Any) -> torch.Tensor: + + return self.time_weight * self.time_module( + preds, target + ) + self.freq_weight * self.freq_module(preds, target) + + +class TimeFreqL1Loss(TimeFreqWrapper): + def __init__( + self, + time_weight: float = 1.0, + freq_weight: float = 1.0, + tkwargs: Optional[Dict[str, Any]] = None, + fkwargs: Optional[Dict[str, Any]] = None, + multistem: bool = True, + ) -> None: + if tkwargs is None: + tkwargs = {} + if fkwargs is None: + fkwargs = {} + time_module = nn.L1Loss(**tkwargs) + freq_module = ReImL1Loss(**fkwargs) + super().__init__(time_module, freq_module, time_weight, freq_weight, multistem) + + +class TimeFreqL2Loss(TimeFreqWrapper): + def __init__( + self, + time_weight: float = 1.0, + freq_weight: float = 1.0, + tkwargs: Optional[Dict[str, Any]] = None, + fkwargs: Optional[Dict[str, Any]] = None, + multistem: bool = True, + ) -> None: + if tkwargs is None: + tkwargs = {} + if fkwargs is None: + fkwargs = {} + time_module = nn.MSELoss(**tkwargs) + freq_module = ReImL2Loss(**fkwargs) + super().__init__(time_module, freq_module, time_weight, freq_weight, multistem) + + +class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper): + def __init__( + self, + time_weight: float = 1.0, + freq_weight: float = 1.0, + tkwargs: Optional[Dict[str, Any]] = None, + fkwargs: Optional[Dict[str, Any]] = None, + multistem: bool = True, + ) -> None: + if tkwargs is None: + tkwargs = {} + if fkwargs is None: + fkwargs = {} + time_module = SignalNoisePNormRatio(**tkwargs) + freq_module = SignalNoisePNormRatio(**fkwargs) + super().__init__(time_module, freq_module, time_weight, freq_weight, multistem) diff --git a/mvsepless/models/bandit/core/loss/snr.py b/mvsepless/models/bandit/core/loss/snr.py new file mode 100644 index 0000000000000000000000000000000000000000..8d712a525027417198c7072ef571166ef0c02afa --- /dev/null +++ b/mvsepless/models/bandit/core/loss/snr.py @@ -0,0 +1,139 @@ +import torch +from torch.nn.modules.loss import _Loss +from torch.nn import functional as F + + +class SignalNoisePNormRatio(_Loss): + def __init__( + self, + p: float = 1.0, + scale_invariant: bool = False, + zero_mean: bool = False, + take_log: bool = True, + reduction: str = "mean", + EPS: float = 1e-3, + ) -> None: + assert reduction != "sum", NotImplementedError + super().__init__(reduction=reduction) + assert not zero_mean + + self.p = p + + self.EPS = EPS + self.take_log = take_log + + self.scale_invariant = scale_invariant + + def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + + target_ = target + if self.scale_invariant: + ndim = target.ndim + dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True) + s_target_energy = torch.sum( + target * torch.conj(target), dim=-1, keepdim=True + ) + + if ndim > 2: + dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True) + s_target_energy = torch.sum( + s_target_energy, dim=list(range(1, ndim)), keepdim=True + ) + + target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8) + target = target_ * target_scaler + + if torch.is_complex(est_target): + est_target = torch.view_as_real(est_target) + target = torch.view_as_real(target) + + batch_size = est_target.shape[0] + est_target = est_target.reshape(batch_size, -1) + target = target.reshape(batch_size, -1) + # target_ = target_.reshape(batch_size, -1) + + if self.p == 1: + e_error = torch.abs(est_target - target).mean(dim=-1) + e_target = torch.abs(target).mean(dim=-1) + elif self.p == 2: + e_error = torch.square(est_target - target).mean(dim=-1) + e_target = torch.square(target).mean(dim=-1) + else: + raise NotImplementedError + + if self.take_log: + loss = 10 * ( + torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS) + ) + else: + loss = (e_error + self.EPS) / (e_target + self.EPS) + + if self.reduction == "mean": + loss = loss.mean() + elif self.reduction == "sum": + loss = loss.sum() + + return loss + + +class MultichannelSingleSrcNegSDR(_Loss): + def __init__( + self, + sdr_type: str, + p: float = 2.0, + zero_mean: bool = True, + take_log: bool = True, + reduction: str = "mean", + EPS: float = 1e-8, + ) -> None: + assert reduction != "sum", NotImplementedError + super().__init__(reduction=reduction) + + assert sdr_type in ["snr", "sisdr", "sdsdr"] + self.sdr_type = sdr_type + self.zero_mean = zero_mean + self.take_log = take_log + self.EPS = 1e-8 + + self.p = p + + def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + if target.size() != est_target.size() or target.ndim != 3: + raise TypeError( + f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead" + ) + # Step 1. Zero-mean norm + if self.zero_mean: + mean_source = torch.mean(target, dim=[1, 2], keepdim=True) + mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True) + target = target - mean_source + est_target = est_target - mean_estimate + # Step 2. Pair-wise SI-SDR. + if self.sdr_type in ["sisdr", "sdsdr"]: + # [batch, 1] + dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True) + # [batch, 1] + s_target_energy = torch.sum(target**2, dim=[1, 2], keepdim=True) + self.EPS + # [batch, time] + scaled_target = dot * target / s_target_energy + else: + # [batch, time] + scaled_target = target + if self.sdr_type in ["sdsdr", "snr"]: + e_noise = est_target - target + else: + e_noise = est_target - scaled_target + # [batch] + + if self.p == 2.0: + losses = torch.sum(scaled_target**2, dim=[1, 2]) / ( + torch.sum(e_noise**2, dim=[1, 2]) + self.EPS + ) + else: + losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / ( + torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS + ) + if self.take_log: + losses = 10 * torch.log10(losses + self.EPS) + losses = losses.mean() if self.reduction == "mean" else losses + return -losses diff --git a/mvsepless/models/bandit/core/metrics/__init__.py b/mvsepless/models/bandit/core/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c638b4df585ad6c3c6490d9e67b7fc197f0d06f4 --- /dev/null +++ b/mvsepless/models/bandit/core/metrics/__init__.py @@ -0,0 +1,9 @@ +from .snr import ( + ChunkMedianScaleInvariantSignalDistortionRatio, + ChunkMedianScaleInvariantSignalNoiseRatio, + ChunkMedianSignalDistortionRatio, + ChunkMedianSignalNoiseRatio, + SafeSignalDistortionRatio, +) + +# from .mushra import EstimatedMushraScore diff --git a/mvsepless/models/bandit/core/metrics/_squim.py b/mvsepless/models/bandit/core/metrics/_squim.py new file mode 100644 index 0000000000000000000000000000000000000000..71c993a2b6cb3da36849c2f87ef7bb7443a9095c --- /dev/null +++ b/mvsepless/models/bandit/core/metrics/_squim.py @@ -0,0 +1,443 @@ +from dataclasses import dataclass + +from torchaudio._internal import load_state_dict_from_url + +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def transform_wb_pesq_range(x: float) -> float: + """The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined + for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric + defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score". + + Args: + x (float): Narrow-band PESQ score. + + Returns: + (float): Wide-band PESQ score. + """ + return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224)) + + +PESQRange: Tuple[float, float] = ( + 1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of + # the raw score is not -0.5 anymore. It's hard to figure out the true lower bound. + # We are using 1.0 as a reasonable approximation. + transform_wb_pesq_range(4.5), +) + + +class RangeSigmoid(nn.Module): + def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None: + super(RangeSigmoid, self).__init__() + assert isinstance(val_range, tuple) and len(val_range) == 2 + self.val_range: Tuple[float, float] = val_range + self.sigmoid: nn.modules.Module = nn.Sigmoid() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = ( + self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + + self.val_range[0] + ) + return out + + +class Encoder(nn.Module): + """Encoder module that transform 1D waveform to 2D representations. + + Args: + feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512) + win_len (int, optional): kernel size in the Conv1D layer. (Default: 32) + """ + + def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None: + super(Encoder, self).__init__() + + self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply waveforms to convolutional layer and ReLU layer. + + Args: + x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`. + + Returns: + (torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`. + """ + out = x.unsqueeze(dim=1) + out = F.relu(self.conv1d(out)) + return out + + +class SingleRNN(nn.Module): + def __init__( + self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0 + ) -> None: + super(SingleRNN, self).__init__() + + self.rnn_type = rnn_type + self.input_size = input_size + self.hidden_size = hidden_size + + self.rnn: nn.modules.Module = getattr(nn, rnn_type)( + input_size, + hidden_size, + 1, + dropout=dropout, + batch_first=True, + bidirectional=True, + ) + + self.proj = nn.Linear(hidden_size * 2, input_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # input shape: batch, seq, dim + out, _ = self.rnn(x) + out = self.proj(out) + return out + + +class DPRNN(nn.Module): + """*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`. + + Args: + feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64) + hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128) + num_blocks (int, optional): Number of DPRNN layers. (Default: 6) + rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM") + d_model (int, optional): The number of expected features in the input. (Default: 256) + chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100) + chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50) + """ + + def __init__( + self, + feat_dim: int = 64, + hidden_dim: int = 128, + num_blocks: int = 6, + rnn_type: str = "LSTM", + d_model: int = 256, + chunk_size: int = 100, + chunk_stride: int = 50, + ) -> None: + super(DPRNN, self).__init__() + + self.num_blocks = num_blocks + + self.row_rnn = nn.ModuleList([]) + self.col_rnn = nn.ModuleList([]) + self.row_norm = nn.ModuleList([]) + self.col_norm = nn.ModuleList([]) + for _ in range(num_blocks): + self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim)) + self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim)) + self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8)) + self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8)) + self.conv = nn.Sequential( + nn.Conv2d(feat_dim, d_model, 1), + nn.PReLU(), + ) + self.chunk_size = chunk_size + self.chunk_stride = chunk_stride + + def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: + # input shape: (B, N, T) + seq_len = x.shape[-1] + + rest = ( + self.chunk_size + - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size + ) + out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride]) + + return out, rest + + def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: + out, rest = self.pad_chunk(x) + batch_size, feat_dim, seq_len = out.shape + + segments1 = ( + out[:, :, : -self.chunk_stride] + .contiguous() + .view(batch_size, feat_dim, -1, self.chunk_size) + ) + segments2 = ( + out[:, :, self.chunk_stride :] + .contiguous() + .view(batch_size, feat_dim, -1, self.chunk_size) + ) + out = torch.cat([segments1, segments2], dim=3) + out = ( + out.view(batch_size, feat_dim, -1, self.chunk_size) + .transpose(2, 3) + .contiguous() + ) + + return out, rest + + def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor: + batch_size, dim, _, _ = x.shape + out = ( + x.transpose(2, 3) + .contiguous() + .view(batch_size, dim, -1, self.chunk_size * 2) + ) + out1 = ( + out[:, :, :, : self.chunk_size] + .contiguous() + .view(batch_size, dim, -1)[:, :, self.chunk_stride :] + ) + out2 = ( + out[:, :, :, self.chunk_size :] + .contiguous() + .view(batch_size, dim, -1)[:, :, : -self.chunk_stride] + ) + out = out1 + out2 + if rest > 0: + out = out[:, :, :-rest] + out = out.contiguous() + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, rest = self.chunking(x) + batch_size, _, dim1, dim2 = x.shape + out = x + for row_rnn, row_norm, col_rnn, col_norm in zip( + self.row_rnn, self.row_norm, self.col_rnn, self.col_norm + ): + row_in = ( + out.permute(0, 3, 2, 1) + .contiguous() + .view(batch_size * dim2, dim1, -1) + .contiguous() + ) + row_out = row_rnn(row_in) + row_out = ( + row_out.view(batch_size, dim2, dim1, -1) + .permute(0, 3, 2, 1) + .contiguous() + ) + row_out = row_norm(row_out) + out = out + row_out + + col_in = ( + out.permute(0, 2, 3, 1) + .contiguous() + .view(batch_size * dim1, dim2, -1) + .contiguous() + ) + col_out = col_rnn(col_in) + col_out = ( + col_out.view(batch_size, dim1, dim2, -1) + .permute(0, 3, 1, 2) + .contiguous() + ) + col_out = col_norm(col_out) + out = out + col_out + out = self.conv(out) + out = self.merging(out, rest) + out = out.transpose(1, 2).contiguous() + return out + + +class AutoPool(nn.Module): + def __init__(self, pool_dim: int = 1) -> None: + super(AutoPool, self).__init__() + self.pool_dim: int = pool_dim + self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim) + self.register_parameter("alpha", nn.Parameter(torch.ones(1))) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + weight = self.softmax(torch.mul(x, self.alpha)) + out = torch.sum(torch.mul(x, weight), dim=self.pool_dim) + return out + + +class SquimObjective(nn.Module): + """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores + for speech enhancement (e.g., STOI, PESQ, and SI-SDR). + + Args: + encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation. + dprnn (torch.nn.Module): DPRNN module to model sequential feature. + branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score. + """ + + def __init__( + self, + encoder: nn.Module, + dprnn: nn.Module, + branches: nn.ModuleList, + ): + super(SquimObjective, self).__init__() + self.encoder = encoder + self.dprnn = dprnn + self.branches = branches + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + """ + Args: + x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`. + + Returns: + List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`. + """ + if x.ndim != 2: + raise ValueError( + f"The input must be a 2D Tensor. Found dimension {x.ndim}." + ) + x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20) + out = self.encoder(x) + out = self.dprnn(out) + scores = [] + for branch in self.branches: + scores.append(branch(out).squeeze(dim=1)) + return scores + + +def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module: + """Create branch module after DPRNN model for predicting metric score. + + Args: + d_model (int): The number of expected features in the input. + nhead (int): Number of heads in the multi-head attention model. + metric (str): The metric name to predict. + + Returns: + (nn.Module): Returned module to predict corresponding metric score. + """ + layer1 = nn.TransformerEncoderLayer( + d_model, nhead, d_model * 4, dropout=0.0, batch_first=True + ) + layer2 = AutoPool() + if metric == "stoi": + layer3 = nn.Sequential( + nn.Linear(d_model, d_model), + nn.PReLU(), + nn.Linear(d_model, 1), + RangeSigmoid(), + ) + elif metric == "pesq": + layer3 = nn.Sequential( + nn.Linear(d_model, d_model), + nn.PReLU(), + nn.Linear(d_model, 1), + RangeSigmoid(val_range=PESQRange), + ) + else: + layer3: nn.modules.Module = nn.Sequential( + nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1) + ) + return nn.Sequential(layer1, layer2, layer3) + + +def squim_objective_model( + feat_dim: int, + win_len: int, + d_model: int, + nhead: int, + hidden_dim: int, + num_blocks: int, + rnn_type: str, + chunk_size: int, + chunk_stride: Optional[int] = None, +) -> SquimObjective: + """Build a custome :class:`torchaudio.prototype.models.SquimObjective` model. + + Args: + feat_dim (int, optional): The feature dimension after Encoder module. + win_len (int): Kernel size in the Encoder module. + d_model (int): The number of expected features in the input. + nhead (int): Number of heads in the multi-head attention model. + hidden_dim (int): Hidden dimension in the RNN layer of DPRNN. + num_blocks (int): Number of DPRNN layers. + rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. + chunk_size (int): Chunk size of input for DPRNN. + chunk_stride (int or None, optional): Stride of chunk input for DPRNN. + """ + if chunk_stride is None: + chunk_stride = chunk_size // 2 + encoder = Encoder(feat_dim, win_len) + dprnn = DPRNN( + feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride + ) + branches = nn.ModuleList( + [ + _create_branch(d_model, nhead, "stoi"), + _create_branch(d_model, nhead, "pesq"), + _create_branch(d_model, nhead, "sisdr"), + ] + ) + return SquimObjective(encoder, dprnn, branches) + + +def squim_objective_base() -> SquimObjective: + """Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments.""" + return squim_objective_model( + feat_dim=256, + win_len=64, + d_model=256, + nhead=4, + hidden_dim=256, + num_blocks=2, + rnn_type="LSTM", + chunk_size=71, + ) + + +@dataclass +class SquimObjectiveBundle: + + _path: str + _sample_rate: float + + def _get_state_dict(self, dl_kwargs): + url = f"https://download.pytorch.org/torchaudio/models/{self._path}" + dl_kwargs = {} if dl_kwargs is None else dl_kwargs + state_dict = load_state_dict_from_url(url, **dl_kwargs) + return state_dict + + def get_model(self, *, dl_kwargs=None) -> SquimObjective: + """Construct the SquimObjective model, and load the pretrained weight. + + The weight file is downloaded from the internet and cached with + :func:`torch.hub.load_state_dict_from_url` + + Args: + dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. + + Returns: + Variation of :py:class:`~torchaudio.models.SquimObjective`. + """ + model = squim_objective_base() + model.load_state_dict(self._get_state_dict(dl_kwargs)) + model.eval() + return model + + @property + def sample_rate(self): + """Sample rate of the audio that the model is trained on. + + :type: float + """ + return self._sample_rate + + +SQUIM_OBJECTIVE = SquimObjectiveBundle( + "squim_objective_dns2020.pth", + _sample_rate=16000, +) +SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in + :cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`. + + The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`. + The weights are under `Creative Commons Attribution 4.0 International License + `__. + + Please refer to :py:class:`SquimObjectiveBundle` for usage instructions. + """ diff --git a/mvsepless/models/bandit/core/metrics/snr.py b/mvsepless/models/bandit/core/metrics/snr.py new file mode 100644 index 0000000000000000000000000000000000000000..6b7a168756204ddd8044c7c74e31c15dc1643aa1 --- /dev/null +++ b/mvsepless/models/bandit/core/metrics/snr.py @@ -0,0 +1,127 @@ +from typing import Any, Callable + +import numpy as np +import torch +import torchmetrics as tm +from torch._C import _LinAlgError +from torchmetrics import functional as tmF + + +class SafeSignalDistortionRatio(tm.SignalDistortionRatio): + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def update(self, *args, **kwargs) -> Any: + try: + super().update(*args, **kwargs) + except: + pass + + def compute(self) -> Any: + if self.total == 0: + return torch.tensor(torch.nan) + return super().compute() + + +class BaseChunkMedianSignalRatio(tm.Metric): + def __init__( + self, + func: Callable, + window_size: int, + hop_size: int = None, + zero_mean: bool = False, + ) -> None: + super().__init__() + + # self.zero_mean = zero_mean + self.func = func + self.window_size = window_size + if hop_size is None: + hop_size = window_size + self.hop_size = hop_size + + self.add_state("sum_snr", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + + n_samples = target.shape[-1] + + n_chunks = int(np.ceil((n_samples - self.window_size) / self.hop_size) + 1) + + snr_chunk = [] + + for i in range(n_chunks): + start = i * self.hop_size + + if n_samples - start < self.window_size: + continue + + end = start + self.window_size + + try: + chunk_snr = self.func(preds[..., start:end], target[..., start:end]) + + # print(preds.shape, chunk_snr.shape) + + if torch.all(torch.isfinite(chunk_snr)): + snr_chunk.append(chunk_snr) + except _LinAlgError: + pass + + snr_chunk = torch.stack(snr_chunk, dim=-1) + snr_batch, _ = torch.nanmedian(snr_chunk, dim=-1) + + self.sum_snr += snr_batch.sum() + self.total += snr_batch.numel() + + def compute(self) -> Any: + return self.sum_snr / self.total + + +class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio): + def __init__( + self, window_size: int, hop_size: int = None, zero_mean: bool = False + ) -> None: + super().__init__( + func=tmF.signal_noise_ratio, + window_size=window_size, + hop_size=hop_size, + zero_mean=zero_mean, + ) + + +class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio): + def __init__( + self, window_size: int, hop_size: int = None, zero_mean: bool = False + ) -> None: + super().__init__( + func=tmF.scale_invariant_signal_noise_ratio, + window_size=window_size, + hop_size=hop_size, + zero_mean=zero_mean, + ) + + +class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio): + def __init__( + self, window_size: int, hop_size: int = None, zero_mean: bool = False + ) -> None: + super().__init__( + func=tmF.signal_distortion_ratio, + window_size=window_size, + hop_size=hop_size, + zero_mean=zero_mean, + ) + + +class ChunkMedianScaleInvariantSignalDistortionRatio(BaseChunkMedianSignalRatio): + def __init__( + self, window_size: int, hop_size: int = None, zero_mean: bool = False + ) -> None: + super().__init__( + func=tmF.scale_invariant_signal_distortion_ratio, + window_size=window_size, + hop_size=hop_size, + zero_mean=zero_mean, + ) diff --git a/mvsepless/models/bandit/core/model/__init__.py b/mvsepless/models/bandit/core/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..54ac48eb69d6f844ba5b73b213eae4cfab157cac --- /dev/null +++ b/mvsepless/models/bandit/core/model/__init__.py @@ -0,0 +1,3 @@ +from .bsrnn.wrapper import ( + MultiMaskMultiSourceBandSplitRNNSimple, +) diff --git a/mvsepless/models/bandit/core/model/_spectral.py b/mvsepless/models/bandit/core/model/_spectral.py new file mode 100644 index 0000000000000000000000000000000000000000..6af5cbd0dcb6ed0a4babd6b8554184d91c406655 --- /dev/null +++ b/mvsepless/models/bandit/core/model/_spectral.py @@ -0,0 +1,54 @@ +from typing import Dict, Optional + +import torch +import torchaudio as ta +from torch import nn + + +class _SpectralComponent(nn.Module): + def __init__( + self, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + **kwargs, + ) -> None: + super().__init__() + + assert power is None + + window_fn = torch.__dict__[window_fn] + + self.stft = ta.transforms.Spectrogram( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + pad_mode=pad_mode, + pad=0, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + normalized=normalized, + center=center, + onesided=onesided, + ) + + self.istft = ta.transforms.InverseSpectrogram( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + pad_mode=pad_mode, + pad=0, + window_fn=window_fn, + wkwargs=wkwargs, + normalized=normalized, + center=center, + onesided=onesided, + ) diff --git a/mvsepless/models/bandit/core/model/bsrnn/__init__.py b/mvsepless/models/bandit/core/model/bsrnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4736e9dec4f2e731ebf34bd51bf4ce1af728b269 --- /dev/null +++ b/mvsepless/models/bandit/core/model/bsrnn/__init__.py @@ -0,0 +1,23 @@ +from abc import ABC +from typing import Iterable, Mapping, Union + +from torch import nn + +from .bandsplit import BandSplitModule +from .tfmodel import ( + SeqBandModellingModule, + TransformerTimeFreqModule, +) + + +class BandsplitCoreBase(nn.Module, ABC): + band_split: nn.Module + tf_model: nn.Module + mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]] + + def __init__(self) -> None: + super().__init__() + + @staticmethod + def mask(x, m): + return x * m diff --git a/mvsepless/models/bandit/core/model/bsrnn/bandsplit.py b/mvsepless/models/bandit/core/model/bsrnn/bandsplit.py new file mode 100644 index 0000000000000000000000000000000000000000..171c8800e5b45fa6bb8dff800d6ba61be8328375 --- /dev/null +++ b/mvsepless/models/bandit/core/model/bsrnn/bandsplit.py @@ -0,0 +1,135 @@ +from typing import List, Tuple + +import torch +from torch import nn + +from .utils import ( + band_widths_from_specs, + check_no_gap, + check_no_overlap, + check_nonzero_bandwidth, +) + + +class NormFC(nn.Module): + def __init__( + self, + emb_dim: int, + bandwidth: int, + in_channel: int, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + ) -> None: + super().__init__() + + self.treat_channel_as_feature = treat_channel_as_feature + + if normalize_channel_independently: + raise NotImplementedError + + reim = 2 + + self.norm = nn.LayerNorm(in_channel * bandwidth * reim) + + fc_in = bandwidth * reim + + if treat_channel_as_feature: + fc_in *= in_channel + else: + assert emb_dim % in_channel == 0 + emb_dim = emb_dim // in_channel + + self.fc = nn.Linear(fc_in, emb_dim) + + def forward(self, xb): + # xb = (batch, n_time, in_chan, reim * band_width) + + batch, n_time, in_chan, ribw = xb.shape + xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw)) + # (batch, n_time, in_chan * reim * band_width) + + if not self.treat_channel_as_feature: + xb = xb.reshape(batch, n_time, in_chan, ribw) + # (batch, n_time, in_chan, reim * band_width) + + zb = self.fc(xb) + # (batch, n_time, emb_dim) + # OR + # (batch, n_time, in_chan, emb_dim_per_chan) + + if not self.treat_channel_as_feature: + batch, n_time, in_chan, emb_dim_per_chan = zb.shape + # (batch, n_time, in_chan, emb_dim_per_chan) + zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan)) + + return zb # (batch, n_time, emb_dim) + + +class BandSplitModule(nn.Module): + def __init__( + self, + band_specs: List[Tuple[float, float]], + emb_dim: int, + in_channel: int, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + ) -> None: + super().__init__() + + check_nonzero_bandwidth(band_specs) + + if require_no_gap: + check_no_gap(band_specs) + + if require_no_overlap: + check_no_overlap(band_specs) + + self.band_specs = band_specs + # list of [fstart, fend) in index. + # Note that fend is exclusive. + self.band_widths = band_widths_from_specs(band_specs) + self.n_bands = len(band_specs) + self.emb_dim = emb_dim + + self.norm_fc_modules = nn.ModuleList( + [ # type: ignore + ( + NormFC( + emb_dim=emb_dim, + bandwidth=bw, + in_channel=in_channel, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + ) + ) + for bw in self.band_widths + ] + ) + + def forward(self, x: torch.Tensor): + # x = complex spectrogram (batch, in_chan, n_freq, n_time) + + batch, in_chan, _, n_time = x.shape + + z = torch.zeros( + size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device + ) + + xr = torch.view_as_real(x) # batch, in_chan, n_freq, n_time, 2 + xr = torch.permute(xr, (0, 3, 1, 4, 2)) # batch, n_time, in_chan, 2, n_freq + batch, n_time, in_chan, reim, band_width = xr.shape + for i, nfm in enumerate(self.norm_fc_modules): + # print(f"bandsplit/band{i:02d}") + fstart, fend = self.band_specs[i] + xb = xr[..., fstart:fend] + # (batch, n_time, in_chan, reim, band_width) + xb = torch.reshape(xb, (batch, n_time, in_chan, -1)) + # (batch, n_time, in_chan, reim * band_width) + # z.append(nfm(xb)) # (batch, n_time, emb_dim) + z[:, i, :, :] = nfm(xb.contiguous()) + + # z = torch.stack(z, dim=1) + + return z diff --git a/mvsepless/models/bandit/core/model/bsrnn/core.py b/mvsepless/models/bandit/core/model/bsrnn/core.py new file mode 100644 index 0000000000000000000000000000000000000000..412b37dc6994ce18f227021c5458c12e06e600a5 --- /dev/null +++ b/mvsepless/models/bandit/core/model/bsrnn/core.py @@ -0,0 +1,651 @@ +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn +from torch.nn import functional as F + +from . import BandsplitCoreBase +from .bandsplit import BandSplitModule +from .maskestim import ( + MaskEstimationModule, + OverlappingMaskEstimationModule, +) +from .tfmodel import ( + ConvolutionalTimeFreqModule, + SeqBandModellingModule, + TransformerTimeFreqModule, +) + + +class MultiMaskBandSplitCoreBase(BandsplitCoreBase): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, cond=None, compute_residual: bool = True): + # x = complex spectrogram (batch, in_chan, n_freq, n_time) + # print(x.shape) + batch, in_chan, n_freq, n_time = x.shape + x = torch.reshape(x, (-1, 1, n_freq, n_time)) + + z = self.band_split(x) # (batch, emb_dim, n_band, n_time) + + # if torch.any(torch.isnan(z)): + # raise ValueError("z nan") + + # print(z) + q = self.tf_model(z) # (batch, emb_dim, n_band, n_time) + # print(q) + + # if torch.any(torch.isnan(q)): + # raise ValueError("q nan") + + out = {} + + for stem, mem in self.mask_estim.items(): + m = mem(q, cond=cond) + + # if torch.any(torch.isnan(m)): + # raise ValueError("m nan", stem) + + s = self.mask(x, m) + s = torch.reshape(s, (batch, in_chan, n_freq, n_time)) + out[stem] = s + + return {"spectrogram": out} + + def instantiate_mask_estim( + self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + emb_dim: int, + mlp_dim: int, + cond_dim: int, + hidden_activation: str, + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, + use_freq_weights: bool = True, + mult_add_mask: bool = False, + ): + if hidden_activation_kwargs is None: + hidden_activation_kwargs = {} + + if "mne:+" in stems: + stems = [s for s in stems if s != "mne:+"] + + if overlapping_band: + assert freq_weights is not None + assert n_freq is not None + + if mult_add_mask: + + self.mask_estim = nn.ModuleDict( + { + stem: MultAddMaskEstimationModule( + band_specs=band_specs, + freq_weights=freq_weights, + n_freq=n_freq, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + use_freq_weights=use_freq_weights, + ) + for stem in stems + } + ) + else: + self.mask_estim = nn.ModuleDict( + { + stem: OverlappingMaskEstimationModule( + band_specs=band_specs, + freq_weights=freq_weights, + n_freq=n_freq, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + use_freq_weights=use_freq_weights, + ) + for stem in stems + } + ) + else: + self.mask_estim = nn.ModuleDict( + { + stem: MaskEstimationModule( + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + for stem in stems + } + ) + + def instantiate_bandsplit( + self, + in_channel: int, + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + emb_dim: int = 128, + ): + self.band_split = BandSplitModule( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + ) + + +class SingleMaskBandsplitCoreBase(BandsplitCoreBase): + def __init__(self, **kwargs) -> None: + super().__init__() + + def forward(self, x): + # x = complex spectrogram (batch, in_chan, n_freq, n_time) + z = self.band_split(x) # (batch, emb_dim, n_band, n_time) + q = self.tf_model(z) # (batch, emb_dim, n_band, n_time) + m = self.mask_estim(q) # (batch, in_chan, n_freq, n_time) + + s = self.mask(x, m) + + return s + + +class SingleMaskBandsplitCoreRNN( + SingleMaskBandsplitCoreBase, +): + def __init__( + self, + in_channel: int, + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + ) -> None: + super().__init__() + self.band_split = BandSplitModule( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + ) + self.tf_model = SeqBandModellingModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ) + self.mask_estim = MaskEstimationModule( + in_channel=in_channel, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + + +class SingleMaskBandsplitCoreTransformer( + SingleMaskBandsplitCoreBase, +): + def __init__( + self, + in_channel: int, + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + tf_dropout: float = 0.0, + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + ) -> None: + super().__init__() + self.band_split = BandSplitModule( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + ) + self.tf_model = TransformerTimeFreqModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=tf_dropout, + ) + self.mask_estim = MaskEstimationModule( + in_channel=in_channel, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + + +class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + cond_dim: int = 0, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, + use_freq_weights: bool = True, + mult_add_mask: bool = False, + ) -> None: + + super().__init__() + self.instantiate_bandsplit( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + ) + + self.tf_model = SeqBandModellingModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ) + + self.mult_add_mask = mult_add_mask + + self.instantiate_mask_estim( + in_channel=in_channel, + stems=stems, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=overlapping_band, + freq_weights=freq_weights, + n_freq=n_freq, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask, + ) + + @staticmethod + def _mult_add_mask(x, m): + + assert m.ndim == 5 + + mm = m[..., 0] + am = m[..., 1] + + # print(mm.shape, am.shape, x.shape, m.shape) + + return x * mm + am + + def mask(self, x, m): + if self.mult_add_mask: + + return self._mult_add_mask(x, m) + else: + return super().mask(x, m) + + +class MultiSourceMultiMaskBandSplitCoreTransformer( + MultiMaskBandSplitCoreBase, +): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + tf_dropout: float = 0.0, + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, + use_freq_weights: bool = True, + rnn_type: str = "LSTM", + cond_dim: int = 0, + mult_add_mask: bool = False, + ) -> None: + super().__init__() + self.instantiate_bandsplit( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + ) + self.tf_model = TransformerTimeFreqModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=tf_dropout, + ) + + self.instantiate_mask_estim( + in_channel=in_channel, + stems=stems, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=overlapping_band, + freq_weights=freq_weights, + n_freq=n_freq, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask, + ) + + +class MultiSourceMultiMaskBandSplitCoreConv( + MultiMaskBandSplitCoreBase, +): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + tf_dropout: float = 0.0, + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, + use_freq_weights: bool = True, + rnn_type: str = "LSTM", + cond_dim: int = 0, + mult_add_mask: bool = False, + ) -> None: + super().__init__() + self.instantiate_bandsplit( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + ) + self.tf_model = ConvolutionalTimeFreqModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=tf_dropout, + ) + + self.instantiate_mask_estim( + in_channel=in_channel, + stems=stems, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=overlapping_band, + freq_weights=freq_weights, + n_freq=n_freq, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask, + ) + + +class PatchingMaskBandsplitCoreBase(MultiMaskBandSplitCoreBase): + def __init__(self) -> None: + super().__init__() + + def mask(self, x, m): + # x.shape = (batch, n_channel, n_freq, n_time) + # m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time) + + _, n_channel, kernel_freq, kernel_time, n_freq, n_time = m.shape + padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2) + + xf = F.unfold( + x, + kernel_size=(kernel_freq, kernel_time), + padding=padding, + stride=(1, 1), + ) + + xf = xf.view( + -1, + n_channel, + kernel_freq, + kernel_time, + n_freq, + n_time, + ) + + sf = xf * m + + sf = sf.view( + -1, + n_channel * kernel_freq * kernel_time, + n_freq * n_time, + ) + + s = F.fold( + sf, + output_size=(n_freq, n_time), + kernel_size=(kernel_freq, kernel_time), + padding=padding, + stride=(1, 1), + ).view( + -1, + n_channel, + n_freq, + n_time, + ) + + return s + + def old_mask(self, x, m): + # x.shape = (batch, n_channel, n_freq, n_time) + # m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time) + + s = torch.zeros_like(x) + + _, n_channel, n_freq, n_time = x.shape + kernel_freq, kernel_time, _, _, _, _ = m.shape + + # print(x.shape, m.shape) + + kernel_freq_half = (kernel_freq - 1) // 2 + kernel_time_half = (kernel_time - 1) // 2 + + for ifreq in range(kernel_freq): + for itime in range(kernel_time): + df, dt = kernel_freq_half - ifreq, kernel_time_half - itime + x = x.roll(shifts=(df, dt), dims=(2, 3)) + + # if `df` > 0: + # x[:, :, :df, :] = 0 + # elif `df` < 0: + # x[:, :, df:, :] = 0 + + # if `dt` > 0: + # x[:, :, :, :dt] = 0 + # elif `dt` < 0: + # x[:, :, :, dt:] = 0 + + fslice = slice(max(0, df), min(n_freq, n_freq + df)) + tslice = slice(max(0, dt), min(n_time, n_time + dt)) + + s[:, :, fslice, tslice] += ( + x[:, :, fslice, tslice] * m[ifreq, itime, :, :, fslice, tslice] + ) + + return s + + +class MultiSourceMultiPatchingMaskBandSplitCoreRNN(PatchingMaskBandsplitCoreBase): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + mask_kernel_freq: int, + mask_kernel_time: int, + conv_kernel_freq: int, + conv_kernel_time: int, + kernel_norm_mlp_version: int, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, + ) -> None: + + super().__init__() + self.band_split = BandSplitModule( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + ) + + self.tf_model = SeqBandModellingModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ) + + if hidden_activation_kwargs is None: + hidden_activation_kwargs = {} + + if overlapping_band: + assert freq_weights is not None + assert n_freq is not None + self.mask_estim = nn.ModuleDict( + { + stem: PatchingMaskEstimationModule( + band_specs=band_specs, + freq_weights=freq_weights, + n_freq=n_freq, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + mask_kernel_freq=mask_kernel_freq, + mask_kernel_time=mask_kernel_time, + conv_kernel_freq=conv_kernel_freq, + conv_kernel_time=conv_kernel_time, + kernel_norm_mlp_version=kernel_norm_mlp_version, + ) + for stem in stems + } + ) + else: + raise NotImplementedError diff --git a/mvsepless/models/bandit/core/model/bsrnn/maskestim.py b/mvsepless/models/bandit/core/model/bsrnn/maskestim.py new file mode 100644 index 0000000000000000000000000000000000000000..067f79ce28b5e39de1259710b3444e837c4eb960 --- /dev/null +++ b/mvsepless/models/bandit/core/model/bsrnn/maskestim.py @@ -0,0 +1,351 @@ +import warnings +from typing import Dict, List, Optional, Tuple, Type + +import torch +from torch import nn +from torch.nn.modules import activation + +from .utils import ( + band_widths_from_specs, + check_no_gap, + check_no_overlap, + check_nonzero_bandwidth, +) + + +class BaseNormMLP(nn.Module): + def __init__( + self, + emb_dim: int, + mlp_dim: int, + bandwidth: int, + in_channel: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs=None, + complex_mask: bool = True, + ): + + super().__init__() + if hidden_activation_kwargs is None: + hidden_activation_kwargs = {} + self.hidden_activation_kwargs = hidden_activation_kwargs + self.norm = nn.LayerNorm(emb_dim) + self.hidden = torch.jit.script( + nn.Sequential( + nn.Linear(in_features=emb_dim, out_features=mlp_dim), + activation.__dict__[hidden_activation](**self.hidden_activation_kwargs), + ) + ) + + self.bandwidth = bandwidth + self.in_channel = in_channel + + self.complex_mask = complex_mask + self.reim = 2 if complex_mask else 1 + self.glu_mult = 2 + + +class NormMLP(BaseNormMLP): + def __init__( + self, + emb_dim: int, + mlp_dim: int, + bandwidth: int, + in_channel: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs=None, + complex_mask: bool = True, + ) -> None: + super().__init__( + emb_dim=emb_dim, + mlp_dim=mlp_dim, + bandwidth=bandwidth, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + + self.output = torch.jit.script( + nn.Sequential( + nn.Linear( + in_features=mlp_dim, + out_features=bandwidth * in_channel * self.reim * 2, + ), + nn.GLU(dim=-1), + ) + ) + + def reshape_output(self, mb): + # print(mb.shape) + batch, n_time, _ = mb.shape + if self.complex_mask: + mb = mb.reshape( + batch, n_time, self.in_channel, self.bandwidth, self.reim + ).contiguous() + # print(mb.shape) + mb = torch.view_as_complex(mb) # (batch, n_time, in_channel, bandwidth) + else: + mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth) + + mb = torch.permute(mb, (0, 2, 3, 1)) # (batch, in_channel, bandwidth, n_time) + + return mb + + def forward(self, qb): + # qb = (batch, n_time, emb_dim) + + # if torch.any(torch.isnan(qb)): + # raise ValueError("qb0") + + qb = self.norm(qb) # (batch, n_time, emb_dim) + + # if torch.any(torch.isnan(qb)): + # raise ValueError("qb1") + + qb = self.hidden(qb) # (batch, n_time, mlp_dim) + # if torch.any(torch.isnan(qb)): + # raise ValueError("qb2") + mb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim) + # if torch.any(torch.isnan(qb)): + # raise ValueError("mb") + mb = self.reshape_output(mb) # (batch, in_channel, bandwidth, n_time) + + return mb + + +class MultAddNormMLP(NormMLP): + def __init__( + self, + emb_dim: int, + mlp_dim: int, + bandwidth: int, + in_channel: "int | None", + hidden_activation: str = "Tanh", + hidden_activation_kwargs=None, + complex_mask: bool = True, + ) -> None: + super().__init__( + emb_dim, + mlp_dim, + bandwidth, + in_channel, + hidden_activation, + hidden_activation_kwargs, + complex_mask, + ) + + self.output2 = torch.jit.script( + nn.Sequential( + nn.Linear( + in_features=mlp_dim, + out_features=bandwidth * in_channel * self.reim * 2, + ), + nn.GLU(dim=-1), + ) + ) + + def forward(self, qb): + + qb = self.norm(qb) # (batch, n_time, emb_dim) + qb = self.hidden(qb) # (batch, n_time, mlp_dim) + mmb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim) + mmb = self.reshape_output(mmb) # (batch, in_channel, bandwidth, n_time) + amb = self.output2(qb) # (batch, n_time, bandwidth * in_channel * reim) + amb = self.reshape_output(amb) # (batch, in_channel, bandwidth, n_time) + + return mmb, amb + + +class MaskEstimationModuleSuperBase(nn.Module): + pass + + +class MaskEstimationModuleBase(MaskEstimationModuleSuperBase): + def __init__( + self, + band_specs: List[Tuple[float, float]], + emb_dim: int, + mlp_dim: int, + in_channel: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict = None, + complex_mask: bool = True, + norm_mlp_cls: Type[nn.Module] = NormMLP, + norm_mlp_kwargs: Dict = None, + ) -> None: + super().__init__() + + self.band_widths = band_widths_from_specs(band_specs) + self.n_bands = len(band_specs) + + if hidden_activation_kwargs is None: + hidden_activation_kwargs = {} + + if norm_mlp_kwargs is None: + norm_mlp_kwargs = {} + + self.norm_mlp = nn.ModuleList( + [ + ( + norm_mlp_cls( + bandwidth=self.band_widths[b], + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + **norm_mlp_kwargs, + ) + ) + for b in range(self.n_bands) + ] + ) + + def compute_masks(self, q): + batch, n_bands, n_time, emb_dim = q.shape + + masks = [] + + for b, nmlp in enumerate(self.norm_mlp): + # print(f"maskestim/{b:02d}") + qb = q[:, b, :, :] + mb = nmlp(qb) + masks.append(mb) + + return masks + + +class OverlappingMaskEstimationModule(MaskEstimationModuleBase): + def __init__( + self, + in_channel: int, + band_specs: List[Tuple[float, float]], + freq_weights: List[torch.Tensor], + n_freq: int, + emb_dim: int, + mlp_dim: int, + cond_dim: int = 0, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict = None, + complex_mask: bool = True, + norm_mlp_cls: Type[nn.Module] = NormMLP, + norm_mlp_kwargs: Dict = None, + use_freq_weights: bool = True, + ) -> None: + check_nonzero_bandwidth(band_specs) + check_no_gap(band_specs) + + # if cond_dim > 0: + # raise NotImplementedError + + super().__init__( + band_specs=band_specs, + emb_dim=emb_dim + cond_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + norm_mlp_cls=norm_mlp_cls, + norm_mlp_kwargs=norm_mlp_kwargs, + ) + + self.n_freq = n_freq + self.band_specs = band_specs + self.in_channel = in_channel + + if freq_weights is not None: + for i, fw in enumerate(freq_weights): + self.register_buffer(f"freq_weights/{i}", fw) + + self.use_freq_weights = use_freq_weights + else: + self.use_freq_weights = False + + self.cond_dim = cond_dim + + def forward(self, q, cond=None): + # q = (batch, n_bands, n_time, emb_dim) + + batch, n_bands, n_time, emb_dim = q.shape + + if cond is not None: + print(cond) + if cond.ndim == 2: + cond = cond[:, None, None, :].expand(-1, n_bands, n_time, -1) + elif cond.ndim == 3: + assert cond.shape[1] == n_time + else: + raise ValueError(f"Invalid cond shape: {cond.shape}") + + q = torch.cat([q, cond], dim=-1) + elif self.cond_dim > 0: + cond = torch.ones( + (batch, n_bands, n_time, self.cond_dim), + device=q.device, + dtype=q.dtype, + ) + q = torch.cat([q, cond], dim=-1) + else: + pass + + mask_list = self.compute_masks( + q + ) # [n_bands * (batch, in_channel, bandwidth, n_time)] + + masks = torch.zeros( + (batch, self.in_channel, self.n_freq, n_time), + device=q.device, + dtype=mask_list[0].dtype, + ) + + for im, mask in enumerate(mask_list): + fstart, fend = self.band_specs[im] + if self.use_freq_weights: + fw = self.get_buffer(f"freq_weights/{im}")[:, None] + mask = mask * fw + masks[:, :, fstart:fend, :] += mask + + return masks + + +class MaskEstimationModule(OverlappingMaskEstimationModule): + def __init__( + self, + band_specs: List[Tuple[float, float]], + emb_dim: int, + mlp_dim: int, + in_channel: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict = None, + complex_mask: bool = True, + **kwargs, + ) -> None: + check_nonzero_bandwidth(band_specs) + check_no_gap(band_specs) + check_no_overlap(band_specs) + super().__init__( + in_channel=in_channel, + band_specs=band_specs, + freq_weights=None, + n_freq=None, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + + def forward(self, q, cond=None): + # q = (batch, n_bands, n_time, emb_dim) + + masks = self.compute_masks( + q + ) # [n_bands * (batch, in_channel, bandwidth, n_time)] + + # TODO: currently this requires band specs to have no gap and no overlap + masks = torch.concat(masks, dim=2) # (batch, in_channel, n_freq, n_time) + + return masks diff --git a/mvsepless/models/bandit/core/model/bsrnn/tfmodel.py b/mvsepless/models/bandit/core/model/bsrnn/tfmodel.py new file mode 100644 index 0000000000000000000000000000000000000000..f482a118f5a9ac7f9f2d5bc36725155b3d8049db --- /dev/null +++ b/mvsepless/models/bandit/core/model/bsrnn/tfmodel.py @@ -0,0 +1,320 @@ +import warnings + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.modules import rnn + +import torch.backends.cuda + + +class TimeFrequencyModellingModule(nn.Module): + def __init__(self) -> None: + super().__init__() + + +class ResidualRNN(nn.Module): + def __init__( + self, + emb_dim: int, + rnn_dim: int, + bidirectional: bool = True, + rnn_type: str = "LSTM", + use_batch_trick: bool = True, + use_layer_norm: bool = True, + ) -> None: + # n_group is the size of the 2nd dim + super().__init__() + + self.use_layer_norm = use_layer_norm + if use_layer_norm: + self.norm = nn.LayerNorm(emb_dim) + else: + self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim) + + self.rnn = rnn.__dict__[rnn_type]( + input_size=emb_dim, + hidden_size=rnn_dim, + num_layers=1, + batch_first=True, + bidirectional=bidirectional, + ) + + self.fc = nn.Linear( + in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim + ) + + self.use_batch_trick = use_batch_trick + if not self.use_batch_trick: + warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!") + + def forward(self, z): + # z = (batch, n_uncrossed, n_across, emb_dim) + + z0 = torch.clone(z) + + # print(z.device) + + if self.use_layer_norm: + z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim) + else: + z = torch.permute( + z, (0, 3, 1, 2) + ) # (batch, emb_dim, n_uncrossed, n_across) + + z = self.norm(z) # (batch, emb_dim, n_uncrossed, n_across) + + z = torch.permute( + z, (0, 2, 3, 1) + ) # (batch, n_uncrossed, n_across, emb_dim) + + batch, n_uncrossed, n_across, emb_dim = z.shape + + if self.use_batch_trick: + z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim)) + + z = self.rnn(z.contiguous())[ + 0 + ] # (batch * n_uncrossed, n_across, dir_rnn_dim) + + z = torch.reshape(z, (batch, n_uncrossed, n_across, -1)) + # (batch, n_uncrossed, n_across, dir_rnn_dim) + else: + # Note: this is EXTREMELY SLOW + zlist = [] + for i in range(n_uncrossed): + zi = self.rnn(z[:, i, :, :])[0] # (batch, n_across, emb_dim) + zlist.append(zi) + + z = torch.stack(zlist, dim=1) # (batch, n_uncrossed, n_across, dir_rnn_dim) + + z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim) + + z = z + z0 + + return z + + +class SeqBandModellingModule(TimeFrequencyModellingModule): + def __init__( + self, + n_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + parallel_mode=False, + ) -> None: + super().__init__() + self.seqband = nn.ModuleList([]) + + if parallel_mode: + for _ in range(n_modules): + self.seqband.append( + nn.ModuleList( + [ + ResidualRNN( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ), + ResidualRNN( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ), + ] + ) + ) + else: + + for _ in range(2 * n_modules): + self.seqband.append( + ResidualRNN( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ) + ) + + self.parallel_mode = parallel_mode + + def forward(self, z): + # z = (batch, n_bands, n_time, emb_dim) + + if self.parallel_mode: + for sbm_pair in self.seqband: + # z: (batch, n_bands, n_time, emb_dim) + sbm_t, sbm_f = sbm_pair[0], sbm_pair[1] + zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim) + zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim) + z = zt + zf.transpose(1, 2) + else: + for sbm in self.seqband: + z = sbm(z) + z = z.transpose(1, 2) + + # (batch, n_bands, n_time, emb_dim) + # --> (batch, n_time, n_bands, emb_dim) + # OR + # (batch, n_time, n_bands, emb_dim) + # --> (batch, n_bands, n_time, emb_dim) + + q = z + return q # (batch, n_bands, n_time, emb_dim) + + +class ResidualTransformer(nn.Module): + def __init__( + self, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + dropout: float = 0.0, + ) -> None: + # n_group is the size of the 2nd dim + super().__init__() + + self.tf = nn.TransformerEncoderLayer( + d_model=emb_dim, nhead=4, dim_feedforward=rnn_dim, batch_first=True + ) + + self.is_causal = not bidirectional + self.dropout = dropout + + def forward(self, z): + batch, n_uncrossed, n_across, emb_dim = z.shape + z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim)) + z = self.tf( + z, is_causal=self.is_causal + ) # (batch, n_uncrossed, n_across, emb_dim) + z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim)) + + return z + + +class TransformerTimeFreqModule(TimeFrequencyModellingModule): + def __init__( + self, + n_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.norm = nn.LayerNorm(emb_dim) + self.seqband = nn.ModuleList([]) + + for _ in range(2 * n_modules): + self.seqband.append( + ResidualTransformer( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=dropout, + ) + ) + + def forward(self, z): + # z = (batch, n_bands, n_time, emb_dim) + z = self.norm(z) # (batch, n_bands, n_time, emb_dim) + + for sbm in self.seqband: + z = sbm(z) + z = z.transpose(1, 2) + + # (batch, n_bands, n_time, emb_dim) + # --> (batch, n_time, n_bands, emb_dim) + # OR + # (batch, n_time, n_bands, emb_dim) + # --> (batch, n_bands, n_time, emb_dim) + + q = z + return q # (batch, n_bands, n_time, emb_dim) + + +class ResidualConvolution(nn.Module): + def __init__( + self, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + dropout: float = 0.0, + ) -> None: + # n_group is the size of the 2nd dim + super().__init__() + self.norm = nn.InstanceNorm2d(emb_dim, affine=True) + + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=emb_dim, + out_channels=rnn_dim, + kernel_size=(3, 3), + padding="same", + stride=(1, 1), + ), + nn.Tanhshrink(), + ) + + self.is_causal = not bidirectional + self.dropout = dropout + + self.fc = nn.Conv2d( + in_channels=rnn_dim, + out_channels=emb_dim, + kernel_size=(1, 1), + padding="same", + stride=(1, 1), + ) + + def forward(self, z): + # z = (batch, n_uncrossed, n_across, emb_dim) + + z0 = torch.clone(z) + + z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim) + z = self.conv(z) # (batch, n_uncrossed, n_across, emb_dim) + z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim) + z = z + z0 + + return z + + +class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule): + def __init__( + self, + n_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.seqband = torch.jit.script( + nn.Sequential( + *[ + ResidualConvolution( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=dropout, + ) + for _ in range(2 * n_modules) + ] + ) + ) + + def forward(self, z): + # z = (batch, n_bands, n_time, emb_dim) + + z = torch.permute(z, (0, 3, 1, 2)) # (batch, emb_dim, n_bands, n_time) + + z = self.seqband(z) # (batch, emb_dim, n_bands, n_time) + + z = torch.permute(z, (0, 2, 3, 1)) # (batch, n_bands, n_time, emb_dim) + + return z diff --git a/mvsepless/models/bandit/core/model/bsrnn/utils.py b/mvsepless/models/bandit/core/model/bsrnn/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0f470b180303c48e2e5a02b1319bff743c0c83db --- /dev/null +++ b/mvsepless/models/bandit/core/model/bsrnn/utils.py @@ -0,0 +1,525 @@ +import os +from abc import abstractmethod +from typing import Any, Callable + +import numpy as np +import torch +from librosa import hz_to_midi, midi_to_hz +from torch import Tensor +from torchaudio import functional as taF +from spafe.fbanks import bark_fbanks +from spafe.utils.converters import erb2hz, hz2bark, hz2erb +from torchaudio.functional.functional import _create_triangular_filterbank + + +def band_widths_from_specs(band_specs): + return [e - i for i, e in band_specs] + + +def check_nonzero_bandwidth(band_specs): + # pprint(band_specs) + for fstart, fend in band_specs: + if fend - fstart <= 0: + raise ValueError("Bands cannot be zero-width") + + +def check_no_overlap(band_specs): + fend_prev = -1 + for fstart_curr, fend_curr in band_specs: + if fstart_curr <= fend_prev: + raise ValueError("Bands cannot overlap") + + +def check_no_gap(band_specs): + fstart, _ = band_specs[0] + assert fstart == 0 + + fend_prev = -1 + for fstart_curr, fend_curr in band_specs: + if fstart_curr - fend_prev > 1: + raise ValueError("Bands cannot leave gap") + fend_prev = fend_curr + + +class BandsplitSpecification: + def __init__(self, nfft: int, fs: int) -> None: + self.fs = fs + self.nfft = nfft + self.nyquist = fs / 2 + self.max_index = nfft // 2 + 1 + + self.split500 = self.hertz_to_index(500) + self.split1k = self.hertz_to_index(1000) + self.split2k = self.hertz_to_index(2000) + self.split4k = self.hertz_to_index(4000) + self.split8k = self.hertz_to_index(8000) + self.split16k = self.hertz_to_index(16000) + self.split20k = self.hertz_to_index(20000) + + self.above20k = [(self.split20k, self.max_index)] + self.above16k = [(self.split16k, self.split20k)] + self.above20k + + def index_to_hertz(self, index: int): + return index * self.fs / self.nfft + + def hertz_to_index(self, hz: float, round: bool = True): + index = hz * self.nfft / self.fs + + if round: + index = int(np.round(index)) + + return index + + def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz): + band_specs = [] + lower = start_index + + while lower < end_index: + upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz))) + upper = min(upper, end_index) + + band_specs.append((lower, upper)) + lower = upper + + return band_specs + + @abstractmethod + def get_band_specs(self): + raise NotImplementedError + + +class VocalBandsplitSpecification(BandsplitSpecification): + def __init__(self, nfft: int, fs: int, version: str = "7") -> None: + super().__init__(nfft=nfft, fs=fs) + + self.version = version + + def get_band_specs(self): + return getattr(self, f"version{self.version}")() + + @property + def version1(self): + return self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.max_index, bandwidth_hz=1000 + ) + + def version2(self): + below16k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split16k, bandwidth_hz=1000 + ) + below20k = self.get_band_specs_with_bandwidth( + start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000 + ) + + return below16k + below20k + self.above20k + + def version3(self): + below8k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split8k, bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 + ) + + return below8k + below16k + self.above16k + + def version4(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 + ) + + return below1k + below8k + below16k + self.above16k + + def version5(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000 + ) + below20k = self.get_band_specs_with_bandwidth( + start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000 + ) + return below1k + below16k + below20k + self.above20k + + def version6(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 + ) + return below1k + below4k + below8k + below16k + self.above16k + + def version7(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000 + ) + below20k = self.get_band_specs_with_bandwidth( + start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000 + ) + return below1k + below4k + below8k + below16k + below20k + self.above20k + + +class OtherBandsplitSpecification(VocalBandsplitSpecification): + def __init__(self, nfft: int, fs: int) -> None: + super().__init__(nfft=nfft, fs=fs, version="7") + + +class BassBandsplitSpecification(BandsplitSpecification): + def __init__(self, nfft: int, fs: int, version: str = "7") -> None: + super().__init__(nfft=nfft, fs=fs) + + def get_band_specs(self): + below500 = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split500, bandwidth_hz=50 + ) + below1k = self.get_band_specs_with_bandwidth( + start_index=self.split500, end_index=self.split1k, bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 + ) + above16k = [(self.split16k, self.max_index)] + + return below500 + below1k + below4k + below8k + below16k + above16k + + +class DrumBandsplitSpecification(BandsplitSpecification): + def __init__(self, nfft: int, fs: int) -> None: + super().__init__(nfft=nfft, fs=fs) + + def get_band_specs(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=50 + ) + below2k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000 + ) + above16k = [(self.split16k, self.max_index)] + + return below1k + below2k + below4k + below8k + below16k + above16k + + +class PerceptualBandsplitSpecification(BandsplitSpecification): + def __init__( + self, + nfft: int, + fs: int, + fbank_fn: Callable[[int, int, float, float, int], torch.Tensor], + n_bands: int, + f_min: float = 0.0, + f_max: float = None, + ) -> None: + super().__init__(nfft=nfft, fs=fs) + self.n_bands = n_bands + if f_max is None: + f_max = fs / 2 + + self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index) + + weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True) # (1, n_freqs) + normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs) + + freq_weights = [] + band_specs = [] + for i in range(self.n_bands): + active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist() + if isinstance(active_bins, int): + active_bins = (active_bins, active_bins) + if len(active_bins) == 0: + continue + start_index = active_bins[0] + end_index = active_bins[-1] + 1 + band_specs.append((start_index, end_index)) + freq_weights.append(normalized_mel_fb[i, start_index:end_index]) + + self.freq_weights = freq_weights + self.band_specs = band_specs + + def get_band_specs(self): + return self.band_specs + + def get_freq_weights(self): + return self.freq_weights + + def save_to_file(self, dir_path: str) -> None: + + os.makedirs(dir_path, exist_ok=True) + + import pickle + + with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f: + pickle.dump( + { + "band_specs": self.band_specs, + "freq_weights": self.freq_weights, + "filterbank": self.filterbank, + }, + f, + ) + + +def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs): + fb = taF.melscale_fbanks( + n_mels=n_bands, + sample_rate=fs, + f_min=f_min, + f_max=f_max, + n_freqs=n_freqs, + ).T + + fb[0, 0] = 1.0 + + return fb + + +class MelBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None + ) -> None: + super().__init__( + fbank_fn=mel_filterbank, + nfft=nfft, + fs=fs, + n_bands=n_bands, + f_min=f_min, + f_max=f_max, + ) + + +def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"): + + nfft = 2 * (n_freqs - 1) + df = fs / nfft + # init freqs + f_max = f_max or fs / 2 + f_min = f_min or 0 + f_min = fs / nfft + + n_octaves = np.log2(f_max / f_min) + n_octaves_per_band = n_octaves / n_bands + bandwidth_mult = np.power(2.0, n_octaves_per_band) + + low_midi = max(0, hz_to_midi(f_min)) + high_midi = hz_to_midi(f_max) + midi_points = np.linspace(low_midi, high_midi, n_bands) + hz_pts = midi_to_hz(midi_points) + + low_pts = hz_pts / bandwidth_mult + high_pts = hz_pts * bandwidth_mult + + low_bins = np.floor(low_pts / df).astype(int) + high_bins = np.ceil(high_pts / df).astype(int) + + fb = np.zeros((n_bands, n_freqs)) + + for i in range(n_bands): + fb[i, low_bins[i] : high_bins[i] + 1] = 1.0 + + fb[0, : low_bins[0]] = 1.0 + fb[-1, high_bins[-1] + 1 :] = 1.0 + + return torch.as_tensor(fb) + + +class MusicalBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None + ) -> None: + super().__init__( + fbank_fn=musical_filterbank, + nfft=nfft, + fs=fs, + n_bands=n_bands, + f_min=f_min, + f_max=f_max, + ) + + +def bark_filterbank(n_bands, fs, f_min, f_max, n_freqs): + nfft = 2 * (n_freqs - 1) + fb, _ = bark_fbanks.bark_filter_banks( + nfilts=n_bands, + nfft=nfft, + fs=fs, + low_freq=f_min, + high_freq=f_max, + scale="constant", + ) + + return torch.as_tensor(fb) + + +class BarkBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None + ) -> None: + super().__init__( + fbank_fn=bark_filterbank, + nfft=nfft, + fs=fs, + n_bands=n_bands, + f_min=f_min, + f_max=f_max, + ) + + +def triangular_bark_filterbank(n_bands, fs, f_min, f_max, n_freqs): + + all_freqs = torch.linspace(0, fs // 2, n_freqs) + + # calculate mel freq bins + m_min = hz2bark(f_min) + m_max = hz2bark(f_max) + + m_pts = torch.linspace(m_min, m_max, n_bands + 2) + f_pts = 600 * torch.sinh(m_pts / 6) + + # create filterbank + fb = _create_triangular_filterbank(all_freqs, f_pts) + + fb = fb.T + + first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0] + first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0] + + fb[first_active_band, :first_active_bin] = 1.0 + + return fb + + +class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None + ) -> None: + super().__init__( + fbank_fn=triangular_bark_filterbank, + nfft=nfft, + fs=fs, + n_bands=n_bands, + f_min=f_min, + f_max=f_max, + ) + + +def minibark_filterbank(n_bands, fs, f_min, f_max, n_freqs): + fb = bark_filterbank(n_bands, fs, f_min, f_max, n_freqs) + + fb[fb < np.sqrt(0.5)] = 0.0 + + return fb + + +class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None + ) -> None: + super().__init__( + fbank_fn=minibark_filterbank, + nfft=nfft, + fs=fs, + n_bands=n_bands, + f_min=f_min, + f_max=f_max, + ) + + +def erb_filterbank( + n_bands: int, + fs: int, + f_min: float, + f_max: float, + n_freqs: int, +) -> Tensor: + # freq bins + A = (1000 * np.log(10)) / (24.7 * 4.37) + all_freqs = torch.linspace(0, fs // 2, n_freqs) + + # calculate mel freq bins + m_min = hz2erb(f_min) + m_max = hz2erb(f_max) + + m_pts = torch.linspace(m_min, m_max, n_bands + 2) + f_pts = (torch.pow(10, (m_pts / A)) - 1) / 0.00437 + + # create filterbank + fb = _create_triangular_filterbank(all_freqs, f_pts) + + fb = fb.T + + first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0] + first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0] + + fb[first_active_band, :first_active_bin] = 1.0 + + return fb + + +class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None + ) -> None: + super().__init__( + fbank_fn=erb_filterbank, + nfft=nfft, + fs=fs, + n_bands=n_bands, + f_min=f_min, + f_max=f_max, + ) + + +if __name__ == "__main__": + import pandas as pd + + band_defs = [] + + for bands in [VocalBandsplitSpecification]: + band_name = bands.__name__.replace("BandsplitSpecification", "") + + mbs = bands(nfft=2048, fs=44100).get_band_specs() + + for i, (f_min, f_max) in enumerate(mbs): + band_defs.append( + {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max} + ) + + df = pd.DataFrame(band_defs) + df.to_csv("vox7bands.csv", index=False) diff --git a/mvsepless/models/bandit/core/model/bsrnn/wrapper.py b/mvsepless/models/bandit/core/model/bsrnn/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..b077b03b1712f186bdb51764537a13a3731c0aa1 --- /dev/null +++ b/mvsepless/models/bandit/core/model/bsrnn/wrapper.py @@ -0,0 +1,829 @@ +from pprint import pprint +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import nn + +from .._spectral import _SpectralComponent +from .utils import ( + BarkBandsplitSpecification, + BassBandsplitSpecification, + DrumBandsplitSpecification, + EquivalentRectangularBandsplitSpecification, + MelBandsplitSpecification, + MusicalBandsplitSpecification, + OtherBandsplitSpecification, + TriangularBarkBandsplitSpecification, + VocalBandsplitSpecification, +) +from .core import ( + MultiSourceMultiMaskBandSplitCoreConv, + MultiSourceMultiMaskBandSplitCoreRNN, + MultiSourceMultiMaskBandSplitCoreTransformer, + MultiSourceMultiPatchingMaskBandSplitCoreRNN, + SingleMaskBandsplitCoreRNN, + SingleMaskBandsplitCoreTransformer, +) + +import pytorch_lightning as pl + + +def get_band_specs(band_specs, n_fft, fs, n_bands=None): + if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]: + bsm = VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs() + freq_weights = None + overlapping_band = False + elif "tribark" in band_specs: + assert n_bands is not None + specs = TriangularBarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands) + bsm = specs.get_band_specs() + freq_weights = specs.get_freq_weights() + overlapping_band = True + elif "bark" in band_specs: + assert n_bands is not None + specs = BarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands) + bsm = specs.get_band_specs() + freq_weights = specs.get_freq_weights() + overlapping_band = True + elif "erb" in band_specs: + assert n_bands is not None + specs = EquivalentRectangularBandsplitSpecification( + nfft=n_fft, fs=fs, n_bands=n_bands + ) + bsm = specs.get_band_specs() + freq_weights = specs.get_freq_weights() + overlapping_band = True + elif "musical" in band_specs: + assert n_bands is not None + specs = MusicalBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands) + bsm = specs.get_band_specs() + freq_weights = specs.get_freq_weights() + overlapping_band = True + elif band_specs == "dnr:mel" or "mel" in band_specs: + assert n_bands is not None + specs = MelBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands) + bsm = specs.get_band_specs() + freq_weights = specs.get_freq_weights() + overlapping_band = True + else: + raise NameError + + return bsm, freq_weights, overlapping_band + + +def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None): + if band_specs_map == "musdb:all": + bsm = { + "vocals": VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(), + "drums": DrumBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(), + "bass": BassBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(), + "other": OtherBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(), + } + freq_weights = None + overlapping_band = False + elif band_specs_map == "dnr:vox7": + bsm_, freq_weights, overlapping_band = get_band_specs( + "dnr:speech", n_fft, fs, n_bands + ) + bsm = {"speech": bsm_, "music": bsm_, "effects": bsm_} + elif "dnr:vox7:" in band_specs_map: + stem = band_specs_map.split(":")[-1] + bsm_, freq_weights, overlapping_band = get_band_specs( + "dnr:speech", n_fft, fs, n_bands + ) + bsm = {stem: bsm_} + else: + raise NameError + + return bsm, freq_weights, overlapping_band + + +class BandSplitWrapperBase(pl.LightningModule): + bsrnn: nn.Module + + def __init__(self, **kwargs): + super().__init__() + + +class SingleMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent): + def __init__( + self, + band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]], + fs: int = 44100, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + ) -> None: + super().__init__( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + ) + + if isinstance(band_specs_map, str): + self.band_specs_map, self.freq_weights, self.overlapping_band = ( + get_band_specs_map(band_specs_map, n_fft, fs, n_bands=n_bands) + ) + + self.stems = list(self.band_specs_map.keys()) + + def forward(self, batch): + audio = batch["audio"] + + with torch.no_grad(): + batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio} + + X = batch["spectrogram"]["mixture"] + length = batch["audio"]["mixture"].shape[-1] + + output = {"spectrogram": {}, "audio": {}} + + for stem, bsrnn in self.bsrnn.items(): + S = bsrnn(X) + s = self.istft(S, length) + output["spectrogram"][stem] = S + output["audio"][stem] = s + + return batch, output + + +class MultiMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent): + def __init__( + self, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + ) -> None: + super().__init__( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + ) + + if isinstance(band_specs, str): + self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs( + band_specs, n_fft, fs, n_bands + ) + + self.stems = stems + + def forward(self, batch): + # with torch.no_grad(): + audio = batch["audio"] + cond = batch.get("condition", None) + with torch.no_grad(): + batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio} + + X = batch["spectrogram"]["mixture"] + length = batch["audio"]["mixture"].shape[-1] + + output = self.bsrnn(X, cond=cond) + output["audio"] = {} + + for stem, S in output["spectrogram"].items(): + s = self.istft(S, length) + output["audio"][stem] = s + + return batch, output + + +class MultiMaskMultiSourceBandSplitBaseSimple(BandSplitWrapperBase, _SpectralComponent): + def __init__( + self, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + ) -> None: + super().__init__( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + ) + + if isinstance(band_specs, str): + self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs( + band_specs, n_fft, fs, n_bands + ) + + self.stems = stems + + def forward(self, batch): + with torch.no_grad(): + X = self.stft(batch) + length = batch.shape[-1] + output = self.bsrnn(X, cond=None) + res = [] + for stem, S in output["spectrogram"].items(): + s = self.istft(S, length) + res.append(s) + res = torch.stack(res, dim=1) + return res + + +class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase): + def __init__( + self, + in_channel: int, + band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + ) -> None: + super().__init__( + band_specs_map=band_specs_map, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + ) + + self.bsrnn = nn.ModuleDict( + { + src: SingleMaskBandsplitCoreRNN( + band_specs=specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + for src, specs in self.band_specs_map.items() + } + ) + + +class SingleMaskMultiSourceBandSplitTransformer(SingleMaskMultiSourceBandSplitBase): + def __init__( + self, + in_channel: int, + band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + tf_dropout: float = 0.0, + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + ) -> None: + super().__init__( + band_specs_map=band_specs_map, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + ) + + self.bsrnn = nn.ModuleDict( + { + src: SingleMaskBandsplitCoreTransformer( + band_specs=specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + tf_dropout=tf_dropout, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + for src, specs in self.band_specs_map.items() + } + ) + + +class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + cond_dim: int = 0, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + use_freq_weights: bool = True, + normalize_input: bool = False, + mult_add_mask: bool = False, + freeze_encoder: bool = False, + ) -> None: + super().__init__( + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, + ) + + self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN( + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask, + ) + + self.normalize_input = normalize_input + self.cond_dim = cond_dim + + if freeze_encoder: + for param in self.bsrnn.band_split.parameters(): + param.requires_grad = False + + for param in self.bsrnn.tf_model.parameters(): + param.requires_grad = False + + +class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + cond_dim: int = 0, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + use_freq_weights: bool = True, + normalize_input: bool = False, + mult_add_mask: bool = False, + freeze_encoder: bool = False, + ) -> None: + super().__init__( + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, + ) + + self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN( + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask, + ) + + self.normalize_input = normalize_input + self.cond_dim = cond_dim + + if freeze_encoder: + for param in self.bsrnn.band_split.parameters(): + param.requires_grad = False + + for param in self.bsrnn.tf_model.parameters(): + param.requires_grad = False + + +class MultiMaskMultiSourceBandSplitTransformer(MultiMaskMultiSourceBandSplitBase): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + cond_dim: int = 0, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + use_freq_weights: bool = True, + normalize_input: bool = False, + mult_add_mask: bool = False, + ) -> None: + super().__init__( + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, + ) + + self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer( + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask, + ) + + +class MultiMaskMultiSourceBandSplitConv(MultiMaskMultiSourceBandSplitBase): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + cond_dim: int = 0, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + use_freq_weights: bool = True, + normalize_input: bool = False, + mult_add_mask: bool = False, + ) -> None: + super().__init__( + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, + ) + + self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv( + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask, + ) + + +class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + kernel_norm_mlp_version: int = 1, + mask_kernel_freq: int = 3, + mask_kernel_time: int = 3, + conv_kernel_freq: int = 1, + conv_kernel_time: int = 1, + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + ) -> None: + super().__init__( + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, + ) + + self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN( + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + mask_kernel_freq=mask_kernel_freq, + mask_kernel_time=mask_kernel_time, + conv_kernel_freq=conv_kernel_freq, + conv_kernel_time=conv_kernel_time, + kernel_norm_mlp_version=kernel_norm_mlp_version, + ) diff --git a/mvsepless/models/bandit/core/utils/__init__.py b/mvsepless/models/bandit/core/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mvsepless/models/bandit/core/utils/audio.py b/mvsepless/models/bandit/core/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..20fbfa6d1f20e763f380e5cc3c341384b4e3af16 --- /dev/null +++ b/mvsepless/models/bandit/core/utils/audio.py @@ -0,0 +1,412 @@ +from collections import defaultdict + +from tqdm.auto import tqdm +from typing import Callable, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + + +@torch.jit.script +def merge( + combined: torch.Tensor, + original_batch_size: int, + n_channel: int, + n_chunks: int, + chunk_size: int, +): + combined = torch.reshape( + combined, (original_batch_size, n_chunks, n_channel, chunk_size) + ) + combined = torch.permute(combined, (0, 2, 3, 1)).reshape( + original_batch_size * n_channel, chunk_size, n_chunks + ) + + return combined + + +@torch.jit.script +def unfold( + padded_audio: torch.Tensor, + original_batch_size: int, + n_channel: int, + chunk_size: int, + hop_size: int, +) -> torch.Tensor: + + unfolded_input = F.unfold( + padded_audio[:, :, None, :], kernel_size=(1, chunk_size), stride=(1, hop_size) + ) + + _, _, n_chunks = unfolded_input.shape + unfolded_input = unfolded_input.view( + original_batch_size, n_channel, chunk_size, n_chunks + ) + unfolded_input = torch.permute(unfolded_input, (0, 3, 1, 2)).reshape( + original_batch_size * n_chunks, n_channel, chunk_size + ) + + return unfolded_input + + +@torch.jit.script +# @torch.compile +def merge_chunks_all( + combined: torch.Tensor, + original_batch_size: int, + n_channel: int, + n_samples: int, + n_padded_samples: int, + n_chunks: int, + chunk_size: int, + hop_size: int, + edge_frame_pad_sizes: Tuple[int, int], + standard_window: torch.Tensor, + first_window: torch.Tensor, + last_window: torch.Tensor, +): + combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size) + + combined = combined * standard_window[:, None].to(combined.device) + + combined = F.fold( + combined.to(torch.float32), + output_size=(1, n_padded_samples), + kernel_size=(1, chunk_size), + stride=(1, hop_size), + ) + + combined = combined.view(original_batch_size, n_channel, n_padded_samples) + + pad_front, pad_back = edge_frame_pad_sizes + combined = combined[..., pad_front:-pad_back] + + combined = combined[..., :n_samples] + + return combined + + # @torch.jit.script + + +def merge_chunks_edge( + combined: torch.Tensor, + original_batch_size: int, + n_channel: int, + n_samples: int, + n_padded_samples: int, + n_chunks: int, + chunk_size: int, + hop_size: int, + edge_frame_pad_sizes: Tuple[int, int], + standard_window: torch.Tensor, + first_window: torch.Tensor, + last_window: torch.Tensor, +): + combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size) + + combined[..., 0] = combined[..., 0] * first_window + combined[..., -1] = combined[..., -1] * last_window + combined[..., 1:-1] = combined[..., 1:-1] * standard_window[:, None] + + combined = F.fold( + combined, + output_size=(1, n_padded_samples), + kernel_size=(1, chunk_size), + stride=(1, hop_size), + ) + + combined = combined.view(original_batch_size, n_channel, n_padded_samples) + + combined = combined[..., :n_samples] + + return combined + + +class BaseFader(nn.Module): + def __init__( + self, + chunk_size_second: float, + hop_size_second: float, + fs: int, + fade_edge_frames: bool, + batch_size: int, + ) -> None: + super().__init__() + + self.chunk_size = int(chunk_size_second * fs) + self.hop_size = int(hop_size_second * fs) + self.overlap_size = self.chunk_size - self.hop_size + self.fade_edge_frames = fade_edge_frames + self.batch_size = batch_size + + # @torch.jit.script + def prepare(self, audio): + + if self.fade_edge_frames: + audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect") + + n_samples = audio.shape[-1] + n_chunks = int(np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1) + + padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size + pad_size = padded_size - n_samples + + padded_audio = F.pad(audio, (0, pad_size)) + + return padded_audio, n_chunks + + def forward( + self, + audio: torch.Tensor, + model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]], + ): + + original_dtype = audio.dtype + original_device = audio.device + + audio = audio.to("cpu") + + original_batch_size, n_channel, n_samples = audio.shape + padded_audio, n_chunks = self.prepare(audio) + del audio + n_padded_samples = padded_audio.shape[-1] + + if n_channel > 1: + padded_audio = padded_audio.view( + original_batch_size * n_channel, 1, n_padded_samples + ) + + unfolded_input = unfold( + padded_audio, original_batch_size, n_channel, self.chunk_size, self.hop_size + ) + + n_total_chunks, n_channel, chunk_size = unfolded_input.shape + + n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int) + + chunks_in = [ + unfolded_input[b * self.batch_size : (b + 1) * self.batch_size, ...].clone() + for b in range(n_batch) + ] + + all_chunks_out = defaultdict( + lambda: torch.zeros_like(unfolded_input, device="cpu") + ) + + # for b, cin in enumerate(tqdm(chunks_in)): + for b, cin in enumerate(chunks_in): + if torch.allclose(cin, torch.tensor(0.0)): + del cin + continue + + chunks_out = model_fn(cin.to(original_device)) + del cin + for s, c in chunks_out.items(): + all_chunks_out[s][ + b * self.batch_size : (b + 1) * self.batch_size, ... + ] = c.cpu() + del chunks_out + + del unfolded_input + del padded_audio + + if self.fade_edge_frames: + fn = merge_chunks_all + else: + fn = merge_chunks_edge + outputs = {} + + torch.cuda.empty_cache() + + for s, c in all_chunks_out.items(): + combined: torch.Tensor = fn( + c, + original_batch_size, + n_channel, + n_samples, + n_padded_samples, + n_chunks, + self.chunk_size, + self.hop_size, + self.edge_frame_pad_sizes, + self.standard_window, + self.__dict__.get("first_window", self.standard_window), + self.__dict__.get("last_window", self.standard_window), + ) + + outputs[s] = combined.to(dtype=original_dtype, device=original_device) + + return {"audio": outputs} + + # + # def old_forward( + # self, + # audio: torch.Tensor, + # model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]], + # ): + # + # n_samples = audio.shape[-1] + # original_batch_size = audio.shape[0] + # + # padded_audio, n_chunks = self.prepare(audio) + # + # ndim = padded_audio.ndim + # broadcaster = [1 for _ in range(ndim - 1)] + [self.chunk_size] + # + # outputs = defaultdict( + # lambda: torch.zeros_like( + # padded_audio, device=audio.device, dtype=torch.float64 + # ) + # ) + # + # all_chunks_out = [] + # len_chunks_in = [] + # + # batch_size_ = int(self.batch_size // original_batch_size) + # for b in range(int(np.ceil(n_chunks / batch_size_))): + # chunks_in = [] + # for j in range(batch_size_): + # i = b * batch_size_ + j + # if i == n_chunks: + # break + # + # start = i * hop_size + # end = start + self.chunk_size + # chunk_in = padded_audio[..., start:end] + # chunks_in.append(chunk_in) + # + # chunks_in = torch.concat(chunks_in, dim=0) + # chunks_out = model_fn(chunks_in) + # all_chunks_out.append(chunks_out) + # len_chunks_in.append(len(chunks_in)) + # + # for b, (chunks_out, lci) in enumerate( + # zip(all_chunks_out, len_chunks_in) + # ): + # for stem in chunks_out: + # for j in range(lci // original_batch_size): + # i = b * batch_size_ + j + # + # if self.fade_edge_frames: + # window = self.standard_window + # else: + # if i == 0: + # window = self.first_window + # elif i == n_chunks - 1: + # window = self.last_window + # else: + # window = self.standard_window + # + # start = i * hop_size + # end = start + self.chunk_size + # + # chunk_out = chunks_out[stem][j * original_batch_size: (j + 1) * original_batch_size, + # ...] + # contrib = window.view(*broadcaster) * chunk_out + # outputs[stem][..., start:end] = ( + # outputs[stem][..., start:end] + contrib + # ) + # + # if self.fade_edge_frames: + # pad_front, pad_back = self.edge_frame_pad_sizes + # outputs = {k: v[..., pad_front:-pad_back] for k, v in + # outputs.items()} + # + # outputs = {k: v[..., :n_samples].to(audio.dtype) for k, v in + # outputs.items()} + # + # return { + # "audio": outputs + # } + + +class LinearFader(BaseFader): + def __init__( + self, + chunk_size_second: float, + hop_size_second: float, + fs: int, + fade_edge_frames: bool = False, + batch_size: int = 1, + ) -> None: + + assert hop_size_second >= chunk_size_second / 2 + + super().__init__( + chunk_size_second=chunk_size_second, + hop_size_second=hop_size_second, + fs=fs, + fade_edge_frames=fade_edge_frames, + batch_size=batch_size, + ) + + in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1] + out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:] + center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size) + inout_ones = torch.ones(self.overlap_size) + + # using nn.Parameters allows lightning to take care of devices for us + self.register_buffer( + "standard_window", torch.concat([in_fade, center_ones, out_fade]) + ) + + self.fade_edge_frames = fade_edge_frames + self.edge_frame_pad_size = (self.overlap_size, self.overlap_size) + + if not self.fade_edge_frames: + self.first_window = nn.Parameter( + torch.concat([inout_ones, center_ones, out_fade]), requires_grad=False + ) + self.last_window = nn.Parameter( + torch.concat([in_fade, center_ones, inout_ones]), requires_grad=False + ) + + +class OverlapAddFader(BaseFader): + def __init__( + self, + window_type: str, + chunk_size_second: float, + hop_size_second: float, + fs: int, + batch_size: int = 1, + ) -> None: + assert (chunk_size_second / hop_size_second) % 2 == 0 + assert int(chunk_size_second * fs) % 2 == 0 + + super().__init__( + chunk_size_second=chunk_size_second, + hop_size_second=hop_size_second, + fs=fs, + fade_edge_frames=True, + batch_size=batch_size, + ) + + self.hop_multiplier = self.chunk_size / (2 * self.hop_size) + # print(f"hop multiplier: {self.hop_multiplier}") + + self.edge_frame_pad_sizes = (2 * self.overlap_size, 2 * self.overlap_size) + + self.register_buffer( + "standard_window", + torch.windows.__dict__[window_type]( + self.chunk_size, + sym=False, # dtype=torch.float64 + ) + / self.hop_multiplier, + ) + + +if __name__ == "__main__": + import torchaudio as ta + + fs = 44100 + ola = OverlapAddFader("hann", 6.0, 1.0, fs, batch_size=16) + audio_, _ = ta.load( + "$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too " "Much/vocals.wav" + ) + audio_ = audio_[None, ...] + out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"] + print(torch.allclose(out, audio_)) diff --git a/mvsepless/models/bandit/model_from_config.py b/mvsepless/models/bandit/model_from_config.py new file mode 100644 index 0000000000000000000000000000000000000000..3007853faf7de657be8fceb18d6b5e720ee4a60f --- /dev/null +++ b/mvsepless/models/bandit/model_from_config.py @@ -0,0 +1,26 @@ +import sys +import os.path +import torch + +import yaml +from ml_collections import ConfigDict + +torch.set_float32_matmul_precision("medium") + + +def get_model( + config_path, + weights_path, + device, +): + from .core.model import MultiMaskMultiSourceBandSplitRNNSimple + + f = open(config_path) + config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) + f.close() + + model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model) + d = torch.load(code_path + "model_bandit_plus_dnr_sdr_11.47.chpt") + model.load_state_dict(d) + model.to(device) + return model, config diff --git a/mvsepless/models/bandit_v2/bandit.py b/mvsepless/models/bandit_v2/bandit.py new file mode 100644 index 0000000000000000000000000000000000000000..fba32962f242ce44f64ab7d8ca433290cbe6d7e2 --- /dev/null +++ b/mvsepless/models/bandit_v2/bandit.py @@ -0,0 +1,363 @@ +from typing import Dict, List, Optional + +import torch +import torchaudio as ta +from torch import nn +import pytorch_lightning as pl + +from .bandsplit import BandSplitModule +from .maskestim import OverlappingMaskEstimationModule +from .tfmodel import SeqBandModellingModule +from .utils import MusicalBandsplitSpecification + + +class BaseEndToEndModule(pl.LightningModule): + def __init__( + self, + ) -> None: + super().__init__() + + +class BaseBandit(BaseEndToEndModule): + def __init__( + self, + in_channels: int, + fs: int, + band_type: str = "musical", + n_bands: int = 64, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + + self.instantitate_spectral( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + normalized=normalized, + center=center, + pad_mode=pad_mode, + onesided=onesided, + ) + + self.instantiate_bandsplit( + in_channels=in_channels, + band_type=band_type, + n_bands=n_bands, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + n_fft=n_fft, + fs=fs, + ) + + self.instantiate_tf_modelling( + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ) + + def instantitate_spectral( + self, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + normalized: bool = True, + center: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + ): + assert power is None + + window_fn = torch.__dict__[window_fn] + + self.stft = ta.transforms.Spectrogram( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + pad_mode=pad_mode, + pad=0, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + normalized=normalized, + center=center, + onesided=onesided, + ) + + self.istft = ta.transforms.InverseSpectrogram( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + pad_mode=pad_mode, + pad=0, + window_fn=window_fn, + wkwargs=wkwargs, + normalized=normalized, + center=center, + onesided=onesided, + ) + + def instantiate_bandsplit( + self, + in_channels: int, + band_type: str = "musical", + n_bands: int = 64, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + emb_dim: int = 128, + n_fft: int = 2048, + fs: int = 44100, + ): + assert band_type == "musical" + + self.band_specs = MusicalBandsplitSpecification( + nfft=n_fft, fs=fs, n_bands=n_bands + ) + + self.band_split = BandSplitModule( + in_channels=in_channels, + band_specs=self.band_specs.get_band_specs(), + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + ) + + def instantiate_tf_modelling( + self, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + ): + try: + self.tf_model = torch.compile( + SeqBandModellingModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ), + disable=True, + ) + except Exception as e: + self.tf_model = SeqBandModellingModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ) + + def mask(self, x, m): + return x * m + + def forward(self, batch, mode="train"): + # Model takes mono as input we give stereo, so we do process of each channel independently + init_shape = batch.shape + if not isinstance(batch, dict): + mono = batch.view(-1, 1, batch.shape[-1]) + batch = {"mixture": {"audio": mono}} + + with torch.no_grad(): + mixture = batch["mixture"]["audio"] + + x = self.stft(mixture) + batch["mixture"]["spectrogram"] = x + + if "sources" in batch.keys(): + for stem in batch["sources"].keys(): + s = batch["sources"][stem]["audio"] + s = self.stft(s) + batch["sources"][stem]["spectrogram"] = s + + batch = self.separate(batch) + + if 1: + b = [] + for s in self.stems: + # We need to obtain stereo again + r = batch["estimates"][s]["audio"].view( + -1, init_shape[1], init_shape[2] + ) + b.append(r) + # And we need to return back tensor and not independent stems + batch = torch.stack(b, dim=1) + return batch + + def encode(self, batch): + x = batch["mixture"]["spectrogram"] + length = batch["mixture"]["audio"].shape[-1] + + z = self.band_split(x) # (batch, emb_dim, n_band, n_time) + q = self.tf_model(z) # (batch, emb_dim, n_band, n_time) + + return x, q, length + + def separate(self, batch): + raise NotImplementedError + + +class Bandit(BaseBandit): + def __init__( + self, + in_channels: int, + stems: List[str], + band_type: str = "musical", + n_bands: int = 64, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict | None = None, + complex_mask: bool = True, + use_freq_weights: bool = True, + n_fft: int = 2048, + win_length: int | None = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Dict | None = None, + power: int | None = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + fs: int = 44100, + stft_precisions="32", + bandsplit_precisions="bf16", + tf_model_precisions="bf16", + mask_estim_precisions="bf16", + ): + super().__init__( + in_channels=in_channels, + band_type=band_type, + n_bands=n_bands, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + fs=fs, + ) + + self.stems = stems + + self.instantiate_mask_estim( + in_channels=in_channels, + stems=stems, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + n_freq=n_fft // 2 + 1, + use_freq_weights=use_freq_weights, + ) + + def instantiate_mask_estim( + self, + in_channels: int, + stems: List[str], + emb_dim: int, + mlp_dim: int, + hidden_activation: str, + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_freq: Optional[int] = None, + use_freq_weights: bool = False, + ): + if hidden_activation_kwargs is None: + hidden_activation_kwargs = {} + + assert n_freq is not None + + self.mask_estim = nn.ModuleDict( + { + stem: OverlappingMaskEstimationModule( + band_specs=self.band_specs.get_band_specs(), + freq_weights=self.band_specs.get_freq_weights(), + n_freq=n_freq, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channels=in_channels, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + use_freq_weights=use_freq_weights, + ) + for stem in stems + } + ) + + def separate(self, batch): + batch["estimates"] = {} + + x, q, length = self.encode(batch) + + for stem, mem in self.mask_estim.items(): + m = mem(q) + + s = self.mask(x, m.to(x.dtype)) + s = torch.reshape(s, x.shape) + batch["estimates"][stem] = { + "audio": self.istft(s, length), + "spectrogram": s, + } + + return batch diff --git a/mvsepless/models/bandit_v2/bandsplit.py b/mvsepless/models/bandit_v2/bandsplit.py new file mode 100644 index 0000000000000000000000000000000000000000..a14ea52bfa318264d536c9f934d0e28db63e15dc --- /dev/null +++ b/mvsepless/models/bandit_v2/bandsplit.py @@ -0,0 +1,130 @@ +from typing import List, Tuple + +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint_sequential + +from .utils import ( + band_widths_from_specs, + check_no_gap, + check_no_overlap, + check_nonzero_bandwidth, +) + + +class NormFC(nn.Module): + def __init__( + self, + emb_dim: int, + bandwidth: int, + in_channels: int, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + ) -> None: + super().__init__() + + if not treat_channel_as_feature: + raise NotImplementedError + + self.treat_channel_as_feature = treat_channel_as_feature + + if normalize_channel_independently: + raise NotImplementedError + + reim = 2 + + norm = nn.LayerNorm(in_channels * bandwidth * reim) + + fc_in = bandwidth * reim + + if treat_channel_as_feature: + fc_in *= in_channels + else: + assert emb_dim % in_channels == 0 + emb_dim = emb_dim // in_channels + + fc = nn.Linear(fc_in, emb_dim) + + self.combined = nn.Sequential(norm, fc) + + def forward(self, xb): + return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False) + + +class BandSplitModule(nn.Module): + def __init__( + self, + band_specs: List[Tuple[float, float]], + emb_dim: int, + in_channels: int, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + ) -> None: + super().__init__() + + check_nonzero_bandwidth(band_specs) + + if require_no_gap: + check_no_gap(band_specs) + + if require_no_overlap: + check_no_overlap(band_specs) + + self.band_specs = band_specs + # list of [fstart, fend) in index. + # Note that fend is exclusive. + self.band_widths = band_widths_from_specs(band_specs) + self.n_bands = len(band_specs) + self.emb_dim = emb_dim + + try: + self.norm_fc_modules = nn.ModuleList( + [ # type: ignore + torch.compile( + NormFC( + emb_dim=emb_dim, + bandwidth=bw, + in_channels=in_channels, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + ), + disable=True, + ) + for bw in self.band_widths + ] + ) + except Exception as e: + self.norm_fc_modules = nn.ModuleList( + [ # type: ignore + NormFC( + emb_dim=emb_dim, + bandwidth=bw, + in_channels=in_channels, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + ) + for bw in self.band_widths + ] + ) + + def forward(self, x: torch.Tensor): + # x = complex spectrogram (batch, in_chan, n_freq, n_time) + + batch, in_chan, band_width, n_time = x.shape + + z = torch.zeros( + size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device + ) + + x = torch.permute(x, (0, 3, 1, 2)).contiguous() + + for i, nfm in enumerate(self.norm_fc_modules): + fstart, fend = self.band_specs[i] + xb = x[:, :, :, fstart:fend] + xb = torch.view_as_real(xb) + xb = torch.reshape(xb, (batch, n_time, -1)) + z[:, i, :, :] = nfm(xb) + + return z diff --git a/mvsepless/models/bandit_v2/film.py b/mvsepless/models/bandit_v2/film.py new file mode 100644 index 0000000000000000000000000000000000000000..253594ad0154cee4ef7467036ed71f4d5f836db8 --- /dev/null +++ b/mvsepless/models/bandit_v2/film.py @@ -0,0 +1,23 @@ +from torch import nn +import torch + + +class FiLM(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, gamma, beta): + return gamma * x + beta + + +class BTFBroadcastedFiLM(nn.Module): + def __init__(self): + super().__init__() + self.film = FiLM() + + def forward(self, x, gamma, beta): + + gamma = gamma[None, None, None, :] + beta = beta[None, None, None, :] + + return self.film(x, gamma, beta) diff --git a/mvsepless/models/bandit_v2/maskestim.py b/mvsepless/models/bandit_v2/maskestim.py new file mode 100644 index 0000000000000000000000000000000000000000..65215d86a5e94dafdb71744aafadf7aaab93330d --- /dev/null +++ b/mvsepless/models/bandit_v2/maskestim.py @@ -0,0 +1,281 @@ +from typing import Dict, List, Optional, Tuple, Type + +import torch +from torch import nn +from torch.nn.modules import activation +from torch.utils.checkpoint import checkpoint_sequential + +from .utils import ( + band_widths_from_specs, + check_no_gap, + check_no_overlap, + check_nonzero_bandwidth, +) + + +class BaseNormMLP(nn.Module): + def __init__( + self, + emb_dim: int, + mlp_dim: int, + bandwidth: int, + in_channels: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs=None, + complex_mask: bool = True, + ): + super().__init__() + if hidden_activation_kwargs is None: + hidden_activation_kwargs = {} + self.hidden_activation_kwargs = hidden_activation_kwargs + self.norm = nn.LayerNorm(emb_dim) + self.hidden = nn.Sequential( + nn.Linear(in_features=emb_dim, out_features=mlp_dim), + activation.__dict__[hidden_activation](**self.hidden_activation_kwargs), + ) + + self.bandwidth = bandwidth + self.in_channels = in_channels + + self.complex_mask = complex_mask + self.reim = 2 if complex_mask else 1 + self.glu_mult = 2 + + +class NormMLP(BaseNormMLP): + def __init__( + self, + emb_dim: int, + mlp_dim: int, + bandwidth: int, + in_channels: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs=None, + complex_mask: bool = True, + ) -> None: + super().__init__( + emb_dim=emb_dim, + mlp_dim=mlp_dim, + bandwidth=bandwidth, + in_channels=in_channels, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + + self.output = nn.Sequential( + nn.Linear( + in_features=mlp_dim, + out_features=bandwidth * in_channels * self.reim * 2, + ), + nn.GLU(dim=-1), + ) + + try: + self.combined = torch.compile( + nn.Sequential(self.norm, self.hidden, self.output), disable=True + ) + except Exception as e: + self.combined = nn.Sequential(self.norm, self.hidden, self.output) + + def reshape_output(self, mb): + # print(mb.shape) + batch, n_time, _ = mb.shape + if self.complex_mask: + mb = mb.reshape( + batch, n_time, self.in_channels, self.bandwidth, self.reim + ).contiguous() + # print(mb.shape) + mb = torch.view_as_complex(mb) # (batch, n_time, in_channels, bandwidth) + else: + mb = mb.reshape(batch, n_time, self.in_channels, self.bandwidth) + + mb = torch.permute(mb, (0, 2, 3, 1)) # (batch, in_channels, bandwidth, n_time) + + return mb + + def forward(self, qb): + # qb = (batch, n_time, emb_dim) + # qb = self.norm(qb) # (batch, n_time, emb_dim) + # qb = self.hidden(qb) # (batch, n_time, mlp_dim) + # mb = self.output(qb) # (batch, n_time, bandwidth * in_channels * reim) + + mb = checkpoint_sequential(self.combined, 2, qb, use_reentrant=False) + mb = self.reshape_output(mb) # (batch, in_channels, bandwidth, n_time) + + return mb + + +class MaskEstimationModuleSuperBase(nn.Module): + pass + + +class MaskEstimationModuleBase(MaskEstimationModuleSuperBase): + def __init__( + self, + band_specs: List[Tuple[float, float]], + emb_dim: int, + mlp_dim: int, + in_channels: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict = None, + complex_mask: bool = True, + norm_mlp_cls: Type[nn.Module] = NormMLP, + norm_mlp_kwargs: Dict = None, + ) -> None: + super().__init__() + + self.band_widths = band_widths_from_specs(band_specs) + self.n_bands = len(band_specs) + + if hidden_activation_kwargs is None: + hidden_activation_kwargs = {} + + if norm_mlp_kwargs is None: + norm_mlp_kwargs = {} + + self.norm_mlp = nn.ModuleList( + [ + norm_mlp_cls( + bandwidth=self.band_widths[b], + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channels=in_channels, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + **norm_mlp_kwargs, + ) + for b in range(self.n_bands) + ] + ) + + def compute_masks(self, q): + batch, n_bands, n_time, emb_dim = q.shape + + masks = [] + + for b, nmlp in enumerate(self.norm_mlp): + # print(f"maskestim/{b:02d}") + qb = q[:, b, :, :] + mb = nmlp(qb) + masks.append(mb) + + return masks + + def compute_mask(self, q, b): + batch, n_bands, n_time, emb_dim = q.shape + qb = q[:, b, :, :] + mb = self.norm_mlp[b](qb) + return mb + + +class OverlappingMaskEstimationModule(MaskEstimationModuleBase): + def __init__( + self, + in_channels: int, + band_specs: List[Tuple[float, float]], + freq_weights: List[torch.Tensor], + n_freq: int, + emb_dim: int, + mlp_dim: int, + cond_dim: int = 0, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict = None, + complex_mask: bool = True, + norm_mlp_cls: Type[nn.Module] = NormMLP, + norm_mlp_kwargs: Dict = None, + use_freq_weights: bool = False, + ) -> None: + check_nonzero_bandwidth(band_specs) + check_no_gap(band_specs) + + if cond_dim > 0: + raise NotImplementedError + + super().__init__( + band_specs=band_specs, + emb_dim=emb_dim + cond_dim, + mlp_dim=mlp_dim, + in_channels=in_channels, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + norm_mlp_cls=norm_mlp_cls, + norm_mlp_kwargs=norm_mlp_kwargs, + ) + + self.n_freq = n_freq + self.band_specs = band_specs + self.in_channels = in_channels + + if freq_weights is not None and use_freq_weights: + for i, fw in enumerate(freq_weights): + self.register_buffer(f"freq_weights/{i}", fw) + + self.use_freq_weights = use_freq_weights + else: + self.use_freq_weights = False + + def forward(self, q): + # q = (batch, n_bands, n_time, emb_dim) + + batch, n_bands, n_time, emb_dim = q.shape + + masks = torch.zeros( + (batch, self.in_channels, self.n_freq, n_time), + device=q.device, + dtype=torch.complex64, + ) + + for im in range(n_bands): + fstart, fend = self.band_specs[im] + + mask = self.compute_mask(q, im) + + if self.use_freq_weights: + fw = self.get_buffer(f"freq_weights/{im}")[:, None] + mask = mask * fw + masks[:, :, fstart:fend, :] += mask + + return masks + + +class MaskEstimationModule(OverlappingMaskEstimationModule): + def __init__( + self, + band_specs: List[Tuple[float, float]], + emb_dim: int, + mlp_dim: int, + in_channels: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict = None, + complex_mask: bool = True, + **kwargs, + ) -> None: + check_nonzero_bandwidth(band_specs) + check_no_gap(band_specs) + check_no_overlap(band_specs) + super().__init__( + in_channels=in_channels, + band_specs=band_specs, + freq_weights=None, + n_freq=None, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + + def forward(self, q, cond=None): + # q = (batch, n_bands, n_time, emb_dim) + + masks = self.compute_masks( + q + ) # [n_bands * (batch, in_channels, bandwidth, n_time)] + + # TODO: currently this requires band specs to have no gap and no overlap + masks = torch.concat(masks, dim=2) # (batch, in_channels, n_freq, n_time) + + return masks diff --git a/mvsepless/models/bandit_v2/tfmodel.py b/mvsepless/models/bandit_v2/tfmodel.py new file mode 100644 index 0000000000000000000000000000000000000000..21aef03d1f0e814c20db05fe7d14f8019f07713b --- /dev/null +++ b/mvsepless/models/bandit_v2/tfmodel.py @@ -0,0 +1,145 @@ +import warnings + +import torch +import torch.backends.cuda +from torch import nn +from torch.nn.modules import rnn +from torch.utils.checkpoint import checkpoint_sequential + + +class TimeFrequencyModellingModule(nn.Module): + def __init__(self) -> None: + super().__init__() + + +class ResidualRNN(nn.Module): + def __init__( + self, + emb_dim: int, + rnn_dim: int, + bidirectional: bool = True, + rnn_type: str = "LSTM", + use_batch_trick: bool = True, + use_layer_norm: bool = True, + ) -> None: + # n_group is the size of the 2nd dim + super().__init__() + + assert use_layer_norm + assert use_batch_trick + + self.use_layer_norm = use_layer_norm + self.norm = nn.LayerNorm(emb_dim) + self.rnn = rnn.__dict__[rnn_type]( + input_size=emb_dim, + hidden_size=rnn_dim, + num_layers=1, + batch_first=True, + bidirectional=bidirectional, + ) + + self.fc = nn.Linear( + in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim + ) + + self.use_batch_trick = use_batch_trick + if not self.use_batch_trick: + warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!") + + def forward(self, z): + # z = (batch, n_uncrossed, n_across, emb_dim) + + z0 = torch.clone(z) + z = self.norm(z) + + batch, n_uncrossed, n_across, emb_dim = z.shape + z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim)) + z = self.rnn(z)[0] + z = torch.reshape(z, (batch, n_uncrossed, n_across, -1)) + + z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim) + + z = z + z0 + + return z + + +class Transpose(nn.Module): + def __init__(self, dim0: int, dim1: int) -> None: + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, z): + return z.transpose(self.dim0, self.dim1) + + +class SeqBandModellingModule(TimeFrequencyModellingModule): + def __init__( + self, + n_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + parallel_mode=False, + ) -> None: + super().__init__() + + self.n_modules = n_modules + + if parallel_mode: + self.seqband = nn.ModuleList([]) + for _ in range(n_modules): + self.seqband.append( + nn.ModuleList( + [ + ResidualRNN( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ), + ResidualRNN( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ), + ] + ) + ) + else: + seqband = [] + for _ in range(2 * n_modules): + seqband += [ + ResidualRNN( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ), + Transpose(1, 2), + ] + + self.seqband = nn.Sequential(*seqband) + + self.parallel_mode = parallel_mode + + def forward(self, z): + # z = (batch, n_bands, n_time, emb_dim) + + if self.parallel_mode: + for sbm_pair in self.seqband: + # z: (batch, n_bands, n_time, emb_dim) + sbm_t, sbm_f = sbm_pair[0], sbm_pair[1] + zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim) + zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim) + z = zt + zf.transpose(1, 2) + else: + z = checkpoint_sequential( + self.seqband, self.n_modules, z, use_reentrant=False + ) + + q = z + return q # (batch, n_bands, n_time, emb_dim) diff --git a/mvsepless/models/bandit_v2/utils.py b/mvsepless/models/bandit_v2/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ad4eab5d8c5b5396ed717f5b9c365a6900eddd2f --- /dev/null +++ b/mvsepless/models/bandit_v2/utils.py @@ -0,0 +1,523 @@ +import os +from abc import abstractmethod +from typing import Callable + +import numpy as np +import torch +from librosa import hz_to_midi, midi_to_hz +from torchaudio import functional as taF + +# from spafe.fbanks import bark_fbanks +# from spafe.utils.converters import erb2hz, hz2bark, hz2erb + + +def band_widths_from_specs(band_specs): + return [e - i for i, e in band_specs] + + +def check_nonzero_bandwidth(band_specs): + # pprint(band_specs) + for fstart, fend in band_specs: + if fend - fstart <= 0: + raise ValueError("Bands cannot be zero-width") + + +def check_no_overlap(band_specs): + fend_prev = -1 + for fstart_curr, fend_curr in band_specs: + if fstart_curr <= fend_prev: + raise ValueError("Bands cannot overlap") + + +def check_no_gap(band_specs): + fstart, _ = band_specs[0] + assert fstart == 0 + + fend_prev = -1 + for fstart_curr, fend_curr in band_specs: + if fstart_curr - fend_prev > 1: + raise ValueError("Bands cannot leave gap") + fend_prev = fend_curr + + +class BandsplitSpecification: + def __init__(self, nfft: int, fs: int) -> None: + self.fs = fs + self.nfft = nfft + self.nyquist = fs / 2 + self.max_index = nfft // 2 + 1 + + self.split500 = self.hertz_to_index(500) + self.split1k = self.hertz_to_index(1000) + self.split2k = self.hertz_to_index(2000) + self.split4k = self.hertz_to_index(4000) + self.split8k = self.hertz_to_index(8000) + self.split16k = self.hertz_to_index(16000) + self.split20k = self.hertz_to_index(20000) + + self.above20k = [(self.split20k, self.max_index)] + self.above16k = [(self.split16k, self.split20k)] + self.above20k + + def index_to_hertz(self, index: int): + return index * self.fs / self.nfft + + def hertz_to_index(self, hz: float, round: bool = True): + index = hz * self.nfft / self.fs + + if round: + index = int(np.round(index)) + + return index + + def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz): + band_specs = [] + lower = start_index + + while lower < end_index: + upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz))) + upper = min(upper, end_index) + + band_specs.append((lower, upper)) + lower = upper + + return band_specs + + @abstractmethod + def get_band_specs(self): + raise NotImplementedError + + +class VocalBandsplitSpecification(BandsplitSpecification): + def __init__(self, nfft: int, fs: int, version: str = "7") -> None: + super().__init__(nfft=nfft, fs=fs) + + self.version = version + + def get_band_specs(self): + return getattr(self, f"version{self.version}")() + + @property + def version1(self): + return self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.max_index, bandwidth_hz=1000 + ) + + def version2(self): + below16k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split16k, bandwidth_hz=1000 + ) + below20k = self.get_band_specs_with_bandwidth( + start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000 + ) + + return below16k + below20k + self.above20k + + def version3(self): + below8k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split8k, bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 + ) + + return below8k + below16k + self.above16k + + def version4(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 + ) + + return below1k + below8k + below16k + self.above16k + + def version5(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000 + ) + below20k = self.get_band_specs_with_bandwidth( + start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000 + ) + return below1k + below16k + below20k + self.above20k + + def version6(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 + ) + return below1k + below4k + below8k + below16k + self.above16k + + def version7(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000 + ) + below20k = self.get_band_specs_with_bandwidth( + start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000 + ) + return below1k + below4k + below8k + below16k + below20k + self.above20k + + +class OtherBandsplitSpecification(VocalBandsplitSpecification): + def __init__(self, nfft: int, fs: int) -> None: + super().__init__(nfft=nfft, fs=fs, version="7") + + +class BassBandsplitSpecification(BandsplitSpecification): + def __init__(self, nfft: int, fs: int, version: str = "7") -> None: + super().__init__(nfft=nfft, fs=fs) + + def get_band_specs(self): + below500 = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split500, bandwidth_hz=50 + ) + below1k = self.get_band_specs_with_bandwidth( + start_index=self.split500, end_index=self.split1k, bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 + ) + above16k = [(self.split16k, self.max_index)] + + return below500 + below1k + below4k + below8k + below16k + above16k + + +class DrumBandsplitSpecification(BandsplitSpecification): + def __init__(self, nfft: int, fs: int) -> None: + super().__init__(nfft=nfft, fs=fs) + + def get_band_specs(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=50 + ) + below2k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000 + ) + above16k = [(self.split16k, self.max_index)] + + return below1k + below2k + below4k + below8k + below16k + above16k + + +class PerceptualBandsplitSpecification(BandsplitSpecification): + def __init__( + self, + nfft: int, + fs: int, + fbank_fn: Callable[[int, int, float, float, int], torch.Tensor], + n_bands: int, + f_min: float = 0.0, + f_max: float = None, + ) -> None: + super().__init__(nfft=nfft, fs=fs) + self.n_bands = n_bands + if f_max is None: + f_max = fs / 2 + + self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index) + + weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True) # (1, n_freqs) + normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs) + + freq_weights = [] + band_specs = [] + for i in range(self.n_bands): + active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist() + if isinstance(active_bins, int): + active_bins = (active_bins, active_bins) + if len(active_bins) == 0: + continue + start_index = active_bins[0] + end_index = active_bins[-1] + 1 + band_specs.append((start_index, end_index)) + freq_weights.append(normalized_mel_fb[i, start_index:end_index]) + + self.freq_weights = freq_weights + self.band_specs = band_specs + + def get_band_specs(self): + return self.band_specs + + def get_freq_weights(self): + return self.freq_weights + + def save_to_file(self, dir_path: str) -> None: + os.makedirs(dir_path, exist_ok=True) + + import pickle + + with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f: + pickle.dump( + { + "band_specs": self.band_specs, + "freq_weights": self.freq_weights, + "filterbank": self.filterbank, + }, + f, + ) + + +def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs): + fb = taF.melscale_fbanks( + n_mels=n_bands, + sample_rate=fs, + f_min=f_min, + f_max=f_max, + n_freqs=n_freqs, + ).T + + fb[0, 0] = 1.0 + + return fb + + +class MelBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None + ) -> None: + super().__init__( + fbank_fn=mel_filterbank, + nfft=nfft, + fs=fs, + n_bands=n_bands, + f_min=f_min, + f_max=f_max, + ) + + +def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"): + nfft = 2 * (n_freqs - 1) + df = fs / nfft + # init freqs + f_max = f_max or fs / 2 + f_min = f_min or 0 + f_min = fs / nfft + + n_octaves = np.log2(f_max / f_min) + n_octaves_per_band = n_octaves / n_bands + bandwidth_mult = np.power(2.0, n_octaves_per_band) + + low_midi = max(0, hz_to_midi(f_min)) + high_midi = hz_to_midi(f_max) + midi_points = np.linspace(low_midi, high_midi, n_bands) + hz_pts = midi_to_hz(midi_points) + + low_pts = hz_pts / bandwidth_mult + high_pts = hz_pts * bandwidth_mult + + low_bins = np.floor(low_pts / df).astype(int) + high_bins = np.ceil(high_pts / df).astype(int) + + fb = np.zeros((n_bands, n_freqs)) + + for i in range(n_bands): + fb[i, low_bins[i] : high_bins[i] + 1] = 1.0 + + fb[0, : low_bins[0]] = 1.0 + fb[-1, high_bins[-1] + 1 :] = 1.0 + + return torch.as_tensor(fb) + + +class MusicalBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None + ) -> None: + super().__init__( + fbank_fn=musical_filterbank, + nfft=nfft, + fs=fs, + n_bands=n_bands, + f_min=f_min, + f_max=f_max, + ) + + +# def bark_filterbank( +# n_bands, fs, f_min, f_max, n_freqs +# ): +# nfft = 2 * (n_freqs -1) +# fb, _ = bark_fbanks.bark_filter_banks( +# nfilts=n_bands, +# nfft=nfft, +# fs=fs, +# low_freq=f_min, +# high_freq=f_max, +# scale="constant" +# ) + +# return torch.as_tensor(fb) + +# class BarkBandsplitSpecification(PerceptualBandsplitSpecification): +# def __init__( +# self, +# nfft: int, +# fs: int, +# n_bands: int, +# f_min: float = 0.0, +# f_max: float = None +# ) -> None: +# super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + + +# def triangular_bark_filterbank( +# n_bands, fs, f_min, f_max, n_freqs +# ): + +# all_freqs = torch.linspace(0, fs // 2, n_freqs) + +# # calculate mel freq bins +# m_min = hz2bark(f_min) +# m_max = hz2bark(f_max) + +# m_pts = torch.linspace(m_min, m_max, n_bands + 2) +# f_pts = 600 * torch.sinh(m_pts / 6) + +# # create filterbank +# fb = _create_triangular_filterbank(all_freqs, f_pts) + +# fb = fb.T + +# first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0] +# first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0] + +# fb[first_active_band, :first_active_bin] = 1.0 + +# return fb + +# class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification): +# def __init__( +# self, +# nfft: int, +# fs: int, +# n_bands: int, +# f_min: float = 0.0, +# f_max: float = None +# ) -> None: +# super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + + +# def minibark_filterbank( +# n_bands, fs, f_min, f_max, n_freqs +# ): +# fb = bark_filterbank( +# n_bands, +# fs, +# f_min, +# f_max, +# n_freqs +# ) + +# fb[fb < np.sqrt(0.5)] = 0.0 + +# return fb + +# class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification): +# def __init__( +# self, +# nfft: int, +# fs: int, +# n_bands: int, +# f_min: float = 0.0, +# f_max: float = None +# ) -> None: +# super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + + +# def erb_filterbank( +# n_bands: int, +# fs: int, +# f_min: float, +# f_max: float, +# n_freqs: int, +# ) -> Tensor: +# # freq bins +# A = (1000 * np.log(10)) / (24.7 * 4.37) +# all_freqs = torch.linspace(0, fs // 2, n_freqs) + +# # calculate mel freq bins +# m_min = hz2erb(f_min) +# m_max = hz2erb(f_max) + +# m_pts = torch.linspace(m_min, m_max, n_bands + 2) +# f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437 + +# # create filterbank +# fb = _create_triangular_filterbank(all_freqs, f_pts) + +# fb = fb.T + + +# first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0] +# first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0] + +# fb[first_active_band, :first_active_bin] = 1.0 + +# return fb + + +# class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification): +# def __init__( +# self, +# nfft: int, +# fs: int, +# n_bands: int, +# f_min: float = 0.0, +# f_max: float = None +# ) -> None: +# super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + +if __name__ == "__main__": + import pandas as pd + + band_defs = [] + + for bands in [VocalBandsplitSpecification]: + band_name = bands.__name__.replace("BandsplitSpecification", "") + + mbs = bands(nfft=2048, fs=44100).get_band_specs() + + for i, (f_min, f_max) in enumerate(mbs): + band_defs.append( + {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max} + ) + + df = pd.DataFrame(band_defs) + df.to_csv("vox7bands.csv", index=False) diff --git a/mvsepless/models/bs_roformer/__init__.py b/mvsepless/models/bs_roformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..378a20b247a018ac062afd8db0bc2997841c22e5 --- /dev/null +++ b/mvsepless/models/bs_roformer/__init__.py @@ -0,0 +1,4 @@ +from .bs_roformer import BSRoformer +from .bs_roformer_sw import BSRoformer_SW +from .bs_roformer_fno import BSRoformer_FNO +from .mel_band_roformer import MelBandRoformer diff --git a/mvsepless/models/bs_roformer/attend.py b/mvsepless/models/bs_roformer/attend.py new file mode 100644 index 0000000000000000000000000000000000000000..9ebb7c937268ff4568ef6dbdcbc90abbfc8f0887 --- /dev/null +++ b/mvsepless/models/bs_roformer/attend.py @@ -0,0 +1,144 @@ +from functools import wraps +from packaging import version +from collections import namedtuple + +import os +import torch +from torch import nn, einsum +import torch.nn.functional as F + +from einops import rearrange, reduce + +# constants + +FlashAttentionConfig = namedtuple( + "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"] +) + +# helpers + + +def exists(val): + return val is not None + + +def default(v, d): + return v if exists(v) else d + + +def once(fn): + called = False + + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + + return inner + + +print_once = once(print) + +# main class + + +class Attend(nn.Module): + def __init__(self, dropout=0.0, flash=False, scale=None): + super().__init__() + self.scale = scale + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) + + self.flash = flash + assert not ( + flash and version.parse(torch.__version__) < version.parse("2.0.0") + ), "in order to use flash attention, you must be using pytorch 2.0 or above" + + # determine efficient attention configs for cuda and cpu + + self.cpu_config = FlashAttentionConfig(True, True, True) + self.cuda_config = None + + if not torch.cuda.is_available() or not flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device("cuda")) + device_version = version.parse( + f"{device_properties.major}.{device_properties.minor}" + ) + + if device_version >= version.parse("8.0"): + if os.name == "nt": + print_once( + "Windows OS detected, using math or mem efficient attention if input tensor is on cuda" + ) + self.cuda_config = FlashAttentionConfig(False, True, True) + else: + print_once( + "GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda" + ) + self.cuda_config = FlashAttentionConfig(True, False, False) + else: + print_once( + "GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda" + ) + self.cuda_config = FlashAttentionConfig(False, True, True) + + def flash_attn(self, q, k, v): + _, heads, q_len, _, k_len, is_cuda, device = ( + *q.shape, + k.shape[-2], + q.is_cuda, + q.device, + ) + + if exists(self.scale): + default_scale = q.shape[-1] ** -0.5 + q = q * (self.scale / default_scale) + + # Check if there is a compatible device for flash attention + + config = self.cuda_config if is_cuda else self.cpu_config + + # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale + + with torch.backends.cuda.sdp_kernel(**config._asdict()): + out = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout if self.training else 0.0 + ) + + return out + + def forward(self, q, k, v): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + q_len, k_len, device = q.shape[-2], k.shape[-2], q.device + + scale = default(self.scale, q.shape[-1] ** -0.5) + + if self.flash: + return self.flash_attn(q, k, v) + + # similarity + + sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale + + # attention + + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + + out = einsum(f"b h i j, b h j d -> b h i d", attn, v) + + return out diff --git a/mvsepless/models/bs_roformer/attend_sw.py b/mvsepless/models/bs_roformer/attend_sw.py new file mode 100644 index 0000000000000000000000000000000000000000..fc73e3ea8875ba351901fda08f7b4353248e8ae8 --- /dev/null +++ b/mvsepless/models/bs_roformer/attend_sw.py @@ -0,0 +1,100 @@ +import logging +import os + +import torch +import torch.nn.functional as F +from packaging import version +from torch import Tensor, einsum, nn +from torch.nn.attention import SDPBackend, sdpa_kernel + +logger = logging.getLogger(__name__) + + +class Attend(nn.Module): + def __init__(self, dropout: float = 0.0, flash: bool = False, scale=None): + super().__init__() + self.scale = scale + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) + + self.flash = flash + assert not ( + flash and version.parse(torch.__version__) < version.parse("2.0.0") + ), "expected pytorch >= 2.0.0 to use flash attention" + + # determine efficient attention configs for cuda and cpu + self.cpu_backends = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, + ] + self.cuda_backends: list | None = None + + if not torch.cuda.is_available() or not flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device("cuda")) + device_version = version.parse( + f"{device_properties.major}.{device_properties.minor}" + ) + + if device_version >= version.parse("8.0"): + if os.name == "nt": + cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH] + logger.info(f"windows detected, {cuda_backends=}") + else: + cuda_backends = [SDPBackend.FLASH_ATTENTION] + logger.info(f"gpu compute capability >= 8.0, {cuda_backends=}") + else: + cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH] + logger.info(f"gpu compute capability < 8.0, {cuda_backends=}") + + self.cuda_backends = cuda_backends + + def flash_attn(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + _, _heads, _q_len, _, _k_len, is_cuda, _device = ( + *q.shape, + k.shape[-2], + q.is_cuda, + q.device, + ) # type: ignore + + if self.scale is not None: + default_scale = q.shape[-1] ** -0.5 + q = q * (self.scale / default_scale) + + backends = self.cuda_backends if is_cuda else self.cpu_backends + # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale + with sdpa_kernel(backends=backends): # type: ignore + out = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout if self.training else 0.0 + ) + + return out + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + _q_len, _k_len, _device = q.shape[-2], k.shape[-2], q.device + + scale = self.scale or q.shape[-1] ** -0.5 + + if self.flash: + return self.flash_attn(q, k, v) + + # similarity + sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale + + # attention + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + return out diff --git a/mvsepless/models/bs_roformer/bs_roformer.py b/mvsepless/models/bs_roformer/bs_roformer.py new file mode 100644 index 0000000000000000000000000000000000000000..adf2faed4c2a8a4db7606cb820207834eab256c5 --- /dev/null +++ b/mvsepless/models/bs_roformer/bs_roformer.py @@ -0,0 +1,708 @@ +from functools import partial + +import torch +from torch import nn, einsum, Tensor +from torch.nn import Module, ModuleList +import torch.nn.functional as F + +from .attend import Attend +from torch.utils.checkpoint import checkpoint + +from beartype.typing import Tuple, Optional, List, Callable +from beartype import beartype + +from rotary_embedding_torch import RotaryEmbedding + +from einops import rearrange, pack, unpack +from einops.layers.torch import Rearrange + +# helper functions + + +def exists(val): + return val is not None + + +def default(v, d): + return v if exists(v) else d + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +# norm + + +def l2norm(t): + return F.normalize(t, dim=-1, p=2) + + +class RMSNorm(Module): + def __init__(self, dim): + super().__init__() + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return F.normalize(x, dim=-1) * self.scale * self.gamma + + +# attention + + +class FeedForward(Module): + def __init__(self, dim, mult=4, dropout=0.0): + super().__init__() + dim_inner = int(dim * mult) + self.net = nn.Sequential( + RMSNorm(dim), + nn.Linear(dim, dim_inner), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim_inner, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(Module): + def __init__( + self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True + ): + super().__init__() + self.heads = heads + self.scale = dim_head**-0.5 + dim_inner = heads * dim_head + + self.rotary_embed = rotary_embed + + self.attend = Attend(flash=flash, dropout=dropout) + + self.norm = RMSNorm(dim) + self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False) + + self.to_gates = nn.Linear(dim, heads) + + self.to_out = nn.Sequential( + nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout) + ) + + def forward(self, x): + x = self.norm(x) + + q, k, v = rearrange( + self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads + ) + + if exists(self.rotary_embed): + q = self.rotary_embed.rotate_queries_or_keys(q) + k = self.rotary_embed.rotate_queries_or_keys(k) + + out = self.attend(q, k, v) + + gates = self.to_gates(x) + out = out * rearrange(gates, "b n h -> b h n 1").sigmoid() + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class LinearAttention(Module): + """ + this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al. + """ + + @beartype + def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0): + super().__init__() + dim_inner = dim_head * heads + self.norm = RMSNorm(dim) + + self.to_qkv = nn.Sequential( + nn.Linear(dim, dim_inner * 3, bias=False), + Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads), + ) + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.attend = Attend(scale=scale, dropout=dropout, flash=flash) + + self.to_out = nn.Sequential( + Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False) + ) + + def forward(self, x): + x = self.norm(x) + + q, k, v = self.to_qkv(x) + + q, k = map(l2norm, (q, k)) + q = q * self.temperature.exp() + + out = self.attend(q, k, v) + + return self.to_out(out) + + +class Transformer(Module): + def __init__( + self, + *, + dim, + depth, + dim_head=64, + heads=8, + attn_dropout=0.0, + ff_dropout=0.0, + ff_mult=4, + norm_output=True, + rotary_embed=None, + flash_attn=True, + linear_attn=False, + ): + super().__init__() + self.layers = ModuleList([]) + + for _ in range(depth): + if linear_attn: + attn = LinearAttention( + dim=dim, + dim_head=dim_head, + heads=heads, + dropout=attn_dropout, + flash=flash_attn, + ) + else: + attn = Attention( + dim=dim, + dim_head=dim_head, + heads=heads, + dropout=attn_dropout, + rotary_embed=rotary_embed, + flash=flash_attn, + ) + + self.layers.append( + ModuleList( + [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)] + ) + ) + + self.norm = RMSNorm(dim) if norm_output else nn.Identity() + + def forward(self, x): + + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + return self.norm(x) + + +# bandsplit module + + +class BandSplit(Module): + @beartype + def __init__(self, dim, dim_inputs: Tuple[int, ...]): + super().__init__() + self.dim_inputs = dim_inputs + self.to_features = ModuleList([]) + + for dim_in in dim_inputs: + net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim)) + + self.to_features.append(net) + + def forward(self, x): + x = x.split(self.dim_inputs, dim=-1) + + outs = [] + for split_input, to_feature in zip(x, self.to_features): + split_output = to_feature(split_input) + outs.append(split_output) + + return torch.stack(outs, dim=-2) + + +def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh): + dim_hidden = default(dim_hidden, dim_in) + + net = [] + dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out) + + for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): + is_last = ind == (len(dims) - 2) + + net.append(nn.Linear(layer_dim_in, layer_dim_out)) + + if is_last: + continue + + net.append(activation()) + + return nn.Sequential(*net) + + +class MaskEstimator(Module): + @beartype + def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4): + super().__init__() + self.dim_inputs = dim_inputs + self.to_freqs = ModuleList([]) + dim_hidden = dim * mlp_expansion_factor + + for dim_in in dim_inputs: + net = [] + + mlp = nn.Sequential( + MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1) + ) + + self.to_freqs.append(mlp) + + def forward(self, x): + x = x.unbind(dim=-2) + + outs = [] + + for band_features, mlp in zip(x, self.to_freqs): + freq_out = mlp(band_features) + outs.append(freq_out) + + return torch.cat(outs, dim=-1) + + +# main class + +DEFAULT_FREQS_PER_BANDS = ( + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 48, + 48, + 48, + 48, + 48, + 48, + 48, + 48, + 128, + 129, +) + + +class BSRoformer(Module): + + @beartype + def __init__( + self, + dim, + *, + depth, + stereo=False, + num_stems=1, + time_transformer_depth=2, + freq_transformer_depth=2, + linear_transformer_depth=0, + freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS, + # in the paper, they divide into ~60 bands, test with 1 for starters + dim_head=64, + heads=8, + attn_dropout=0.0, + ff_dropout=0.0, + flash_attn=True, + dim_freqs_in=1025, + stft_n_fft=2048, + stft_hop_length=512, + # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction + stft_win_length=2048, + stft_normalized=False, + stft_window_fn: Optional[Callable] = None, + mask_estimator_depth=2, + multi_stft_resolution_loss_weight=1.0, + multi_stft_resolutions_window_sizes: Tuple[int, ...] = ( + 4096, + 2048, + 1024, + 512, + 256, + ), + multi_stft_hop_size=147, + multi_stft_normalized=False, + multi_stft_window_fn: Callable = torch.hann_window, + mlp_expansion_factor=4, + use_torch_checkpoint=False, + skip_connection=False, + ): + super().__init__() + + self.stereo = stereo + self.audio_channels = 2 if stereo else 1 + self.num_stems = num_stems + self.use_torch_checkpoint = use_torch_checkpoint + self.skip_connection = skip_connection + + self.layers = ModuleList([]) + + transformer_kwargs = dict( + dim=dim, + heads=heads, + dim_head=dim_head, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + flash_attn=flash_attn, + norm_output=False, + ) + + time_rotary_embed = RotaryEmbedding(dim=dim_head) + freq_rotary_embed = RotaryEmbedding(dim=dim_head) + + for _ in range(depth): + tran_modules = [] + if linear_transformer_depth > 0: + tran_modules.append( + Transformer( + depth=linear_transformer_depth, + linear_attn=True, + **transformer_kwargs, + ) + ) + tran_modules.append( + Transformer( + depth=time_transformer_depth, + rotary_embed=time_rotary_embed, + **transformer_kwargs, + ) + ) + tran_modules.append( + Transformer( + depth=freq_transformer_depth, + rotary_embed=freq_rotary_embed, + **transformer_kwargs, + ) + ) + self.layers.append(nn.ModuleList(tran_modules)) + + self.final_norm = RMSNorm(dim) + + self.stft_kwargs = dict( + n_fft=stft_n_fft, + hop_length=stft_hop_length, + win_length=stft_win_length, + normalized=stft_normalized, + ) + + self.stft_window_fn = partial( + default(stft_window_fn, torch.hann_window), stft_win_length + ) + + freqs = torch.stft( + torch.randn(1, 4096), + **self.stft_kwargs, + window=torch.ones(stft_win_length), + return_complex=True, + ).shape[1] + + assert len(freqs_per_bands) > 1 + assert ( + sum(freqs_per_bands) == freqs + ), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}" + + freqs_per_bands_with_complex = tuple( + 2 * f * self.audio_channels for f in freqs_per_bands + ) + + self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex) + + self.mask_estimators = nn.ModuleList([]) + + for _ in range(num_stems): + mask_estimator = MaskEstimator( + dim=dim, + dim_inputs=freqs_per_bands_with_complex, + depth=mask_estimator_depth, + mlp_expansion_factor=mlp_expansion_factor, + ) + + self.mask_estimators.append(mask_estimator) + + # for the multi-resolution stft loss + + self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight + self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes + self.multi_stft_n_fft = stft_n_fft + self.multi_stft_window_fn = multi_stft_window_fn + + self.multi_stft_kwargs = dict( + hop_length=multi_stft_hop_size, normalized=multi_stft_normalized + ) + + def forward(self, raw_audio, target=None, return_loss_breakdown=False): + """ + einops + + b - batch + f - freq + t - time + s - audio channel (1 for mono, 2 for stereo) + n - number of 'stems' + c - complex (2) + d - feature dimension + """ + + device = raw_audio.device + + # defining whether model is loaded on MPS (MacOS GPU accelerator) + x_is_mps = True if device.type == "mps" else False + + if raw_audio.ndim == 2: + raw_audio = rearrange(raw_audio, "b t -> b 1 t") + + channels = raw_audio.shape[1] + assert (not self.stereo and channels == 1) or ( + self.stereo and channels == 2 + ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)" + + # to stft + + raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t") + + stft_window = self.stft_window_fn(device=device) + + # RuntimeError: FFT operations are only supported on MacOS 14+ + # Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used + try: + stft_repr = torch.stft( + raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True + ) + except: + stft_repr = torch.stft( + raw_audio.cpu() if x_is_mps else raw_audio, + **self.stft_kwargs, + window=stft_window.cpu() if x_is_mps else stft_window, + return_complex=True, + ).to(device) + + stft_repr = torch.view_as_real(stft_repr) + + stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c") + stft_repr = rearrange( + stft_repr, "b s f t c -> b (f s) t c" + ) # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting + + x = rearrange(stft_repr, "b f t c -> b t (f c)") + + if self.use_torch_checkpoint: + x = checkpoint(self.band_split, x, use_reentrant=False) + else: + x = self.band_split(x) + + # axial / hierarchical attention + + store = [None] * len(self.layers) + for i, transformer_block in enumerate(self.layers): + + if len(transformer_block) == 3: + linear_transformer, time_transformer, freq_transformer = ( + transformer_block + ) + + x, ft_ps = pack([x], "b * d") + if self.use_torch_checkpoint: + x = checkpoint(linear_transformer, x, use_reentrant=False) + else: + x = linear_transformer(x) + (x,) = unpack(x, ft_ps, "b * d") + else: + time_transformer, freq_transformer = transformer_block + + if self.skip_connection: + # Sum all previous + for j in range(i): + x = x + store[j] + + x = rearrange(x, "b t f d -> b f t d") + x, ps = pack([x], "* t d") + + if self.use_torch_checkpoint: + x = checkpoint(time_transformer, x, use_reentrant=False) + else: + x = time_transformer(x) + + (x,) = unpack(x, ps, "* t d") + x = rearrange(x, "b f t d -> b t f d") + x, ps = pack([x], "* f d") + + if self.use_torch_checkpoint: + x = checkpoint(freq_transformer, x, use_reentrant=False) + else: + x = freq_transformer(x) + + (x,) = unpack(x, ps, "* f d") + + if self.skip_connection: + store[i] = x + + x = self.final_norm(x) + + num_stems = len(self.mask_estimators) + + if self.use_torch_checkpoint: + mask = torch.stack( + [checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], + dim=1, + ) + else: + mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) + mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2) + + # modulate frequency representation + + stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c") + + # complex number multiplication + + stft_repr = torch.view_as_complex(stft_repr) + mask = torch.view_as_complex(mask) + + stft_repr = stft_repr * mask + + # istft + + stft_repr = rearrange( + stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels + ) + + # same as torch.stft() fix for MacOS MPS above + try: + recon_audio = torch.istft( + stft_repr, + **self.stft_kwargs, + window=stft_window, + return_complex=False, + length=raw_audio.shape[-1], + ) + except: + recon_audio = torch.istft( + stft_repr.cpu() if x_is_mps else stft_repr, + **self.stft_kwargs, + window=stft_window.cpu() if x_is_mps else stft_window, + return_complex=False, + length=raw_audio.shape[-1], + ).to(device) + + recon_audio = rearrange( + recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems + ) + + if num_stems == 1: + recon_audio = rearrange(recon_audio, "b 1 s t -> b s t") + + # if a target is passed in, calculate loss for learning + + if not exists(target): + return recon_audio + + if self.num_stems > 1: + assert target.ndim == 4 and target.shape[1] == self.num_stems + + if target.ndim == 2: + target = rearrange(target, "... t -> ... 1 t") + + target = target[ + ..., : recon_audio.shape[-1] + ] # protect against lost length on istft + + loss = F.l1_loss(recon_audio, target) + + multi_stft_resolution_loss = 0.0 + + for window_size in self.multi_stft_resolutions_window_sizes: + res_stft_kwargs = dict( + n_fft=max( + window_size, self.multi_stft_n_fft + ), # not sure what n_fft is across multi resolution stft + win_length=window_size, + return_complex=True, + window=self.multi_stft_window_fn(window_size, device=device), + **self.multi_stft_kwargs, + ) + + recon_Y = torch.stft( + rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs + ) + target_Y = torch.stft( + rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs + ) + + multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss( + recon_Y, target_Y + ) + + weighted_multi_resolution_loss = ( + multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight + ) + + total_loss = loss + weighted_multi_resolution_loss + + if not return_loss_breakdown: + return total_loss + + return total_loss, (loss, multi_stft_resolution_loss) diff --git a/mvsepless/models/bs_roformer/bs_roformer_fno.py b/mvsepless/models/bs_roformer/bs_roformer_fno.py new file mode 100644 index 0000000000000000000000000000000000000000..858a8aa52f74263797ce22c66f2dcf92180f635a --- /dev/null +++ b/mvsepless/models/bs_roformer/bs_roformer_fno.py @@ -0,0 +1,758 @@ +from functools import partial + +import torch +from torch import nn, einsum, Tensor +from torch.nn import Module, ModuleList +import torch.nn.functional as F +from neuralop.models import FNO1d + +from .attend_sw import Attend + +try: + from .attend_sage import Attend as AttendSage +except: + pass +from torch.utils.checkpoint import checkpoint + +from beartype.typing import Tuple, Optional, List, Callable +from beartype import beartype + +from rotary_embedding_torch import RotaryEmbedding + +from einops import rearrange, pack, unpack +from einops.layers.torch import Rearrange + +# helper functions + + +def exists(val): + return val is not None + + +def default(v, d): + return v if exists(v) else d + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +# norm + + +def l2norm(t): + return F.normalize(t, dim=-1, p=2) + + +class RMSNorm(Module): + def __init__(self, dim): + super().__init__() + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return F.normalize(x, dim=-1) * self.scale * self.gamma + + +# attention + + +class FeedForward(Module): + def __init__(self, dim, mult=4, dropout=0.0): + super().__init__() + dim_inner = int(dim * mult) + self.net = nn.Sequential( + RMSNorm(dim), + nn.Linear(dim, dim_inner), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim_inner, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(Module): + def __init__( + self, + dim, + heads=8, + dim_head=64, + dropout=0.0, + rotary_embed=None, + flash=True, + sage_attention=False, + ): + super().__init__() + self.heads = heads + self.scale = dim_head**-0.5 + dim_inner = heads * dim_head + + self.rotary_embed = rotary_embed + + if sage_attention: + self.attend = AttendSage(flash=flash, dropout=dropout) + else: + self.attend = Attend(flash=flash, dropout=dropout) + + self.norm = RMSNorm(dim) + self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False) + + self.to_gates = nn.Linear(dim, heads) + + self.to_out = nn.Sequential( + nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout) + ) + + def forward(self, x): + x = self.norm(x) + + q, k, v = rearrange( + self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads + ) + + if exists(self.rotary_embed): + q = self.rotary_embed.rotate_queries_or_keys(q) + k = self.rotary_embed.rotate_queries_or_keys(k) + + out = self.attend(q, k, v) + + gates = self.to_gates(x) + out = out * rearrange(gates, "b n h -> b h n 1").sigmoid() + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class LinearAttention(Module): + """ + this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al. + """ + + @beartype + def __init__( + self, + *, + dim, + dim_head=32, + heads=8, + scale=8, + flash=False, + dropout=0.0, + sage_attention=False, + ): + super().__init__() + dim_inner = dim_head * heads + self.norm = RMSNorm(dim) + + self.to_qkv = nn.Sequential( + nn.Linear(dim, dim_inner * 3, bias=False), + Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads), + ) + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + if sage_attention: + self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash) + else: + self.attend = Attend(scale=scale, dropout=dropout, flash=flash) + + self.to_out = nn.Sequential( + Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False) + ) + + def forward(self, x): + x = self.norm(x) + + q, k, v = self.to_qkv(x) + + q, k = map(l2norm, (q, k)) + q = q * self.temperature.exp() + + out = self.attend(q, k, v) + + return self.to_out(out) + + +class Transformer(Module): + def __init__( + self, + *, + dim, + depth, + dim_head=64, + heads=8, + attn_dropout=0.0, + ff_dropout=0.0, + ff_mult=4, + norm_output=True, + rotary_embed=None, + flash_attn=True, + linear_attn=False, + sage_attention=False, + ): + super().__init__() + self.layers = ModuleList([]) + + for _ in range(depth): + if linear_attn: + attn = LinearAttention( + dim=dim, + dim_head=dim_head, + heads=heads, + dropout=attn_dropout, + flash=flash_attn, + sage_attention=sage_attention, + ) + else: + attn = Attention( + dim=dim, + dim_head=dim_head, + heads=heads, + dropout=attn_dropout, + rotary_embed=rotary_embed, + flash=flash_attn, + sage_attention=sage_attention, + ) + + self.layers.append( + ModuleList( + [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)] + ) + ) + + self.norm = RMSNorm(dim) if norm_output else nn.Identity() + + def forward(self, x): + + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + return self.norm(x) + + +# bandsplit module + + +class BandSplit(Module): + @beartype + def __init__(self, dim, dim_inputs: Tuple[int, ...]): + super().__init__() + self.dim_inputs = dim_inputs + self.to_features = ModuleList([]) + + for dim_in in dim_inputs: + net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim)) + + self.to_features.append(net) + + def forward(self, x): + x = x.split(self.dim_inputs, dim=-1) + + outs = [] + for split_input, to_feature in zip(x, self.to_features): + split_output = to_feature(split_input) + outs.append(split_output) + + return torch.stack(outs, dim=-2) + + +def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh): + dim_hidden = default(dim_hidden, dim_in) + + net = [] + dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out) + + for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): + is_last = ind == (len(dims) - 2) + + net.append(nn.Linear(layer_dim_in, layer_dim_out)) + + if is_last: + continue + + net.append(activation()) + + return nn.Sequential(*net) + + +class MaskEstimator(Module): + @beartype + def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4): + super().__init__() + self.dim_inputs = dim_inputs + self.to_freqs = ModuleList([]) + dim_hidden = dim * mlp_expansion_factor + + for dim_in in dim_inputs: + net = [] + + mlp = nn.Sequential( + FNO1d( + n_modes_height=64, + hidden_channels=dim, + in_channels=dim, + out_channels=dim_in * 2, + lifting_channels=dim, + projection_channels=dim, + n_layers=3, + separable=True, + ), + nn.GLU(dim=-2), + ) + + self.to_freqs.append(mlp) + + def forward(self, x): + x = x.unbind(dim=-2) + + outs = [] + + for band_features, mlp in zip(x, self.to_freqs): + band_features = rearrange(band_features, "b t c -> b c t") + with torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32): + freq_out = mlp(band_features).float() + freq_out = rearrange(freq_out, "b c t -> b t c") + outs.append(freq_out) + + return torch.cat(outs, dim=-1) + + +# main class + +DEFAULT_FREQS_PER_BANDS = ( + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 48, + 48, + 48, + 48, + 48, + 48, + 48, + 48, + 128, + 129, +) + + +class BSRoformer_FNO(Module): + + @beartype + def __init__( + self, + dim, + *, + depth, + stereo=False, + num_stems=1, + time_transformer_depth=2, + freq_transformer_depth=2, + linear_transformer_depth=0, + freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS, + # in the paper, they divide into ~60 bands, test with 1 for starters + dim_head=64, + heads=8, + attn_dropout=0.0, + ff_dropout=0.0, + flash_attn=True, + dim_freqs_in=1025, + stft_n_fft=2048, + stft_hop_length=512, + # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction + stft_win_length=2048, + stft_normalized=False, + stft_window_fn: Optional[Callable] = None, + mask_estimator_depth=2, + multi_stft_resolution_loss_weight=1.0, + multi_stft_resolutions_window_sizes: Tuple[int, ...] = ( + 4096, + 2048, + 1024, + 512, + 256, + ), + multi_stft_hop_size=147, + multi_stft_normalized=False, + multi_stft_window_fn: Callable = torch.hann_window, + mlp_expansion_factor=4, + use_torch_checkpoint=False, + skip_connection=False, + sage_attention=False, + fno=True, + ): + super().__init__() + + self.stereo = stereo + self.audio_channels = 2 if stereo else 1 + self.num_stems = num_stems + self.use_torch_checkpoint = use_torch_checkpoint + self.skip_connection = skip_connection + + self.layers = ModuleList([]) + + if sage_attention: + print("Use Sage Attention") + + transformer_kwargs = dict( + dim=dim, + heads=heads, + dim_head=dim_head, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + flash_attn=flash_attn, + norm_output=False, + sage_attention=sage_attention, + ) + + time_rotary_embed = RotaryEmbedding(dim=dim_head) + freq_rotary_embed = RotaryEmbedding(dim=dim_head) + + for _ in range(depth): + tran_modules = [] + if linear_transformer_depth > 0: + tran_modules.append( + Transformer( + depth=linear_transformer_depth, + linear_attn=True, + **transformer_kwargs, + ) + ) + tran_modules.append( + Transformer( + depth=time_transformer_depth, + rotary_embed=time_rotary_embed, + **transformer_kwargs, + ) + ) + tran_modules.append( + Transformer( + depth=freq_transformer_depth, + rotary_embed=freq_rotary_embed, + **transformer_kwargs, + ) + ) + self.layers.append(nn.ModuleList(tran_modules)) + + self.final_norm = RMSNorm(dim) + + self.stft_kwargs = dict( + n_fft=stft_n_fft, + hop_length=stft_hop_length, + win_length=stft_win_length, + normalized=stft_normalized, + ) + + self.stft_window_fn = partial( + default(stft_window_fn, torch.hann_window), stft_win_length + ) + + freqs = torch.stft( + torch.randn(1, 4096), + **self.stft_kwargs, + window=torch.ones(stft_win_length), + return_complex=True, + ).shape[1] + + assert len(freqs_per_bands) > 1 + assert ( + sum(freqs_per_bands) == freqs + ), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}" + + freqs_per_bands_with_complex = tuple( + 2 * f * self.audio_channels for f in freqs_per_bands + ) + + self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex) + + self.mask_estimators = nn.ModuleList([]) + + for _ in range(num_stems): + mask_estimator = MaskEstimator( + dim=dim, + dim_inputs=freqs_per_bands_with_complex, + depth=mask_estimator_depth, + mlp_expansion_factor=mlp_expansion_factor, + ) + + self.mask_estimators.append(mask_estimator) + + # for the multi-resolution stft loss + + self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight + self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes + self.multi_stft_n_fft = stft_n_fft + self.multi_stft_window_fn = multi_stft_window_fn + + self.multi_stft_kwargs = dict( + hop_length=multi_stft_hop_size, normalized=multi_stft_normalized + ) + + def forward(self, raw_audio, target=None, return_loss_breakdown=False): + """ + einops + + b - batch + f - freq + t - time + s - audio channel (1 for mono, 2 for stereo) + n - number of 'stems' + c - complex (2) + d - feature dimension + """ + + device = raw_audio.device + + # defining whether model is loaded on MPS (MacOS GPU accelerator) + x_is_mps = True if device.type == "mps" else False + + if raw_audio.ndim == 2: + raw_audio = rearrange(raw_audio, "b t -> b 1 t") + + channels = raw_audio.shape[1] + assert (not self.stereo and channels == 1) or ( + self.stereo and channels == 2 + ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)" + + # to stft + + raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t") + + stft_window = self.stft_window_fn(device=device) + + # RuntimeError: FFT operations are only supported on MacOS 14+ + # Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used + try: + stft_repr = torch.stft( + raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True + ) + except: + stft_repr = torch.stft( + raw_audio.cpu() if x_is_mps else raw_audio, + **self.stft_kwargs, + window=stft_window.cpu() if x_is_mps else stft_window, + return_complex=True, + ).to(device) + stft_repr = torch.view_as_real(stft_repr) + + stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c") + + # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting + stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c") + + x = rearrange(stft_repr, "b f t c -> b t (f c)") + + if self.use_torch_checkpoint: + x = checkpoint(self.band_split, x, use_reentrant=False) + else: + x = self.band_split(x) + + # axial / hierarchical attention + + store = [None] * len(self.layers) + for i, transformer_block in enumerate(self.layers): + + if len(transformer_block) == 3: + linear_transformer, time_transformer, freq_transformer = ( + transformer_block + ) + + x, ft_ps = pack([x], "b * d") + if self.use_torch_checkpoint: + x = checkpoint(linear_transformer, x, use_reentrant=False) + else: + x = linear_transformer(x) + (x,) = unpack(x, ft_ps, "b * d") + else: + time_transformer, freq_transformer = transformer_block + + if self.skip_connection: + # Sum all previous + for j in range(i): + x = x + store[j] + + x = rearrange(x, "b t f d -> b f t d") + x, ps = pack([x], "* t d") + + if self.use_torch_checkpoint: + x = checkpoint(time_transformer, x, use_reentrant=False) + else: + x = time_transformer(x) + + (x,) = unpack(x, ps, "* t d") + x = rearrange(x, "b f t d -> b t f d") + x, ps = pack([x], "* f d") + + if self.use_torch_checkpoint: + x = checkpoint(freq_transformer, x, use_reentrant=False) + else: + x = freq_transformer(x) + + (x,) = unpack(x, ps, "* f d") + + if self.skip_connection: + store[i] = x + + x = self.final_norm(x) + + num_stems = len(self.mask_estimators) + + if self.use_torch_checkpoint: + mask = torch.stack( + [checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], + dim=1, + ) + else: + mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) + mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2) + + # modulate frequency representation + + stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c") + + # complex number multiplication + + stft_repr = torch.view_as_complex(stft_repr) + mask = torch.view_as_complex(mask) + + stft_repr = stft_repr * mask + + # istft + + stft_repr = rearrange( + stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels + ) + + # same as torch.stft() fix for MacOS MPS above + try: + recon_audio = torch.istft( + stft_repr, + **self.stft_kwargs, + window=stft_window, + return_complex=False, + length=raw_audio.shape[-1], + ) + except: + recon_audio = torch.istft( + stft_repr.cpu() if x_is_mps else stft_repr, + **self.stft_kwargs, + window=stft_window.cpu() if x_is_mps else stft_window, + return_complex=False, + length=raw_audio.shape[-1], + ).to(device) + + recon_audio = rearrange( + recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems + ) + + if num_stems == 1: + recon_audio = rearrange(recon_audio, "b 1 s t -> b s t") + + # if a target is passed in, calculate loss for learning + + if not exists(target): + return recon_audio + + if self.num_stems > 1: + assert target.ndim == 4 and target.shape[1] == self.num_stems + + if target.ndim == 2: + target = rearrange(target, "... t -> ... 1 t") + + target = target[ + ..., : recon_audio.shape[-1] + ] # protect against lost length on istft + + loss = F.l1_loss(recon_audio, target) + + multi_stft_resolution_loss = 0.0 + + for window_size in self.multi_stft_resolutions_window_sizes: + res_stft_kwargs = dict( + n_fft=max( + window_size, self.multi_stft_n_fft + ), # not sure what n_fft is across multi resolution stft + win_length=window_size, + return_complex=True, + window=self.multi_stft_window_fn(window_size, device=device), + **self.multi_stft_kwargs, + ) + + recon_Y = torch.stft( + rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs + ) + target_Y = torch.stft( + rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs + ) + + multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss( + recon_Y, target_Y + ) + + weighted_multi_resolution_loss = ( + multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight + ) + + total_loss = loss + weighted_multi_resolution_loss + + if not return_loss_breakdown: + return total_loss + + return total_loss, (loss, multi_stft_resolution_loss) diff --git a/mvsepless/models/bs_roformer/bs_roformer_sw.py b/mvsepless/models/bs_roformer/bs_roformer_sw.py new file mode 100644 index 0000000000000000000000000000000000000000..7b04ce518df91bf28126b76c96bc04910fbeb1de --- /dev/null +++ b/mvsepless/models/bs_roformer/bs_roformer_sw.py @@ -0,0 +1,724 @@ +from __future__ import annotations + +from functools import partial + +import torch +import torch.nn.functional as F +from beartype import beartype +from beartype.typing import Callable +from einops import pack, rearrange, unpack +from einops.layers.torch import Rearrange +from torch import nn +from torch.nn import Module, ModuleList +from torch.utils.checkpoint import checkpoint + +from .attend_sw import Attend + +try: + from .attend_sage_sw import AttendSage +except ImportError: + pass + + +def l2norm(t): + return F.normalize(t, dim=-1, p=2) + + +class CustomNorm(Module): + def __init__(self, dim, eps: float = 5.960464477539063e-08): # 0x1p-24 + super().__init__() + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + self.eps = eps + + def forward(self, x): + l2_norm = torch.linalg.norm(x, dim=-1, keepdim=True) + denom = torch.maximum(l2_norm, torch.full_like(l2_norm, self.eps)) + normalized_x = x / denom + return normalized_x * self.scale * self.gamma + + +# attention + + +class RotaryEmbedding(nn.Module): + def __init__(self, cos_emb, sin_emb): + super().__init__() + # both (seq_len_for_rotation, dim_head) + self.cos_emb = cos_emb + self.sin_emb = sin_emb + + def rotate_half(self, x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + def forward(self, x): + # x is (batch_eff, heads, seq_len_for_rotation, dim_head) + cos_b = self.cos_emb.unsqueeze(0).unsqueeze(0).to(x.device, x.dtype) + sin_b = self.sin_emb.unsqueeze(0).unsqueeze(0).to(x.device, x.dtype) + + term1 = x * cos_b + term2 = self.rotate_half(x) * sin_b + + sum = term1.to(torch.float32) + term2.to(torch.float32) + return sum.to(x.dtype) + + +class FeedForward(Module): + def __init__(self, dim, mult=4, dropout=0.0): + super().__init__() + dim_inner = int(dim * mult) + self.net = nn.Sequential( + CustomNorm(dim), + nn.Linear(dim, dim_inner), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim_inner, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(Module): + def __init__( + self, + dim, + heads=8, + dim_head=64, + dropout=0.0, + shared_qkv_bias=None, + shared_out_bias=None, + rotary_embed: RotaryEmbedding | None = None, + flash=True, + sage_attention=False, + ): + super().__init__() + self.heads = heads + self.scale = dim_head**-0.5 + dim_inner = heads * dim_head + + self.rotary_embed = rotary_embed + + if sage_attention: + self.attend = AttendSage(flash=flash, dropout=dropout) # type: ignore + else: + self.attend = Attend(flash=flash, dropout=dropout) # type: ignore + + self.norm = CustomNorm(dim) + self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=(shared_qkv_bias is not None)) + if shared_qkv_bias is not None: + self.to_qkv.bias = shared_qkv_bias + + self.to_gates = nn.Linear(dim, heads) + + self.to_out = nn.Sequential( + nn.Linear(dim_inner, dim, bias=(shared_out_bias is not None)), + nn.Dropout(dropout), + ) + if shared_out_bias is not None: + self.to_out[0].bias = shared_out_bias + + def forward(self, x): + x = self.norm(x) + + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads) + + if self.rotary_embed is not None: + q = self.rotary_embed(q) + k = self.rotary_embed(k) + + out = self.attend(q, k, v) + + gates = self.to_gates(x) + gate_act = gates.sigmoid() + + out = out * rearrange(gate_act, "b n h -> b h n 1") + + out = rearrange(out, "b h n d -> b n (h d)") + out = self.to_out(out) + return out + + +class LinearAttention(Module): + """ + this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al. + """ + + @beartype + def __init__( + self, + *, + dim, + dim_head=32, + heads=8, + scale=8, + flash=False, + dropout=0.0, + sage_attention=False, + ): + super().__init__() + dim_inner = dim_head * heads + self.norm = CustomNorm(dim) + + self.to_qkv = nn.Sequential( + nn.Linear(dim, dim_inner * 3, bias=False), + Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads), + ) + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + if sage_attention: + self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash) # type: ignore + else: + self.attend = Attend(scale=scale, dropout=dropout, flash=flash) + + self.to_out = nn.Sequential( + Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False) + ) + + def forward(self, x): + x = self.norm(x) + + q, k, v = self.to_qkv(x) + + q, k = map(l2norm, (q, k)) + q = q * self.temperature.exp() + + out = self.attend(q, k, v) + + return self.to_out(out) + + +class Transformer(Module): + def __init__( + self, + *, + dim, + depth, + dim_head=64, + heads=8, + attn_dropout=0.0, + ff_dropout=0.0, + ff_mult=4, + norm_output=True, + rotary_embed: RotaryEmbedding | None = None, + flash_attn=True, + linear_attn=False, + sage_attention=False, + shared_qkv_bias=None, + shared_out_bias=None, + ): + super().__init__() + self.layers = ModuleList([]) + + for _ in range(depth): + attn: LinearAttention | Attention + if linear_attn: + attn = LinearAttention( + dim=dim, + dim_head=dim_head, + heads=heads, + dropout=attn_dropout, + flash=flash_attn, + sage_attention=sage_attention, + ) + else: + attn = Attention( + dim=dim, + dim_head=dim_head, + heads=heads, + dropout=attn_dropout, + shared_qkv_bias=shared_qkv_bias, + shared_out_bias=shared_out_bias, + rotary_embed=rotary_embed, + flash=flash_attn, + sage_attention=sage_attention, + ) + + self.layers.append( + ModuleList( + [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)] + ) + ) + + self.norm = CustomNorm(dim) if norm_output else nn.Identity() + + def forward(self, x): + for attn, ff in self.layers: # type: ignore + x = attn(x) + x + x = ff(x) + x + return self.norm(x) + + +# bandsplit module + + +class BandSplit(Module): + @beartype + def __init__(self, dim, dim_inputs: tuple[int, ...]): + super().__init__() + self.dim_inputs = dim_inputs + self.to_features = ModuleList([]) + + for dim_in in dim_inputs: + net = nn.Sequential(CustomNorm(dim_in), nn.Linear(dim_in, dim)) + + self.to_features.append(net) + + def forward(self, x): + x = x.split(self.dim_inputs, dim=-1) + + outs = [] + for split_input, to_feature in zip(x, self.to_features): + split_output = to_feature(split_input) + outs.append(split_output) + + return torch.stack(outs, dim=-2) + + +def MLP( + dim_in: int, + dim_out: int, + dim_hidden: int | None = None, + depth: int = 1, + activation=nn.Tanh, +): + dim_hidden = dim_hidden or dim_in + + net = [] + dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out) + + for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): + is_last = ind == (len(dims) - 2) + + net.append(nn.Linear(layer_dim_in, layer_dim_out)) + + if is_last: + continue + + net.append(activation()) + + return nn.Sequential(*net) + + +class MaskEstimator(Module): + @beartype + def __init__(self, dim, dim_inputs: tuple[int, ...], depth, mlp_expansion_factor=4): + super().__init__() + self.dim_inputs = dim_inputs + self.to_freqs = ModuleList([]) + dim_hidden = dim * mlp_expansion_factor + + for dim_in in dim_inputs: + mlp = nn.Sequential( + MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1) + ) + + self.to_freqs.append(mlp) + + def forward(self, x): + x = x.unbind(dim=-2) + + outs = [] + + for band_features, mlp in zip(x, self.to_freqs): + freq_out = mlp(band_features) + outs.append(freq_out) + + return torch.cat(outs, dim=-1) + + +# fmt: off +DEFAULT_FREQS_PER_BANDS = ( + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 12, 12, 12, 12, 12, 12, 12, 12, + 24, 24, 24, 24, 24, 24, 24, 24, + 48, 48, 48, 48, 48, 48, 48, 48, + 128, 129 +) +# fmt: on + + +class BSRoformer_SW(Module): + @beartype + def __init__( + self, + dim, + *, + depth, + stereo=False, + num_stems=1, + time_transformer_depth=2, + freq_transformer_depth=2, + linear_transformer_depth=0, + freqs_per_bands: tuple[ + int, ... + ] = DEFAULT_FREQS_PER_BANDS, # in the paper, they divide into ~60 bands, test with 1 for starters + dim_head=64, + heads=8, + attn_dropout=0.0, + ff_dropout=0.0, + flash_attn=True, + stft_n_fft=2048, + stft_hop_length=512, # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction + stft_win_length=2048, + stft_normalized=False, + stft_window_fn: Callable | None = None, + mask_estimator_depth=2, + multi_stft_resolution_loss_weight=1.0, + multi_stft_resolutions_window_sizes: tuple[int, ...] = ( + 4096, + 2048, + 1024, + 512, + 256, + ), + multi_stft_hop_size=147, + multi_stft_normalized=False, + multi_stft_window_fn: Callable = torch.hann_window, + mlp_expansion_factor=4, + use_torch_checkpoint=False, + skip_connection=False, + sage_attention=False, + use_shared_bias=False, + chunk_size: int = 588800, + ): + super().__init__() + + self.stereo = stereo + self.audio_channels = 2 if stereo else 1 + self.num_stems = num_stems + self.use_torch_checkpoint = use_torch_checkpoint + self.skip_connection = skip_connection + + self.layers = ModuleList([]) + + if sage_attention: + print("Use Sage Attention") + + if use_shared_bias: + dim_inner = heads * dim_head + self.shared_qkv_bias = nn.Parameter(torch.ones(dim_inner * 3)) + self.shared_out_bias = nn.Parameter(torch.ones(dim)) + + transformer_kwargs = dict( + dim=dim, + heads=heads, + dim_head=dim_head, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + flash_attn=flash_attn, + norm_output=False, + sage_attention=sage_attention, + shared_qkv_bias=self.shared_qkv_bias, + shared_out_bias=self.shared_out_bias, + ) + + t_frames = chunk_size // stft_hop_length + 1 # e.g. 588800 // 512 + 1 = 1151 + self.cos_emb_time = nn.Parameter(torch.zeros(t_frames, dim_head)) + self.sin_emb_time = nn.Parameter(torch.zeros(t_frames, dim_head)) + time_rotary_embed = RotaryEmbedding( + cos_emb=self.cos_emb_time, sin_emb=self.sin_emb_time + ) + + num_bands = len(freqs_per_bands) # e.g. 62 + self.cos_emb_freq = nn.Parameter(torch.zeros(num_bands, dim_head)) + self.sin_emb_freq = nn.Parameter(torch.zeros(num_bands, dim_head)) + freq_rotary_embed = RotaryEmbedding( + cos_emb=self.cos_emb_freq, sin_emb=self.sin_emb_freq + ) + + for _ in range(depth): + tran_modules = [] + if linear_transformer_depth > 0: + tran_modules.append( + Transformer( + depth=linear_transformer_depth, + linear_attn=True, + **transformer_kwargs, + ) + ) + tran_modules.append( + Transformer( + depth=time_transformer_depth, + rotary_embed=time_rotary_embed, + **transformer_kwargs, + ) + ) + tran_modules.append( + Transformer( + depth=freq_transformer_depth, + rotary_embed=freq_rotary_embed, + **transformer_kwargs, + ) + ) + self.layers.append(nn.ModuleList(tran_modules)) + + self.final_norm = CustomNorm(dim) + + self.stft_kwargs = dict( + n_fft=stft_n_fft, + hop_length=stft_hop_length, + win_length=stft_win_length, + normalized=stft_normalized, + ) + + self.stft_window_fn = partial( + stft_window_fn or torch.hann_window, stft_win_length + ) + + freqs_per_bands_with_complex = tuple( + 2 * f * self.audio_channels for f in freqs_per_bands + ) + + self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex) + + self.mask_estimators = nn.ModuleList([]) + + for _ in range(num_stems): + mask_estimator = MaskEstimator( + dim=dim, + dim_inputs=freqs_per_bands_with_complex, + depth=mask_estimator_depth, + mlp_expansion_factor=mlp_expansion_factor, + ) + + self.mask_estimators.append(mask_estimator) + + # for the multi-resolution stft loss + + self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight + self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes + self.multi_stft_n_fft = stft_n_fft + self.multi_stft_window_fn = multi_stft_window_fn + + self.multi_stft_kwargs = dict( + hop_length=multi_stft_hop_size, normalized=multi_stft_normalized + ) + + def forward(self, raw_audio, target=None, return_loss_breakdown=False): + """ + einops + + b - batch + f - freq + t - time + s - audio channel (1 for mono, 2 for stereo) + n - number of 'stems' + c - complex (2) + d - feature dimension + """ + + device = raw_audio.device + + # defining whether model is loaded on MPS (MacOS GPU accelerator) + x_is_mps = True if device.type == "mps" else False + + if raw_audio.ndim == 2: + raw_audio = rearrange(raw_audio, "b t -> b 1 t") + + channels = raw_audio.shape[1] + assert (not self.stereo and channels == 1) or ( + self.stereo and channels == 2 + ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)" + + # to stft + + raw_audio, batch_audio_channel_packed_shape = pack([raw_audio], "* t") + + stft_window = self.stft_window_fn(device=device) + + # RuntimeError: FFT operations are only supported on MacOS 14+ + # Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used + try: + stft_repr = torch.stft( + raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True + ) + except Exception: + stft_repr = torch.stft( + raw_audio.cpu() if x_is_mps else raw_audio, + **self.stft_kwargs, + window=stft_window.cpu() if x_is_mps else stft_window, + return_complex=True, + ).to(device) + stft_repr = torch.view_as_real(stft_repr) + + stft_repr = unpack(stft_repr, batch_audio_channel_packed_shape, "* f t c")[0] + + # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting + stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c") + + x = rearrange(stft_repr, "b f t c -> b t (f c)") + + if torch.isnan(x).any() or torch.isinf(x).any(): + raise RuntimeError( + f"NaN/Inf in x after stft: {x.isnan().sum()} NaNs, {x.isinf().sum()} Infs" + ) + + if self.use_torch_checkpoint: + x = checkpoint(self.band_split, x, use_reentrant=False) + else: + x = self.band_split(x) + + if torch.isnan(x).any() or torch.isinf(x).any(): + raise RuntimeError( + f"NaN/Inf in x after band_split: {x.isnan().sum()} NaNs, {x.isinf().sum()} Infs" + ) + + # axial / hierarchical attention + + store = [None] * len(self.layers) + for i, transformer_block in enumerate(self.layers): + if len(transformer_block) == 3: + linear_transformer, time_transformer, freq_transformer = ( + transformer_block + ) + + x, ft_ps = pack([x], "b * d") + if self.use_torch_checkpoint: + x = checkpoint(linear_transformer, x, use_reentrant=False) + else: + x = linear_transformer(x) + (x,) = unpack(x, ft_ps, "b * d") + else: + time_transformer, freq_transformer = transformer_block + + if self.skip_connection: + # Sum all previous + for j in range(i): + x = x + store[j] + + x = rearrange(x, "b t f d -> b f t d") + x, ps = pack([x], "* t d") + + if self.use_torch_checkpoint: + x = checkpoint(time_transformer, x, use_reentrant=False) + else: + x = time_transformer(x) + + (x,) = unpack(x, ps, "* t d") + x = rearrange(x, "b f t d -> b t f d") + x, ps = pack([x], "* f d") + + if self.use_torch_checkpoint: + x = checkpoint(freq_transformer, x, use_reentrant=False) + else: + x = freq_transformer(x) + + (x,) = unpack(x, ps, "* f d") + + if self.skip_connection: + store[i] = x + + x = self.final_norm(x) + + num_stems = len(self.mask_estimators) + + if self.use_torch_checkpoint: + mask = torch.stack( + [checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], + dim=1, + ) + else: + mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) + mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2) + + # modulate frequency representation + + stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c") + + # complex number multiplication + + stft_repr = torch.view_as_complex(stft_repr) + mask = torch.view_as_complex(mask) + + stft_repr = stft_repr * mask + + # istft + + stft_repr = rearrange( + stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels + ) + + # same as torch.stft() fix for MacOS MPS above + try: + recon_audio = torch.istft( + stft_repr, + **self.stft_kwargs, + window=stft_window, + return_complex=False, + length=raw_audio.shape[-1], + ) + except Exception: + recon_audio = torch.istft( + stft_repr.cpu() if x_is_mps else stft_repr, + **self.stft_kwargs, + window=stft_window.cpu() if x_is_mps else stft_window, + return_complex=False, + length=raw_audio.shape[-1], + ).to(device) + + recon_audio = rearrange( + recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems + ) + + if num_stems == 1: + recon_audio = rearrange(recon_audio, "b 1 s t -> b s t") + + # if a target is passed in, calculate loss for learning + + if target is None: + return recon_audio + + if self.num_stems > 1: + assert target.ndim == 4 and target.shape[1] == self.num_stems + + if target.ndim == 2: + target = rearrange(target, "... t -> ... 1 t") + + target = target[ + ..., : recon_audio.shape[-1] + ] # protect against lost length on istft + + loss = F.l1_loss(recon_audio, target) + + multi_stft_resolution_loss = 0.0 + + for window_size in self.multi_stft_resolutions_window_sizes: + res_stft_kwargs = dict( + n_fft=max( + window_size, self.multi_stft_n_fft + ), # not sure what n_fft is across multi resolution stft + win_length=window_size, + return_complex=True, + window=self.multi_stft_window_fn(window_size, device=device), + **self.multi_stft_kwargs, + ) + + recon_Y = torch.stft( + rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs + ) + target_Y = torch.stft( + rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs + ) + + multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss( + recon_Y, target_Y + ) + + weighted_multi_resolution_loss = ( + multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight + ) + + total_loss = loss + weighted_multi_resolution_loss + + if not return_loss_breakdown: + return total_loss + + return total_loss, (loss, multi_stft_resolution_loss) diff --git a/mvsepless/models/bs_roformer/mel_band_roformer.py b/mvsepless/models/bs_roformer/mel_band_roformer.py new file mode 100644 index 0000000000000000000000000000000000000000..833aa39c44c5f7b7e8cf9dd60a7cccaead36ef24 --- /dev/null +++ b/mvsepless/models/bs_roformer/mel_band_roformer.py @@ -0,0 +1,704 @@ +from functools import partial + +import torch +from torch import nn, einsum, Tensor +from torch.nn import Module, ModuleList +import torch.nn.functional as F + +from .attend import Attend +from torch.utils.checkpoint import checkpoint + +from beartype.typing import Tuple, Optional, List, Callable +from beartype import beartype + +from rotary_embedding_torch import RotaryEmbedding + +from einops import rearrange, pack, unpack, reduce, repeat +from einops.layers.torch import Rearrange + +from librosa import filters + + +# helper functions + + +def exists(val): + return val is not None + + +def default(v, d): + return v if exists(v) else d + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +def pad_at_dim(t, pad, dim=-1, value=0.0): + dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) + zeros = (0, 0) * dims_from_right + return F.pad(t, (*zeros, *pad), value=value) + + +def l2norm(t): + return F.normalize(t, dim=-1, p=2) + + +# norm + + +class RMSNorm(Module): + def __init__(self, dim): + super().__init__() + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return F.normalize(x, dim=-1) * self.scale * self.gamma + + +# attention + + +class FeedForward(Module): + def __init__(self, dim, mult=4, dropout=0.0): + super().__init__() + dim_inner = int(dim * mult) + self.net = nn.Sequential( + RMSNorm(dim), + nn.Linear(dim, dim_inner), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim_inner, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(Module): + def __init__( + self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True + ): + super().__init__() + self.heads = heads + self.scale = dim_head**-0.5 + dim_inner = heads * dim_head + + self.rotary_embed = rotary_embed + + self.attend = Attend(flash=flash, dropout=dropout) + + self.norm = RMSNorm(dim) + self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False) + + self.to_gates = nn.Linear(dim, heads) + + self.to_out = nn.Sequential( + nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout) + ) + + def forward(self, x): + x = self.norm(x) + + q, k, v = rearrange( + self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads + ) + + if exists(self.rotary_embed): + q = self.rotary_embed.rotate_queries_or_keys(q) + k = self.rotary_embed.rotate_queries_or_keys(k) + + out = self.attend(q, k, v) + + gates = self.to_gates(x) + out = out * rearrange(gates, "b n h -> b h n 1").sigmoid() + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class LinearAttention(Module): + """ + this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al. + """ + + @beartype + def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0): + super().__init__() + dim_inner = dim_head * heads + self.norm = RMSNorm(dim) + + self.to_qkv = nn.Sequential( + nn.Linear(dim, dim_inner * 3, bias=False), + Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads), + ) + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.attend = Attend(scale=scale, dropout=dropout, flash=flash) + + self.to_out = nn.Sequential( + Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False) + ) + + def forward(self, x): + x = self.norm(x) + + q, k, v = self.to_qkv(x) + + q, k = map(l2norm, (q, k)) + q = q * self.temperature.exp() + + out = self.attend(q, k, v) + + return self.to_out(out) + + +class Transformer(Module): + def __init__( + self, + *, + dim, + depth, + dim_head=64, + heads=8, + attn_dropout=0.0, + ff_dropout=0.0, + ff_mult=4, + norm_output=True, + rotary_embed=None, + flash_attn=True, + linear_attn=False, + ): + super().__init__() + self.layers = ModuleList([]) + + for _ in range(depth): + if linear_attn: + attn = LinearAttention( + dim=dim, + dim_head=dim_head, + heads=heads, + dropout=attn_dropout, + flash=flash_attn, + ) + else: + attn = Attention( + dim=dim, + dim_head=dim_head, + heads=heads, + dropout=attn_dropout, + rotary_embed=rotary_embed, + flash=flash_attn, + ) + + self.layers.append( + ModuleList( + [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)] + ) + ) + + self.norm = RMSNorm(dim) if norm_output else nn.Identity() + + def forward(self, x): + + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + return self.norm(x) + + +# bandsplit module + + +class BandSplit(Module): + @beartype + def __init__(self, dim, dim_inputs: Tuple[int, ...]): + super().__init__() + self.dim_inputs = dim_inputs + self.to_features = ModuleList([]) + + for dim_in in dim_inputs: + net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim)) + + self.to_features.append(net) + + def forward(self, x): + x = x.split(self.dim_inputs, dim=-1) + + outs = [] + for split_input, to_feature in zip(x, self.to_features): + split_output = to_feature(split_input) + outs.append(split_output) + + return torch.stack(outs, dim=-2) + + +def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh): + dim_hidden = default(dim_hidden, dim_in) + + net = [] + dims = (dim_in, *((dim_hidden,) * depth), dim_out) + + for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): + is_last = ind == (len(dims) - 2) + + net.append(nn.Linear(layer_dim_in, layer_dim_out)) + + if is_last: + continue + + net.append(activation()) + + return nn.Sequential(*net) + + +class MaskEstimator(Module): + @beartype + def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4): + super().__init__() + self.dim_inputs = dim_inputs + self.to_freqs = ModuleList([]) + dim_hidden = dim * mlp_expansion_factor + + for dim_in in dim_inputs: + net = [] + + mlp = nn.Sequential( + MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1) + ) + + self.to_freqs.append(mlp) + + def forward(self, x): + x = x.unbind(dim=-2) + + outs = [] + + for band_features, mlp in zip(x, self.to_freqs): + freq_out = mlp(band_features) + outs.append(freq_out) + + return torch.cat(outs, dim=-1) + + +# main class + + +class MelBandRoformer(Module): + + @beartype + def __init__( + self, + dim, + *, + depth, + stereo=False, + num_stems=1, + time_transformer_depth=2, + freq_transformer_depth=2, + linear_transformer_depth=0, + num_bands=60, + dim_head=64, + heads=8, + attn_dropout=0.1, + ff_dropout=0.1, + flash_attn=True, + dim_freqs_in=1025, + sample_rate=44100, # needed for mel filter bank from librosa + stft_n_fft=2048, + stft_hop_length=512, + # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction + stft_win_length=2048, + stft_normalized=False, + stft_window_fn: Optional[Callable] = None, + mask_estimator_depth=1, + multi_stft_resolution_loss_weight=1.0, + multi_stft_resolutions_window_sizes: Tuple[int, ...] = ( + 4096, + 2048, + 1024, + 512, + 256, + ), + multi_stft_hop_size=147, + multi_stft_normalized=False, + multi_stft_window_fn: Callable = torch.hann_window, + match_input_audio_length=False, # if True, pad output tensor to match length of input tensor + mlp_expansion_factor=4, + use_torch_checkpoint=False, + skip_connection=False, + ): + super().__init__() + + self.stereo = stereo + self.audio_channels = 2 if stereo else 1 + self.num_stems = num_stems + self.use_torch_checkpoint = use_torch_checkpoint + self.skip_connection = skip_connection + + self.layers = ModuleList([]) + + transformer_kwargs = dict( + dim=dim, + heads=heads, + dim_head=dim_head, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + flash_attn=flash_attn, + ) + + time_rotary_embed = RotaryEmbedding(dim=dim_head) + freq_rotary_embed = RotaryEmbedding(dim=dim_head) + + for _ in range(depth): + tran_modules = [] + if linear_transformer_depth > 0: + tran_modules.append( + Transformer( + depth=linear_transformer_depth, + linear_attn=True, + **transformer_kwargs, + ) + ) + tran_modules.append( + Transformer( + depth=time_transformer_depth, + rotary_embed=time_rotary_embed, + **transformer_kwargs, + ) + ) + tran_modules.append( + Transformer( + depth=freq_transformer_depth, + rotary_embed=freq_rotary_embed, + **transformer_kwargs, + ) + ) + self.layers.append(nn.ModuleList(tran_modules)) + + self.stft_window_fn = partial( + default(stft_window_fn, torch.hann_window), stft_win_length + ) + + self.stft_kwargs = dict( + n_fft=stft_n_fft, + hop_length=stft_hop_length, + win_length=stft_win_length, + normalized=stft_normalized, + ) + + freqs = torch.stft( + torch.randn(1, 4096), + **self.stft_kwargs, + window=torch.ones(stft_n_fft), + return_complex=True, + ).shape[1] + + # create mel filter bank + # with librosa.filters.mel as in section 2 of paper + + mel_filter_bank_numpy = filters.mel( + sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands + ) + + mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy) + + # for some reason, it doesn't include the first freq? just force a value for now + + mel_filter_bank[0][0] = 1.0 + + # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position, + # so let's force a positive value + + mel_filter_bank[-1, -1] = 1.0 + + # binary as in paper (then estimated masks are averaged for overlapping regions) + + freqs_per_band = mel_filter_bank > 0 + assert freqs_per_band.any( + dim=0 + ).all(), "all frequencies need to be covered by all bands for now" + + repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands) + freq_indices = repeated_freq_indices[freqs_per_band] + + if stereo: + freq_indices = repeat(freq_indices, "f -> f s", s=2) + freq_indices = freq_indices * 2 + torch.arange(2) + freq_indices = rearrange(freq_indices, "f s -> (f s)") + + self.register_buffer("freq_indices", freq_indices, persistent=False) + self.register_buffer("freqs_per_band", freqs_per_band, persistent=False) + + num_freqs_per_band = reduce(freqs_per_band, "b f -> b", "sum") + num_bands_per_freq = reduce(freqs_per_band, "b f -> f", "sum") + + self.register_buffer("num_freqs_per_band", num_freqs_per_band, persistent=False) + self.register_buffer("num_bands_per_freq", num_bands_per_freq, persistent=False) + + # band split and mask estimator + + freqs_per_bands_with_complex = tuple( + 2 * f * self.audio_channels for f in num_freqs_per_band.tolist() + ) + + self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex) + + self.mask_estimators = nn.ModuleList([]) + + for _ in range(num_stems): + mask_estimator = MaskEstimator( + dim=dim, + dim_inputs=freqs_per_bands_with_complex, + depth=mask_estimator_depth, + mlp_expansion_factor=mlp_expansion_factor, + ) + + self.mask_estimators.append(mask_estimator) + + # for the multi-resolution stft loss + + self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight + self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes + self.multi_stft_n_fft = stft_n_fft + self.multi_stft_window_fn = multi_stft_window_fn + + self.multi_stft_kwargs = dict( + hop_length=multi_stft_hop_size, normalized=multi_stft_normalized + ) + + self.match_input_audio_length = match_input_audio_length + + def forward(self, raw_audio, target=None, return_loss_breakdown=False): + """ + einops + + b - batch + f - freq + t - time + s - audio channel (1 for mono, 2 for stereo) + n - number of 'stems' + c - complex (2) + d - feature dimension + """ + + device = raw_audio.device + + if raw_audio.ndim == 2: + raw_audio = rearrange(raw_audio, "b t -> b 1 t") + + batch, channels, raw_audio_length = raw_audio.shape + + istft_length = raw_audio_length if self.match_input_audio_length else None + + assert (not self.stereo and channels == 1) or ( + self.stereo and channels == 2 + ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)" + + # to stft + + raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t") + + stft_window = self.stft_window_fn(device=device) + + stft_repr = torch.stft( + raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True + ) + stft_repr = torch.view_as_real(stft_repr) + + stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c") + stft_repr = rearrange( + stft_repr, "b s f t c -> b (f s) t c" + ) # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting + + # index out all frequencies for all frequency ranges across bands ascending in one go + + batch_arange = torch.arange(batch, device=device)[..., None] + + # account for stereo + + x = stft_repr[batch_arange, self.freq_indices] + + # fold the complex (real and imag) into the frequencies dimension + + x = rearrange(x, "b f t c -> b t (f c)") + + if self.use_torch_checkpoint: + x = checkpoint(self.band_split, x, use_reentrant=False) + else: + x = self.band_split(x) + + # axial / hierarchical attention + + store = [None] * len(self.layers) + for i, transformer_block in enumerate(self.layers): + + if len(transformer_block) == 3: + linear_transformer, time_transformer, freq_transformer = ( + transformer_block + ) + + x, ft_ps = pack([x], "b * d") + if self.use_torch_checkpoint: + x = checkpoint(linear_transformer, x, use_reentrant=False) + else: + x = linear_transformer(x) + (x,) = unpack(x, ft_ps, "b * d") + else: + time_transformer, freq_transformer = transformer_block + + if self.skip_connection: + # Sum all previous + for j in range(i): + x = x + store[j] + + x = rearrange(x, "b t f d -> b f t d") + x, ps = pack([x], "* t d") + + if self.use_torch_checkpoint: + x = checkpoint(time_transformer, x, use_reentrant=False) + else: + x = time_transformer(x) + + (x,) = unpack(x, ps, "* t d") + x = rearrange(x, "b f t d -> b t f d") + x, ps = pack([x], "* f d") + + if self.use_torch_checkpoint: + x = checkpoint(freq_transformer, x, use_reentrant=False) + else: + x = freq_transformer(x) + + (x,) = unpack(x, ps, "* f d") + + if self.skip_connection: + store[i] = x + + num_stems = len(self.mask_estimators) + if self.use_torch_checkpoint: + masks = torch.stack( + [checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], + dim=1, + ) + else: + masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) + masks = rearrange(masks, "b n t (f c) -> b n f t c", c=2) + + # modulate frequency representation + + stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c") + + # complex number multiplication + + stft_repr = torch.view_as_complex(stft_repr) + masks = torch.view_as_complex(masks) + + masks = masks.type(stft_repr.dtype) + + # need to average the estimated mask for the overlapped frequencies + + scatter_indices = repeat( + self.freq_indices, + "f -> b n f t", + b=batch, + n=num_stems, + t=stft_repr.shape[-1], + ) + + stft_repr_expanded_stems = repeat(stft_repr, "b 1 ... -> b n ...", n=num_stems) + masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_( + 2, scatter_indices, masks + ) + + denom = repeat(self.num_bands_per_freq, "f -> (f r) 1", r=channels) + + masks_averaged = masks_summed / denom.clamp(min=1e-8) + + # modulate stft repr with estimated mask + + stft_repr = stft_repr * masks_averaged + + # istft + + stft_repr = rearrange( + stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels + ) + + recon_audio = torch.istft( + stft_repr, + **self.stft_kwargs, + window=stft_window, + return_complex=False, + length=istft_length, + ) + + recon_audio = rearrange( + recon_audio, + "(b n s) t -> b n s t", + b=batch, + s=self.audio_channels, + n=num_stems, + ) + + if num_stems == 1: + recon_audio = rearrange(recon_audio, "b 1 s t -> b s t") + + # if a target is passed in, calculate loss for learning + + if not exists(target): + return recon_audio + + if self.num_stems > 1: + assert target.ndim == 4 and target.shape[1] == self.num_stems + + if target.ndim == 2: + target = rearrange(target, "... t -> ... 1 t") + + target = target[ + ..., : recon_audio.shape[-1] + ] # protect against lost length on istft + + loss = F.l1_loss(recon_audio, target) + + multi_stft_resolution_loss = 0.0 + + for window_size in self.multi_stft_resolutions_window_sizes: + res_stft_kwargs = dict( + n_fft=max( + window_size, self.multi_stft_n_fft + ), # not sure what n_fft is across multi resolution stft + win_length=window_size, + return_complex=True, + window=self.multi_stft_window_fn(window_size, device=device), + **self.multi_stft_kwargs, + ) + + recon_Y = torch.stft( + rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs + ) + target_Y = torch.stft( + rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs + ) + + multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss( + recon_Y, target_Y + ) + + weighted_multi_resolution_loss = ( + multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight + ) + + total_loss = loss + weighted_multi_resolution_loss + + if not return_loss_breakdown: + return total_loss + + return total_loss, (loss, multi_stft_resolution_loss) diff --git a/mvsepless/models/demucs4ht.py b/mvsepless/models/demucs4ht.py new file mode 100644 index 0000000000000000000000000000000000000000..888ee82a19cf612b62504dfd8c0d9204f7a906f6 --- /dev/null +++ b/mvsepless/models/demucs4ht.py @@ -0,0 +1,712 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +import numpy as np +import torch +import json +from omegaconf import OmegaConf +from demucs.demucs import Demucs +from demucs.hdemucs import HDemucs + +import math +from openunmix.filtering import wiener +from torch import nn +from torch.nn import functional as F +from fractions import Fraction +from einops import rearrange + +from demucs.transformer import CrossTransformerEncoder + +from demucs.demucs import rescale_module +from demucs.states import capture_init +from demucs.spec import spectro, ispectro +from demucs.hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer + + +class HTDemucs(nn.Module): + """ + Spectrogram and hybrid Demucs model. + The spectrogram model has the same structure as Demucs, except the first few layers are over the + frequency axis, until there is only 1 frequency, and then it moves to time convolutions. + Frequency layers can still access information across time steps thanks to the DConv residual. + + Hybrid model have a parallel time branch. At some layer, the time branch has the same stride + as the frequency branch and then the two are combined. The opposite happens in the decoder. + + Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]), + or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on + Open Unmix implementation [Stoter et al. 2019]. + + The loss is always on the temporal domain, by backpropagating through the above + output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks + a bit Wiener filtering, as doing more iteration at test time will change the spectrogram + contribution, without changing the one from the waveform, which will lead to worse performance. + I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve. + CaC on the other hand provides similar performance for hybrid, and works naturally with + hybrid models. + + This model also uses frequency embeddings are used to improve efficiency on convolutions + over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf). + + Unlike classic Demucs, there is no resampling here, and normalization is always applied. + """ + + @capture_init + def __init__( + self, + sources, + # Channels + audio_channels=2, + channels=48, + channels_time=None, + growth=2, + # STFT + nfft=4096, + num_subbands=1, + wiener_iters=0, + end_iters=0, + wiener_residual=False, + cac=True, + # Main structure + depth=4, + rewrite=True, + # Frequency branch + multi_freqs=None, + multi_freqs_depth=3, + freq_emb=0.2, + emb_scale=10, + emb_smooth=True, + # Convolutions + kernel_size=8, + time_stride=2, + stride=4, + context=1, + context_enc=0, + # Normalization + norm_starts=4, + norm_groups=4, + # DConv residual branch + dconv_mode=1, + dconv_depth=2, + dconv_comp=8, + dconv_init=1e-3, + # Before the Transformer + bottom_channels=0, + # Transformer + t_layers=5, + t_emb="sin", + t_hidden_scale=4.0, + t_heads=8, + t_dropout=0.0, + t_max_positions=10000, + t_norm_in=True, + t_norm_in_group=False, + t_group_norm=False, + t_norm_first=True, + t_norm_out=True, + t_max_period=10000.0, + t_weight_decay=0.0, + t_lr=None, + t_layer_scale=True, + t_gelu=True, + t_weight_pos_embed=1.0, + t_sin_random_shift=0, + t_cape_mean_normalize=True, + t_cape_augment=True, + t_cape_glob_loc_scale=[5000.0, 1.0, 1.4], + t_sparse_self_attn=False, + t_sparse_cross_attn=False, + t_mask_type="diag", + t_mask_random_seed=42, + t_sparse_attn_window=500, + t_global_window=100, + t_sparsity=0.95, + t_auto_sparsity=False, + # ------ Particuliar parameters + t_cross_first=False, + # Weight init + rescale=0.1, + # Metadata + samplerate=44100, + segment=10, + use_train_segment=False, + ): + """ + Args: + sources (list[str]): list of source names. + audio_channels (int): input/output audio channels. + channels (int): initial number of hidden channels. + channels_time: if not None, use a different `channels` value for the time branch. + growth: increase the number of hidden channels by this factor at each layer. + nfft: number of fft bins. Note that changing this require careful computation of + various shape parameters and will not work out of the box for hybrid models. + wiener_iters: when using Wiener filtering, number of iterations at test time. + end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`. + wiener_residual: add residual source before wiener filtering. + cac: uses complex as channels, i.e. complex numbers are 2 channels each + in input and output. no further processing is done before ISTFT. + depth (int): number of layers in the encoder and in the decoder. + rewrite (bool): add 1x1 convolution to each layer. + multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`. + multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost + layers will be wrapped. + freq_emb: add frequency embedding after the first frequency layer if > 0, + the actual value controls the weight of the embedding. + emb_scale: equivalent to scaling the embedding learning rate + emb_smooth: initialize the embedding with a smooth one (with respect to frequencies). + kernel_size: kernel_size for encoder and decoder layers. + stride: stride for encoder and decoder layers. + time_stride: stride for the final time layer, after the merge. + context: context for 1x1 conv in the decoder. + context_enc: context for 1x1 conv in the encoder. + norm_starts: layer at which group norm starts being used. + decoder layers are numbered in reverse order. + norm_groups: number of groups for group norm. + dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. + dconv_depth: depth of residual DConv branch. + dconv_comp: compression of DConv branch. + dconv_attn: adds attention layers in DConv branch starting at this layer. + dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. + dconv_init: initial scale for the DConv branch LayerScale. + bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the + transformer in order to change the number of channels + t_layers: number of layers in each branch (waveform and spec) of the transformer + t_emb: "sin", "cape" or "scaled" + t_hidden_scale: the hidden scale of the Feedforward parts of the transformer + for instance if C = 384 (the number of channels in the transformer) and + t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension + 384 * 4 = 1536 + t_heads: number of heads for the transformer + t_dropout: dropout in the transformer + t_max_positions: max_positions for the "scaled" positional embedding, only + useful if t_emb="scaled" + t_norm_in: (bool) norm before addinf positional embedding and getting into the + transformer layers + t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the + timesteps (GroupNorm with group=1) + t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the + timesteps (GroupNorm with group=1) + t_norm_first: (bool) if True the norm is before the attention and before the FFN + t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer + t_max_period: (float) denominator in the sinusoidal embedding expression + t_weight_decay: (float) weight decay for the transformer + t_lr: (float) specific learning rate for the transformer + t_layer_scale: (bool) Layer Scale for the transformer + t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else + t_weight_pos_embed: (float) weighting of the positional embedding + t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings + see: https://arxiv.org/abs/2106.03143 + t_cape_augment: (bool) if t_emb="cape", must be True during training and False + during the inference, see: https://arxiv.org/abs/2106.03143 + t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters + see: https://arxiv.org/abs/2106.03143 + t_sparse_self_attn: (bool) if True, the self attentions are sparse + t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it + unless you designed really specific masks) + t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination + with '_' between: i.e. "diag_jmask_random" (note that this is permutation + invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag") + t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed + that generated the random part of the mask + t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and + a key (j), the mask is True id |i-j|<=t_sparse_attn_window + t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :] + and mask[:, :t_global_window] will be True + t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity + level of the random part of the mask. + t_cross_first: (bool) if True cross attention is the first layer of the + transformer (False seems to be better) + rescale: weight rescaling trick + use_train_segment: (bool) if True, the actual size that is used during the + training is used during inference. + """ + super().__init__() + self.num_subbands = num_subbands + self.cac = cac + self.wiener_residual = wiener_residual + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.bottom_channels = bottom_channels + self.channels = channels + self.samplerate = samplerate + self.segment = segment + self.use_train_segment = use_train_segment + self.nfft = nfft + self.hop_length = nfft // 4 + self.wiener_iters = wiener_iters + self.end_iters = end_iters + self.freq_emb = None + assert wiener_iters == end_iters + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + self.tencoder = nn.ModuleList() + self.tdecoder = nn.ModuleList() + + chin = audio_channels + chin_z = chin # number of channels for the freq branch + if self.cac: + chin_z *= 2 + if self.num_subbands > 1: + chin_z *= self.num_subbands + chout = channels_time or channels + chout_z = channels + freqs = nfft // 2 + + for index in range(depth): + norm = index >= norm_starts + freq = freqs > 1 + stri = stride + ker = kernel_size + if not freq: + assert freqs == 1 + ker = time_stride * 2 + stri = time_stride + + pad = True + last_freq = False + if freq and freqs <= kernel_size: + ker = freqs + pad = False + last_freq = True + + kw = { + "kernel_size": ker, + "stride": stri, + "freq": freq, + "pad": pad, + "norm": norm, + "rewrite": rewrite, + "norm_groups": norm_groups, + "dconv_kw": { + "depth": dconv_depth, + "compress": dconv_comp, + "init": dconv_init, + "gelu": True, + }, + } + kwt = dict(kw) + kwt["freq"] = 0 + kwt["kernel_size"] = kernel_size + kwt["stride"] = stride + kwt["pad"] = True + kw_dec = dict(kw) + multi = False + if multi_freqs and index < multi_freqs_depth: + multi = True + kw_dec["context_freq"] = False + + if last_freq: + chout_z = max(chout, chout_z) + chout = chout_z + + enc = HEncLayer( + chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw + ) + if freq: + tenc = HEncLayer( + chin, + chout, + dconv=dconv_mode & 1, + context=context_enc, + empty=last_freq, + **kwt, + ) + self.tencoder.append(tenc) + + if multi: + enc = MultiWrap(enc, multi_freqs) + self.encoder.append(enc) + if index == 0: + chin = self.audio_channels * len(self.sources) + chin_z = chin + if self.cac: + chin_z *= 2 + if self.num_subbands > 1: + chin_z *= self.num_subbands + dec = HDecLayer( + chout_z, + chin_z, + dconv=dconv_mode & 2, + last=index == 0, + context=context, + **kw_dec, + ) + if multi: + dec = MultiWrap(dec, multi_freqs) + if freq: + tdec = HDecLayer( + chout, + chin, + dconv=dconv_mode & 2, + empty=last_freq, + last=index == 0, + context=context, + **kwt, + ) + self.tdecoder.insert(0, tdec) + self.decoder.insert(0, dec) + + chin = chout + chin_z = chout_z + chout = int(growth * chout) + chout_z = int(growth * chout_z) + if freq: + if freqs <= kernel_size: + freqs = 1 + else: + freqs //= stride + if index == 0 and freq_emb: + self.freq_emb = ScaledEmbedding( + freqs, chin_z, smooth=emb_smooth, scale=emb_scale + ) + self.freq_emb_scale = freq_emb + + if rescale: + rescale_module(self, reference=rescale) + + transformer_channels = channels * growth ** (depth - 1) + if bottom_channels: + self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1) + self.channel_downsampler = nn.Conv1d( + bottom_channels, transformer_channels, 1 + ) + self.channel_upsampler_t = nn.Conv1d( + transformer_channels, bottom_channels, 1 + ) + self.channel_downsampler_t = nn.Conv1d( + bottom_channels, transformer_channels, 1 + ) + + transformer_channels = bottom_channels + + if t_layers > 0: + self.crosstransformer = CrossTransformerEncoder( + dim=transformer_channels, + emb=t_emb, + hidden_scale=t_hidden_scale, + num_heads=t_heads, + num_layers=t_layers, + cross_first=t_cross_first, + dropout=t_dropout, + max_positions=t_max_positions, + norm_in=t_norm_in, + norm_in_group=t_norm_in_group, + group_norm=t_group_norm, + norm_first=t_norm_first, + norm_out=t_norm_out, + max_period=t_max_period, + weight_decay=t_weight_decay, + lr=t_lr, + layer_scale=t_layer_scale, + gelu=t_gelu, + sin_random_shift=t_sin_random_shift, + weight_pos_embed=t_weight_pos_embed, + cape_mean_normalize=t_cape_mean_normalize, + cape_augment=t_cape_augment, + cape_glob_loc_scale=t_cape_glob_loc_scale, + sparse_self_attn=t_sparse_self_attn, + sparse_cross_attn=t_sparse_cross_attn, + mask_type=t_mask_type, + mask_random_seed=t_mask_random_seed, + sparse_attn_window=t_sparse_attn_window, + global_window=t_global_window, + sparsity=t_sparsity, + auto_sparsity=t_auto_sparsity, + ) + else: + self.crosstransformer = None + + def _spec(self, x): + hl = self.hop_length + nfft = self.nfft + x0 = x # noqa + + # We re-pad the signal in order to keep the property + # that the size of the output is exactly the size of the input + # divided by the stride (here hop_length), when divisible. + # This is achieved by padding by 1/4th of the kernel size (here nfft). + # which is not supported by torch.stft. + # Having all convolution operations follow this convention allow to easily + # align the time and frequency branches later on. + assert hl == nfft // 4 + le = int(math.ceil(x.shape[-1] / hl)) + pad = hl // 2 * 3 + x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect") + + z = spectro(x, nfft, hl)[..., :-1, :] + assert z.shape[-1] == le + 4, (z.shape, x.shape, le) + z = z[..., 2 : 2 + le] + return z + + def _ispec(self, z, length=None, scale=0): + hl = self.hop_length // (4**scale) + z = F.pad(z, (0, 0, 0, 1)) + z = F.pad(z, (2, 2)) + pad = hl // 2 * 3 + le = hl * int(math.ceil(length / hl)) + 2 * pad + x = ispectro(z, hl, length=le) + x = x[..., pad : pad + length] + return x + + def _magnitude(self, z): + # return the magnitude of the spectrogram, except when cac is True, + # in which case we just move the complex dimension to the channel one. + if self.cac: + B, C, Fr, T = z.shape + m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) + m = m.reshape(B, C * 2, Fr, T) + else: + m = z.abs() + return m + + def _mask(self, z, m): + # Apply masking given the mixture spectrogram `z` and the estimated mask `m`. + # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored. + niters = self.wiener_iters + if self.cac: + B, S, C, Fr, T = m.shape + out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) + out = torch.view_as_complex(out.contiguous()) + return out + if self.training: + niters = self.end_iters + if niters < 0: + z = z[:, None] + return z / (1e-8 + z.abs()) * m + else: + return self._wiener(m, z, niters) + + def _wiener(self, mag_out, mix_stft, niters): + # apply wiener filtering from OpenUnmix. + init = mix_stft.dtype + wiener_win_len = 300 + residual = self.wiener_residual + + B, S, C, Fq, T = mag_out.shape + mag_out = mag_out.permute(0, 4, 3, 2, 1) + mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1)) + + outs = [] + for sample in range(B): + pos = 0 + out = [] + for pos in range(0, T, wiener_win_len): + frame = slice(pos, pos + wiener_win_len) + z_out = wiener( + mag_out[sample, frame], + mix_stft[sample, frame], + niters, + residual=residual, + ) + out.append(z_out.transpose(-1, -2)) + outs.append(torch.cat(out, dim=0)) + out = torch.view_as_complex(torch.stack(outs, 0)) + out = out.permute(0, 4, 3, 2, 1).contiguous() + if residual: + out = out[:, :-1] + assert list(out.shape) == [B, S, C, Fq, T] + return out.to(init) + + def valid_length(self, length: int): + """ + Return a length that is appropriate for evaluation. + In our case, always return the training length, unless + it is smaller than the given length, in which case this + raises an error. + """ + if not self.use_train_segment: + return length + training_length = int(self.segment * self.samplerate) + if training_length < length: + raise ValueError( + f"Given length {length} is longer than " + f"training length {training_length}" + ) + return training_length + + def cac2cws(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c, k, f // k, t) + x = x.reshape(b, c * k, f // k, t) + return x + + def cws2cac(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c // k, k, f, t) + x = x.reshape(b, c // k, f * k, t) + return x + + def forward(self, mix): + length = mix.shape[-1] + length_pre_pad = None + if self.use_train_segment: + if self.training: + self.segment = Fraction(mix.shape[-1], self.samplerate) + else: + training_length = int(self.segment * self.samplerate) + # print('Training length: {} Segment: {} Sample rate: {}'.format(training_length, self.segment, self.samplerate)) + if mix.shape[-1] < training_length: + length_pre_pad = mix.shape[-1] + mix = F.pad(mix, (0, training_length - length_pre_pad)) + # print("Mix: {}".format(mix.shape)) + # print("Length: {}".format(length)) + z = self._spec(mix) + # print("Z: {} Type: {}".format(z.shape, z.dtype)) + mag = self._magnitude(z) + x = mag + # print("MAG: {} Type: {}".format(x.shape, x.dtype)) + + if self.num_subbands > 1: + x = self.cac2cws(x) + # print("After SUBBANDS: {} Type: {}".format(x.shape, x.dtype)) + + B, C, Fq, T = x.shape + + # unlike previous Demucs, we always normalize because it is easier. + mean = x.mean(dim=(1, 2, 3), keepdim=True) + std = x.std(dim=(1, 2, 3), keepdim=True) + x = (x - mean) / (1e-5 + std) + # x will be the freq. branch input. + + # Prepare the time branch input. + xt = mix + meant = xt.mean(dim=(1, 2), keepdim=True) + stdt = xt.std(dim=(1, 2), keepdim=True) + xt = (xt - meant) / (1e-5 + stdt) + + # print("XT: {}".format(xt.shape)) + + # okay, this is a giant mess I know... + saved = [] # skip connections, freq. + saved_t = [] # skip connections, time. + lengths = [] # saved lengths to properly remove padding, freq branch. + lengths_t = [] # saved lengths for time branch. + for idx, encode in enumerate(self.encoder): + lengths.append(x.shape[-1]) + inject = None + if idx < len(self.tencoder): + # we have not yet merged branches. + lengths_t.append(xt.shape[-1]) + tenc = self.tencoder[idx] + xt = tenc(xt) + # print("Encode XT {}: {}".format(idx, xt.shape)) + if not tenc.empty: + # save for skip connection + saved_t.append(xt) + else: + # tenc contains just the first conv., so that now time and freq. + # branches have the same shape and can be merged. + inject = xt + x = encode(x, inject) + # print("Encode X {}: {}".format(idx, x.shape)) + if idx == 0 and self.freq_emb is not None: + # add frequency embedding to allow for non equivariant convolutions + # over the frequency axis. + frs = torch.arange(x.shape[-2], device=x.device) + emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) + x = x + self.freq_emb_scale * emb + + saved.append(x) + if self.crosstransformer: + if self.bottom_channels: + b, c, f, t = x.shape + x = rearrange(x, "b c f t-> b c (f t)") + x = self.channel_upsampler(x) + x = rearrange(x, "b c (f t)-> b c f t", f=f) + xt = self.channel_upsampler_t(xt) + + x, xt = self.crosstransformer(x, xt) + # print("Cross Tran X {}, XT: {}".format(x.shape, xt.shape)) + + if self.bottom_channels: + x = rearrange(x, "b c f t-> b c (f t)") + x = self.channel_downsampler(x) + x = rearrange(x, "b c (f t)-> b c f t", f=f) + xt = self.channel_downsampler_t(xt) + + for idx, decode in enumerate(self.decoder): + skip = saved.pop(-1) + x, pre = decode(x, skip, lengths.pop(-1)) + # print('Decode {} X: {}'.format(idx, x.shape)) + # `pre` contains the output just before final transposed convolution, + # which is used when the freq. and time branch separate. + + offset = self.depth - len(self.tdecoder) + if idx >= offset: + tdec = self.tdecoder[idx - offset] + length_t = lengths_t.pop(-1) + if tdec.empty: + assert pre.shape[2] == 1, pre.shape + pre = pre[:, :, 0] + xt, _ = tdec(pre, None, length_t) + else: + skip = saved_t.pop(-1) + xt, _ = tdec(xt, skip, length_t) + # print('Decode {} XT: {}'.format(idx, xt.shape)) + + # Let's make sure we used all stored skip connections. + assert len(saved) == 0 + assert len(lengths_t) == 0 + assert len(saved_t) == 0 + + S = len(self.sources) + + if self.num_subbands > 1: + x = x.view(B, -1, Fq, T) + # print("X view 1: {}".format(x.shape)) + x = self.cws2cac(x) + # print("X view 2: {}".format(x.shape)) + + x = x.view(B, S, -1, Fq * self.num_subbands, T) + x = x * std[:, None] + mean[:, None] + # print("X returned: {}".format(x.shape)) + + zout = self._mask(z, x) + if self.use_train_segment: + if self.training: + x = self._ispec(zout, length) + else: + x = self._ispec(zout, training_length) + else: + x = self._ispec(zout, length) + + if self.use_train_segment: + if self.training: + xt = xt.view(B, S, -1, length) + else: + xt = xt.view(B, S, -1, training_length) + else: + xt = xt.view(B, S, -1, length) + xt = xt * stdt[:, None] + meant[:, None] + x = xt + x + if length_pre_pad: + x = x[..., :length_pre_pad] + return x + + +def get_model(args): + extra = { + "sources": list(args.training.instruments), + "audio_channels": args.training.channels, + "samplerate": args.training.samplerate, + # 'segment': args.model_segment or 4 * args.dset.segment, + "segment": args.training.segment, + } + klass = { + "demucs": Demucs, + "hdemucs": HDemucs, + "htdemucs": HTDemucs, + }[args.model] + kw = OmegaConf.to_container(getattr(args, args.model), resolve=True) + model = klass(**extra, **kw) + return model diff --git a/mvsepless/models/mdx23c_tfc_tdf_v3.py b/mvsepless/models/mdx23c_tfc_tdf_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..cd4637cfee387d1734ee3e33fe2e11ce923ef432 --- /dev/null +++ b/mvsepless/models/mdx23c_tfc_tdf_v3.py @@ -0,0 +1,259 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from infer_utils import prefer_target_instrument + + +class STFT: + def __init__(self, config): + self.n_fft = config.n_fft + self.hop_length = config.hop_length + self.window = torch.hann_window(window_length=self.n_fft, periodic=True) + self.dim_f = config.dim_f + + def __call__(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-2] + c, t = x.shape[-2:] + x = x.reshape([-1, t]) + x = torch.stft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=window, + center=True, + return_complex=True, + ) + x = torch.view_as_real(x) + x = x.permute([0, 3, 1, 2]) + x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape( + [*batch_dims, c * 2, -1, x.shape[-1]] + ) + return x[..., : self.dim_f, :] + + def inverse(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-3] + c, f, t = x.shape[-3:] + n = self.n_fft // 2 + 1 + f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device) + x = torch.cat([x, f_pad], -2) + x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) + x = x.permute([0, 2, 3, 1]) + x = x[..., 0] + x[..., 1] * 1.0j + x = torch.istft( + x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True + ) + x = x.reshape([*batch_dims, 2, -1]) + return x + + +def get_norm(norm_type): + def norm(c, norm_type): + if norm_type == "BatchNorm": + return nn.BatchNorm2d(c) + elif norm_type == "InstanceNorm": + return nn.InstanceNorm2d(c, affine=True) + elif "GroupNorm" in norm_type: + g = int(norm_type.replace("GroupNorm", "")) + return nn.GroupNorm(num_groups=g, num_channels=c) + else: + return nn.Identity() + + return partial(norm, norm_type=norm_type) + + +def get_act(act_type): + if act_type == "gelu": + return nn.GELU() + elif act_type == "relu": + return nn.ReLU() + elif act_type[:3] == "elu": + alpha = float(act_type.replace("elu", "")) + return nn.ELU(alpha) + else: + raise Exception + + +class Upscale(nn.Module): + def __init__(self, in_c, out_c, scale, norm, act): + super().__init__() + self.conv = nn.Sequential( + norm(in_c), + act, + nn.ConvTranspose2d( + in_channels=in_c, + out_channels=out_c, + kernel_size=scale, + stride=scale, + bias=False, + ), + ) + + def forward(self, x): + return self.conv(x) + + +class Downscale(nn.Module): + def __init__(self, in_c, out_c, scale, norm, act): + super().__init__() + self.conv = nn.Sequential( + norm(in_c), + act, + nn.Conv2d( + in_channels=in_c, + out_channels=out_c, + kernel_size=scale, + stride=scale, + bias=False, + ), + ) + + def forward(self, x): + return self.conv(x) + + +class TFC_TDF(nn.Module): + def __init__(self, in_c, c, l, f, bn, norm, act): + super().__init__() + + self.blocks = nn.ModuleList() + for i in range(l): + block = nn.Module() + + block.tfc1 = nn.Sequential( + norm(in_c), + act, + nn.Conv2d(in_c, c, 3, 1, 1, bias=False), + ) + block.tdf = nn.Sequential( + norm(c), + act, + nn.Linear(f, f // bn, bias=False), + norm(c), + act, + nn.Linear(f // bn, f, bias=False), + ) + block.tfc2 = nn.Sequential( + norm(c), + act, + nn.Conv2d(c, c, 3, 1, 1, bias=False), + ) + block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False) + + self.blocks.append(block) + in_c = c + + def forward(self, x): + for block in self.blocks: + s = block.shortcut(x) + x = block.tfc1(x) + x = x + block.tdf(x) + x = block.tfc2(x) + x = x + s + return x + + +class TFC_TDF_net(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + norm = get_norm(norm_type=config.model.norm) + act = get_act(act_type=config.model.act) + + self.num_target_instruments = len(prefer_target_instrument(config)) + self.num_subbands = config.model.num_subbands + + dim_c = self.num_subbands * config.audio.num_channels * 2 + n = config.model.num_scales + scale = config.model.scale + l = config.model.num_blocks_per_scale + c = config.model.num_channels + g = config.model.growth + bn = config.model.bottleneck_factor + f = config.audio.dim_f // self.num_subbands + + self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False) + + self.encoder_blocks = nn.ModuleList() + for i in range(n): + block = nn.Module() + block.tfc_tdf = TFC_TDF(c, c, l, f, bn, norm, act) + block.downscale = Downscale(c, c + g, scale, norm, act) + f = f // scale[1] + c += g + self.encoder_blocks.append(block) + + self.bottleneck_block = TFC_TDF(c, c, l, f, bn, norm, act) + + self.decoder_blocks = nn.ModuleList() + for i in range(n): + block = nn.Module() + block.upscale = Upscale(c, c - g, scale, norm, act) + f = f * scale[1] + c -= g + block.tfc_tdf = TFC_TDF(2 * c, c, l, f, bn, norm, act) + self.decoder_blocks.append(block) + + self.final_conv = nn.Sequential( + nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), + act, + nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False), + ) + + self.stft = STFT(config.audio) + + def cac2cws(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c, k, f // k, t) + x = x.reshape(b, c * k, f // k, t) + return x + + def cws2cac(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c // k, k, f, t) + x = x.reshape(b, c // k, f * k, t) + return x + + def forward(self, x): + + x = self.stft(x) + + mix = x = self.cac2cws(x) + + first_conv_out = x = self.first_conv(x) + + x = x.transpose(-1, -2) + + encoder_outputs = [] + for block in self.encoder_blocks: + x = block.tfc_tdf(x) + encoder_outputs.append(x) + x = block.downscale(x) + + x = self.bottleneck_block(x) + + for block in self.decoder_blocks: + x = block.upscale(x) + x = torch.cat([x, encoder_outputs.pop()], 1) + x = block.tfc_tdf(x) + + x = x.transpose(-1, -2) + + x = x * first_conv_out # reduce artifacts + + x = self.final_conv(torch.cat([mix, x], 1)) + + x = self.cws2cac(x) + + if self.num_target_instruments > 1: + b, c, f, t = x.shape + x = x.reshape(b, self.num_target_instruments, -1, f, t) + + x = self.stft.inverse(x) + + return x diff --git a/mvsepless/models/mdx_net.py b/mvsepless/models/mdx_net.py new file mode 100644 index 0000000000000000000000000000000000000000..fb7a6bc355444ad09f2de15407ca6bee76f25b83 --- /dev/null +++ b/mvsepless/models/mdx_net.py @@ -0,0 +1,208 @@ +import torch +import torch.nn as nn +import onnxruntime as ort +import numpy as np +from typing import Dict, Any, List +import torch.nn.functional as F +import sys +import json + +class MDXNet(nn.Module): + def __init__(self, dim_f: int, dim_t: int, n_fft: int, hop_length: int, primary_stem: str, compensation: float = 1.0): + super().__init__() + self.dim_f = dim_f + self.dim_t = dim_t + self.n_fft = n_fft + self.dim_c = 4 + self.hop_length = hop_length + self.primary_stem = primary_stem + self.compensation = compensation + + # Внутренний chunk_size MDXNet (фиксированный) + self.internal_chunk_size = self.hop_length * (self.dim_t - 1) + self.n_bins = self.n_fft // 2 + 1 + + # ONNX session будет инициализирован позже + self.ort_session = None + + def init_onnx_session(self, onnx_model_path: str, device: str): + """Инициализирует ONNX runtime session""" + providers = ["CUDAExecutionProvider"] if "cuda" in device else ["CPUExecutionProvider"] + self.ort_session = ort.InferenceSession(onnx_model_path, providers=providers) + + # Preload model + self.ort_session.run( + None, + {"input": torch.rand(1, 4, self.dim_f, self.dim_t).numpy()}, + ) + + # Инициализируем оконную функцию на правильном устройстве + self.window = torch.hann_window( + window_length=self.n_fft, periodic=True + ) + + out_c = self.dim_c + + self.freq_pad = torch.zeros( + [1, out_c, self.n_bins - self.dim_f, self.dim_t] + ) + + def stft(self, x: torch.Tensor) -> torch.Tensor: + """STFT преобразование для MDX-Net""" + # Убедимся, что window на том же устройстве, что и x + window = self.window.to(x.device) + + x = x.reshape([-1, self.internal_chunk_size]) + x = torch.stft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=window, + center=True, + return_complex=True, + ) + x = torch.view_as_real(x) + x = x.permute([0, 3, 1, 2]) + x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape( + [-1, 4, self.n_bins, self.dim_t] + ) + return x[:, :, :self.dim_f] + + def istft(self, x: torch.Tensor) -> torch.Tensor: + """Обратное STFT преобразование для MDX-Net""" + # Убедимся, что window и freq_pad на том же устройстве, что и x + window = self.window.to(x.device) + freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]).to(x.device) + + x = torch.cat([x, freq_pad], -2) + x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape( + [-1, 2, self.n_bins, self.dim_t] + ) + x = x.permute([0, 2, 3, 1]) + x = x.contiguous() + x = torch.view_as_complex(x) + x = torch.istft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=window, + center=True, + ) + return x.reshape([-1, 2, self.internal_chunk_size]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Прямой проход через MDX-Net""" + if self.ort_session is None: + raise ValueError("ONNX session not initialized. Call init_onnx_session first.") + + # Преобразуем в numpy для ONNX инференса + x_np = x.cpu().numpy() + output = self.ort_session.run(None, {"input": x_np})[0] + return torch.from_numpy(output).to(x.device) + + def process_wave(self, wave: torch.Tensor, device: torch.device, num_overlap: int, pbar: bool = False) -> torch.Tensor: + """Обрабатывает аудио волну через MDX-Net с overlap и chunking""" + # Перемещаем wave на нужное устройство + wave = wave.to(device) + + # Используем внутренний chunk_size MDXNet + chunk_size = self.internal_chunk_size + fade_size = chunk_size // 10 + step = chunk_size // num_overlap + border = chunk_size - step + + length_init = wave.shape[-1] + + # Добавляем padding для обработки краев + if length_init > 2 * border and border > 0: + wave = nn.functional.pad(wave, (border, border), mode="reflect") + + # Создаем оконную функцию на правильном устройстве + window = self._get_windowing_array(chunk_size, fade_size).to(device) + + batch_size = 1 # MDXNet обычно использует batch_size=1 + + with torch.no_grad(): + # Инициализируем результат и счетчик на правильном устройстве + result = torch.zeros_like(wave, device=device) + counter = torch.zeros_like(wave, device=device) + + i = 0 + batch_data = [] + batch_locations = [] + + # Подсчитываем общее количество чанков для прогресса + total_chunks = 0 + temp_i = 0 + while temp_i < wave.shape[1]: + total_chunks += 1 + temp_i += step + + processed_chunks = 0 + + while i < wave.shape[1]: + # Извлекаем чанк + part = wave[:, i : i + chunk_size] + chunk_len = part.shape[-1] + + # Добавляем padding если чанк меньше chunk_size + if chunk_len < chunk_size: + pad_mode = "reflect" if chunk_len > chunk_size // 2 else "constant" + part = nn.functional.pad( + part, (0, chunk_size - chunk_len), mode=pad_mode, value=0 + ) + + batch_data.append(part) + batch_locations.append((i, chunk_len)) + i += step + + # Обрабатываем батч + if len(batch_data) >= batch_size or i >= wave.shape[1]: + arr = torch.stack(batch_data, dim=0) + + # Обрабатываем каждый чанк в батче + for j, (start, seg_len) in enumerate(batch_locations): + # STFT -> ONNX inference -> iSTFT + spec = self.stft(arr[j:j+1]) + processed_spec = self(spec) + processed_wav = self.istft(processed_spec) + + # Применяем оконную функцию + window_segment = window[..., :seg_len] + result[:, start : start + seg_len] += processed_wav[0, :, :seg_len] * window_segment + counter[:, start : start + seg_len] += window_segment + + # Обновляем прогресс + processed_chunks += len(batch_data) + if pbar: + progress_data = { + "processing": { + "processed": min(i, wave.shape[1]), + "total": wave.shape[1] + } + } + sys.stdout.write(json.dumps(progress_data, ensure_ascii=False) + '\n') + sys.stdout.flush() + + # Очищаем батч + batch_data.clear() + batch_locations.clear() + + # Вычисляем финальный результат + estimated_sources = result / counter + + # Убираем padding + if length_init > 2 * border and border > 0: + estimated_sources = estimated_sources[..., border:-border] + + return estimated_sources + + def _get_windowing_array(self, window_size: int, fade_size: int) -> torch.Tensor: + """Генерирует оконную функцию с fade-in и fade-out""" + fadein = torch.linspace(0, 1, fade_size) + fadeout = torch.linspace(1, 0, fade_size) + + window = torch.ones(window_size) + window[-fade_size:] = fadeout + window[:fade_size] = fadein + return window \ No newline at end of file diff --git a/mvsepless/models/scnet/__init__.py b/mvsepless/models/scnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f6ecefede9345237623066dd21ebd8253af1c60 --- /dev/null +++ b/mvsepless/models/scnet/__init__.py @@ -0,0 +1 @@ +from .scnet import SCNet diff --git a/mvsepless/models/scnet/scnet.py b/mvsepless/models/scnet/scnet.py new file mode 100644 index 0000000000000000000000000000000000000000..71807479eb10bc803f2fb1f7d7f193bd30416167 --- /dev/null +++ b/mvsepless/models/scnet/scnet.py @@ -0,0 +1,419 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections import deque +from .separation import SeparationNet +import typing as tp +import math + + +class Swish(nn.Module): + def forward(self, x): + return x * x.sigmoid() + + +class ConvolutionModule(nn.Module): + """ + Convolution Module in SD block. + + Args: + channels (int): input/output channels. + depth (int): number of layers in the residual branch. Each layer has its own + compress (float): amount of channel compression. + kernel (int): kernel size for the convolutions. + """ + + def __init__(self, channels, depth=2, compress=4, kernel=3): + super().__init__() + assert kernel % 2 == 1 + self.depth = abs(depth) + hidden_size = int(channels / compress) + norm = lambda d: nn.GroupNorm(1, d) + self.layers = nn.ModuleList([]) + for _ in range(self.depth): + padding = kernel // 2 + mods = [ + norm(channels), + nn.Conv1d(channels, hidden_size * 2, kernel, padding=padding), + nn.GLU(1), + nn.Conv1d( + hidden_size, + hidden_size, + kernel, + padding=padding, + groups=hidden_size, + ), + norm(hidden_size), + Swish(), + nn.Conv1d(hidden_size, channels, 1), + ] + layer = nn.Sequential(*mods) + self.layers.append(layer) + + def forward(self, x): + for layer in self.layers: + x = x + layer(x) + return x + + +class FusionLayer(nn.Module): + """ + A FusionLayer within the decoder. + + Args: + - channels (int): Number of input channels. + - kernel_size (int, optional): Kernel size for the convolutional layer, defaults to 3. + - stride (int, optional): Stride for the convolutional layer, defaults to 1. + - padding (int, optional): Padding for the convolutional layer, defaults to 1. + """ + + def __init__(self, channels, kernel_size=3, stride=1, padding=1): + super(FusionLayer, self).__init__() + self.conv = nn.Conv2d( + channels * 2, channels * 2, kernel_size, stride=stride, padding=padding + ) + + def forward(self, x, skip=None): + if skip is not None: + x += skip + x = x.repeat(1, 2, 1, 1) + x = self.conv(x) + x = F.glu(x, dim=1) + return x + + +class SDlayer(nn.Module): + """ + Implements a Sparse Down-sample Layer for processing different frequency bands separately. + + Args: + - channels_in (int): Input channel count. + - channels_out (int): Output channel count. + - band_configs (dict): A dictionary containing configuration for each frequency band. + Keys are 'low', 'mid', 'high' for each band, and values are + dictionaries with keys 'SR', 'stride', and 'kernel' for proportion, + stride, and kernel size, respectively. + """ + + def __init__(self, channels_in, channels_out, band_configs): + super(SDlayer, self).__init__() + + # Initializing convolutional layers for each band + self.convs = nn.ModuleList() + self.strides = [] + self.kernels = [] + for config in band_configs.values(): + self.convs.append( + nn.Conv2d( + channels_in, + channels_out, + (config["kernel"], 1), + (config["stride"], 1), + (0, 0), + ) + ) + self.strides.append(config["stride"]) + self.kernels.append(config["kernel"]) + + # Saving rate proportions for determining splits + self.SR_low = band_configs["low"]["SR"] + self.SR_mid = band_configs["mid"]["SR"] + + def forward(self, x): + B, C, Fr, T = x.shape + # Define splitting points based on sampling rates + splits = [ + (0, math.ceil(Fr * self.SR_low)), + (math.ceil(Fr * self.SR_low), math.ceil(Fr * (self.SR_low + self.SR_mid))), + (math.ceil(Fr * (self.SR_low + self.SR_mid)), Fr), + ] + + # Processing each band with the corresponding convolution + outputs = [] + original_lengths = [] + for conv, stride, kernel, (start, end) in zip( + self.convs, self.strides, self.kernels, splits + ): + extracted = x[:, :, start:end, :] + original_lengths.append(end - start) + current_length = extracted.shape[2] + + # padding + if stride == 1: + total_padding = kernel - stride + else: + total_padding = (stride - current_length % stride) % stride + pad_left = total_padding // 2 + pad_right = total_padding - pad_left + + padded = F.pad(extracted, (0, 0, pad_left, pad_right)) + + output = conv(padded) + outputs.append(output) + + return outputs, original_lengths + + +class SUlayer(nn.Module): + """ + Implements a Sparse Up-sample Layer in decoder. + + Args: + - channels_in: The number of input channels. + - channels_out: The number of output channels. + - convtr_configs: Dictionary containing the configurations for transposed convolutions. + """ + + def __init__(self, channels_in, channels_out, band_configs): + super(SUlayer, self).__init__() + + # Initializing convolutional layers for each band + self.convtrs = nn.ModuleList( + [ + nn.ConvTranspose2d( + channels_in, + channels_out, + [config["kernel"], 1], + [config["stride"], 1], + ) + for _, config in band_configs.items() + ] + ) + + def forward(self, x, lengths, origin_lengths): + B, C, Fr, T = x.shape + # Define splitting points based on input lengths + splits = [ + (0, lengths[0]), + (lengths[0], lengths[0] + lengths[1]), + (lengths[0] + lengths[1], None), + ] + # Processing each band with the corresponding convolution + outputs = [] + for idx, (convtr, (start, end)) in enumerate(zip(self.convtrs, splits)): + out = convtr(x[:, :, start:end, :]) + # Calculate the distance to trim the output symmetrically to original length + current_Fr_length = out.shape[2] + dist = abs(origin_lengths[idx] - current_Fr_length) // 2 + + # Trim the output to the original length symmetrically + trimmed_out = out[:, :, dist : dist + origin_lengths[idx], :] + + outputs.append(trimmed_out) + + # Concatenate trimmed outputs along the frequency dimension to return the final tensor + x = torch.cat(outputs, dim=2) + + return x + + +class SDblock(nn.Module): + """ + Implements a simplified Sparse Down-sample block in encoder. + + Args: + - channels_in (int): Number of input channels. + - channels_out (int): Number of output channels. + - band_config (dict): Configuration for the SDlayer specifying band splits and convolutions. + - conv_config (dict): Configuration for convolution modules applied to each band. + - depths (list of int): List specifying the convolution depths for low, mid, and high frequency bands. + """ + + def __init__( + self, + channels_in, + channels_out, + band_configs={}, + conv_config={}, + depths=[3, 2, 1], + kernel_size=3, + ): + super(SDblock, self).__init__() + self.SDlayer = SDlayer(channels_in, channels_out, band_configs) + + # Dynamically create convolution modules for each band based on depths + self.conv_modules = nn.ModuleList( + [ConvolutionModule(channels_out, depth, **conv_config) for depth in depths] + ) + # Set the kernel_size to an odd number. + self.globalconv = nn.Conv2d( + channels_out, channels_out, kernel_size, 1, (kernel_size - 1) // 2 + ) + + def forward(self, x): + bands, original_lengths = self.SDlayer(x) + # B, C, f, T = band.shape + bands = [ + F.gelu( + conv(band.permute(0, 2, 1, 3).reshape(-1, band.shape[1], band.shape[3])) + .view(band.shape[0], band.shape[2], band.shape[1], band.shape[3]) + .permute(0, 2, 1, 3) + ) + for conv, band in zip(self.conv_modules, bands) + ] + lengths = [band.size(-2) for band in bands] + full_band = torch.cat(bands, dim=2) + skip = full_band + + output = self.globalconv(full_band) + + return output, skip, lengths, original_lengths + + +class SCNet(nn.Module): + """ + The implementation of SCNet: Sparse Compression Network for Music Source Separation. Paper: https://arxiv.org/abs/2401.13276.pdf + + Args: + - sources (List[str]): List of sources to be separated. + - audio_channels (int): Number of audio channels. + - nfft (int): Number of FFTs to determine the frequency dimension of the input. + - hop_size (int): Hop size for the STFT. + - win_size (int): Window size for STFT. + - normalized (bool): Whether to normalize the STFT. + - dims (List[int]): List of channel dimensions for each block. + - band_SR (List[float]): The proportion of each frequency band. + - band_stride (List[int]): The down-sampling ratio of each frequency band. + - band_kernel (List[int]): The kernel sizes for down-sampling convolution in each frequency band + - conv_depths (List[int]): List specifying the number of convolution modules in each SD block. + - compress (int): Compression factor for convolution module. + - conv_kernel (int): Kernel size for convolution layer in convolution module. + - num_dplayer (int): Number of dual-path layers. + - expand (int): Expansion factor in the dual-path RNN, default is 1. + + """ + + def __init__( + self, + sources=["drums", "bass", "other", "vocals"], + audio_channels=2, + # Main structure + dims=[4, 32, 64, 128], # dims = [4, 64, 128, 256] in SCNet-large + # STFT + nfft=4096, + hop_size=1024, + win_size=4096, + normalized=True, + # SD/SU layer + band_SR=[0.175, 0.392, 0.433], + band_stride=[1, 4, 16], + band_kernel=[3, 4, 16], + # Convolution Module + conv_depths=[3, 2, 1], + compress=4, + conv_kernel=3, + # Dual-path RNN + num_dplayer=6, + expand=1, + ): + super().__init__() + self.sources = sources + self.audio_channels = audio_channels + self.dims = dims + band_keys = ["low", "mid", "high"] + self.band_configs = { + band_keys[i]: { + "SR": band_SR[i], + "stride": band_stride[i], + "kernel": band_kernel[i], + } + for i in range(len(band_keys)) + } + self.hop_length = hop_size + self.conv_config = { + "compress": compress, + "kernel": conv_kernel, + } + + self.stft_config = { + "n_fft": nfft, + "hop_length": hop_size, + "win_length": win_size, + "center": True, + "normalized": normalized, + } + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + for index in range(len(dims) - 1): + enc = SDblock( + channels_in=dims[index], + channels_out=dims[index + 1], + band_configs=self.band_configs, + conv_config=self.conv_config, + depths=conv_depths, + ) + self.encoder.append(enc) + + dec = nn.Sequential( + FusionLayer(channels=dims[index + 1]), + SUlayer( + channels_in=dims[index + 1], + channels_out=( + dims[index] if index != 0 else dims[index] * len(sources) + ), + band_configs=self.band_configs, + ), + ) + self.decoder.insert(0, dec) + + self.separation_net = SeparationNet( + channels=dims[-1], + expand=expand, + num_layers=num_dplayer, + ) + + def forward(self, x): + # B, C, L = x.shape + B = x.shape[0] + # In the initial padding, ensure that the number of frames after the STFT (the length of the T dimension) is even, + # so that the RFFT operation can be used in the separation network. + padding = self.hop_length - x.shape[-1] % self.hop_length + if (x.shape[-1] + padding) // self.hop_length % 2 == 0: + padding += self.hop_length + x = F.pad(x, (0, padding)) + + # STFT + L = x.shape[-1] + x = x.reshape(-1, L) + x = torch.stft(x, **self.stft_config, return_complex=True) + x = torch.view_as_real(x) + x = x.permute(0, 3, 1, 2).reshape( + x.shape[0] // self.audio_channels, + x.shape[3] * self.audio_channels, + x.shape[1], + x.shape[2], + ) + + B, C, Fr, T = x.shape + + save_skip = deque() + save_lengths = deque() + save_original_lengths = deque() + # encoder + for sd_layer in self.encoder: + x, skip, lengths, original_lengths = sd_layer(x) + save_skip.append(skip) + save_lengths.append(lengths) + save_original_lengths.append(original_lengths) + + # separation + x = self.separation_net(x) + + # decoder + for fusion_layer, su_layer in self.decoder: + x = fusion_layer(x, save_skip.pop()) + x = su_layer(x, save_lengths.pop(), save_original_lengths.pop()) + + # output + n = self.dims[0] + x = x.view(B, n, -1, Fr, T) + x = x.reshape(-1, 2, Fr, T).permute(0, 2, 3, 1) + x = torch.view_as_complex(x.contiguous()) + x = torch.istft(x, **self.stft_config) + x = x.reshape(B, len(self.sources), self.audio_channels, -1) + + x = x[:, :, :, :-padding] + + return x diff --git a/mvsepless/models/scnet/separation.py b/mvsepless/models/scnet/separation.py new file mode 100644 index 0000000000000000000000000000000000000000..8965e2c8b14fa2c1fb6a2766e840c45128e1303d --- /dev/null +++ b/mvsepless/models/scnet/separation.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn +from torch.nn.modules.rnn import LSTM + + +class FeatureConversion(nn.Module): + """ + Integrates into the adjacent Dual-Path layer. + + Args: + channels (int): Number of input channels. + inverse (bool): If True, uses ifft; otherwise, uses rfft. + """ + + def __init__(self, channels, inverse): + super().__init__() + self.inverse = inverse + self.channels = channels + + def forward(self, x): + # B, C, F, T = x.shape + if self.inverse: + x = x.float() + x_r = x[:, : self.channels // 2, :, :] + x_i = x[:, self.channels // 2 :, :, :] + x = torch.complex(x_r, x_i) + x = torch.fft.irfft(x, dim=3, norm="ortho") + else: + x = x.float() + x = torch.fft.rfft(x, dim=3, norm="ortho") + x_real = x.real + x_imag = x.imag + x = torch.cat([x_real, x_imag], dim=1) + return x + + +class DualPathRNN(nn.Module): + """ + Dual-Path RNN in Separation Network. + + Args: + d_model (int): The number of expected features in the input (input_size). + expand (int): Expansion factor used to calculate the hidden_size of LSTM. + bidirectional (bool): If True, becomes a bidirectional LSTM. + """ + + def __init__(self, d_model, expand, bidirectional=True): + super(DualPathRNN, self).__init__() + + self.d_model = d_model + self.hidden_size = d_model * expand + self.bidirectional = bidirectional + # Initialize LSTM layers and normalization layers + self.lstm_layers = nn.ModuleList( + [self._init_lstm_layer(self.d_model, self.hidden_size) for _ in range(2)] + ) + self.linear_layers = nn.ModuleList( + [nn.Linear(self.hidden_size * 2, self.d_model) for _ in range(2)] + ) + self.norm_layers = nn.ModuleList([nn.GroupNorm(1, d_model) for _ in range(2)]) + + def _init_lstm_layer(self, d_model, hidden_size): + return LSTM( + d_model, + hidden_size, + num_layers=1, + bidirectional=self.bidirectional, + batch_first=True, + ) + + def forward(self, x): + B, C, F, T = x.shape + + # Process dual-path rnn + original_x = x + # Frequency-path + x = self.norm_layers[0](x) + x = x.transpose(1, 3).contiguous().view(B * T, F, C) + x, _ = self.lstm_layers[0](x) + x = self.linear_layers[0](x) + x = x.view(B, T, F, C).transpose(1, 3) + x = x + original_x + + original_x = x + # Time-path + x = self.norm_layers[1](x) + x = x.transpose(1, 2).contiguous().view(B * F, C, T).transpose(1, 2) + x, _ = self.lstm_layers[1](x) + x = self.linear_layers[1](x) + x = x.transpose(1, 2).contiguous().view(B, F, C, T).transpose(1, 2) + x = x + original_x + + return x + + +class SeparationNet(nn.Module): + """ + Implements a simplified Sparse Down-sample block in an encoder architecture. + + Args: + - channels (int): Number input channels. + - expand (int): Expansion factor used to calculate the hidden_size of LSTM. + - num_layers (int): Number of dual-path layers. + """ + + def __init__(self, channels, expand=1, num_layers=6): + super(SeparationNet, self).__init__() + + self.num_layers = num_layers + + self.dp_modules = nn.ModuleList( + [ + DualPathRNN(channels * (2 if i % 2 == 1 else 1), expand) + for i in range(num_layers) + ] + ) + + self.feature_conversion = nn.ModuleList( + [ + FeatureConversion(channels * 2, inverse=False if i % 2 == 0 else True) + for i in range(num_layers) + ] + ) + + def forward(self, x): + for i in range(self.num_layers): + x = self.dp_modules[i](x) + x = self.feature_conversion[i](x) + return x diff --git a/mvsepless/models/scnet_unofficial/__init__.py b/mvsepless/models/scnet_unofficial/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..298d9939f5177c6b24cca743c83a351a84a6ffce --- /dev/null +++ b/mvsepless/models/scnet_unofficial/__init__.py @@ -0,0 +1 @@ +from models.scnet_unofficial.scnet import SCNet diff --git a/mvsepless/models/scnet_unofficial/modules/__init__.py b/mvsepless/models/scnet_unofficial/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69617bb15044d9bbfd0211fcdfa0fa605b01c048 --- /dev/null +++ b/mvsepless/models/scnet_unofficial/modules/__init__.py @@ -0,0 +1,3 @@ +from models.scnet_unofficial.modules.dualpath_rnn import DualPathRNN +from models.scnet_unofficial.modules.sd_encoder import SDBlock +from models.scnet_unofficial.modules.su_decoder import SUBlock diff --git a/mvsepless/models/scnet_unofficial/modules/dualpath_rnn.py b/mvsepless/models/scnet_unofficial/modules/dualpath_rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..644d05a19cc83798402ee08f269b08a52eaeac09 --- /dev/null +++ b/mvsepless/models/scnet_unofficial/modules/dualpath_rnn.py @@ -0,0 +1,238 @@ +import torch +import torch.nn as nn +import torch.nn.functional as Func + + +class RMSNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return Func.normalize(x, dim=-1) * self.scale * self.gamma + + +class MambaModule(nn.Module): + def __init__(self, d_model, d_state, d_conv, d_expand): + super().__init__() + self.norm = RMSNorm(dim=d_model) + self.mamba = Mamba( + d_model=d_model, d_state=d_state, d_conv=d_conv, d_expand=d_expand + ) + + def forward(self, x): + x = x + self.mamba(self.norm(x)) + return x + + +class RNNModule(nn.Module): + """ + RNNModule class implements a recurrent neural network module with LSTM cells. + + Args: + - input_dim (int): Dimensionality of the input features. + - hidden_dim (int): Dimensionality of the hidden state of the LSTM. + - bidirectional (bool, optional): If True, uses bidirectional LSTM. Defaults to True. + + Shapes: + - Input: (B, T, D) where + B is batch size, + T is sequence length, + D is input dimensionality. + - Output: (B, T, D) where + B is batch size, + T is sequence length, + D is input dimensionality. + """ + + def __init__(self, input_dim: int, hidden_dim: int, bidirectional: bool = True): + """ + Initializes RNNModule with input dimension, hidden dimension, and bidirectional flag. + """ + super().__init__() + self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=input_dim) + self.rnn = nn.LSTM( + input_dim, hidden_dim, batch_first=True, bidirectional=bidirectional + ) + self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, input_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the RNNModule. + + Args: + - x (torch.Tensor): Input tensor of shape (B, T, D). + + Returns: + - torch.Tensor: Output tensor of shape (B, T, D). + """ + x = x.transpose(1, 2) + x = self.groupnorm(x) + x = x.transpose(1, 2) + + x, (hidden, _) = self.rnn(x) + x = self.fc(x) + return x + + +class RFFTModule(nn.Module): + """ + RFFTModule class implements a module for performing real-valued Fast Fourier Transform (FFT) + or its inverse on input tensors. + + Args: + - inverse (bool, optional): If False, performs forward FFT. If True, performs inverse FFT. Defaults to False. + + Shapes: + - Input: (B, F, T, D) where + B is batch size, + F is the number of features, + T is sequence length, + D is input dimensionality. + - Output: (B, F, T // 2 + 1, D * 2) if performing forward FFT. + (B, F, T, D // 2, 2) if performing inverse FFT. + """ + + def __init__(self, inverse: bool = False): + """ + Initializes RFFTModule with inverse flag. + """ + super().__init__() + self.inverse = inverse + + def forward(self, x: torch.Tensor, time_dim: int) -> torch.Tensor: + """ + Performs forward or inverse FFT on the input tensor x. + + Args: + - x (torch.Tensor): Input tensor of shape (B, F, T, D). + - time_dim (int): Input size of time dimension. + + Returns: + - torch.Tensor: Output tensor after FFT or its inverse operation. + """ + dtype = x.dtype + B, F, T, D = x.shape + + # RuntimeError: cuFFT only supports dimensions whose sizes are powers of two when computing in half precision + x = x.float() + + if not self.inverse: + x = torch.fft.rfft(x, dim=2) + x = torch.view_as_real(x) + x = x.reshape(B, F, T // 2 + 1, D * 2) + else: + x = x.reshape(B, F, T, D // 2, 2) + x = torch.view_as_complex(x) + x = torch.fft.irfft(x, n=time_dim, dim=2) + + x = x.to(dtype) + return x + + def extra_repr(self) -> str: + """ + Returns extra representation string with module's configuration. + """ + return f"inverse={self.inverse}" + + +class DualPathRNN(nn.Module): + """ + DualPathRNN class implements a neural network with alternating layers of RNNModule and RFFTModule. + + Args: + - n_layers (int): Number of layers in the network. + - input_dim (int): Dimensionality of the input features. + - hidden_dim (int): Dimensionality of the hidden state of the RNNModule. + + Shapes: + - Input: (B, F, T, D) where + B is batch size, + F is the number of features (frequency dimension), + T is sequence length (time dimension), + D is input dimensionality (channel dimension). + - Output: (B, F, T, D) where + B is batch size, + F is the number of features (frequency dimension), + T is sequence length (time dimension), + D is input dimensionality (channel dimension). + """ + + def __init__( + self, + n_layers: int, + input_dim: int, + hidden_dim: int, + use_mamba: bool = False, + d_state: int = 16, + d_conv: int = 4, + d_expand: int = 2, + ): + """ + Initializes DualPathRNN with the specified number of layers, input dimension, and hidden dimension. + """ + super().__init__() + + if use_mamba: + from mamba_ssm.modules.mamba_simple import Mamba + + net = MambaModule + dkwargs = { + "d_model": input_dim, + "d_state": d_state, + "d_conv": d_conv, + "d_expand": d_expand, + } + ukwargs = { + "d_model": input_dim * 2, + "d_state": d_state, + "d_conv": d_conv, + "d_expand": d_expand * 2, + } + else: + net = RNNModule + dkwargs = {"input_dim": input_dim, "hidden_dim": hidden_dim} + ukwargs = {"input_dim": input_dim * 2, "hidden_dim": hidden_dim * 2} + + self.layers = nn.ModuleList() + for i in range(1, n_layers + 1): + kwargs = dkwargs if i % 2 == 1 else ukwargs + layer = nn.ModuleList( + [ + net(**kwargs), + net(**kwargs), + RFFTModule(inverse=(i % 2 == 0)), + ] + ) + self.layers.append(layer) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the DualPathRNN. + + Args: + - x (torch.Tensor): Input tensor of shape (B, F, T, D). + + Returns: + - torch.Tensor: Output tensor of shape (B, F, T, D). + """ + + time_dim = x.shape[2] + + for time_layer, freq_layer, rfft_layer in self.layers: + B, F, T, D = x.shape + + x = x.reshape((B * F), T, D) + x = time_layer(x) + x = x.reshape(B, F, T, D) + x = x.permute(0, 2, 1, 3) + + x = x.reshape((B * T), F, D) + x = freq_layer(x) + x = x.reshape(B, T, F, D) + x = x.permute(0, 2, 1, 3) + + x = rfft_layer(x, time_dim) + + return x diff --git a/mvsepless/models/scnet_unofficial/modules/sd_encoder.py b/mvsepless/models/scnet_unofficial/modules/sd_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..742577f480693671437dc50358a1a65d251b6e9b --- /dev/null +++ b/mvsepless/models/scnet_unofficial/modules/sd_encoder.py @@ -0,0 +1,285 @@ +from typing import List, Tuple + +import torch +import torch.nn as nn + +from models.scnet_unofficial.utils import create_intervals + + +class Downsample(nn.Module): + """ + Downsample class implements a module for downsampling input tensors using 2D convolution. + + Args: + - input_dim (int): Dimensionality of the input channels. + - output_dim (int): Dimensionality of the output channels. + - stride (int): Stride value for the convolution operation. + + Shapes: + - Input: (B, C_in, F, T) where + B is batch size, + C_in is the number of input channels, + F is the frequency dimension, + T is the time dimension. + - Output: (B, C_out, F // stride, T) where + B is batch size, + C_out is the number of output channels, + F // stride is the downsampled frequency dimension. + + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + stride: int, + ): + """ + Initializes Downsample with input dimension, output dimension, and stride. + """ + super().__init__() + self.conv = nn.Conv2d(input_dim, output_dim, 1, (stride, 1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the Downsample module. + + Args: + - x (torch.Tensor): Input tensor of shape (B, C_in, F, T). + + Returns: + - torch.Tensor: Downsampled tensor of shape (B, C_out, F // stride, T). + """ + return self.conv(x) + + +class ConvolutionModule(nn.Module): + """ + ConvolutionModule class implements a module with a sequence of convolutional layers similar to Conformer. + + Args: + - input_dim (int): Dimensionality of the input features. + - hidden_dim (int): Dimensionality of the hidden features. + - kernel_sizes (List[int]): List of kernel sizes for the convolutional layers. + - bias (bool, optional): If True, adds a learnable bias to the output. Default is False. + + Shapes: + - Input: (B, T, D) where + B is batch size, + T is sequence length, + D is input dimensionality. + - Output: (B, T, D) where + B is batch size, + T is sequence length, + D is input dimensionality. + """ + + def __init__( + self, + input_dim: int, + hidden_dim: int, + kernel_sizes: List[int], + bias: bool = False, + ) -> None: + """ + Initializes ConvolutionModule with input dimension, hidden dimension, kernel sizes, and bias. + """ + super().__init__() + self.sequential = nn.Sequential( + nn.GroupNorm(num_groups=1, num_channels=input_dim), + nn.Conv1d( + input_dim, + 2 * hidden_dim, + kernel_sizes[0], + stride=1, + padding=(kernel_sizes[0] - 1) // 2, + bias=bias, + ), + nn.GLU(dim=1), + nn.Conv1d( + hidden_dim, + hidden_dim, + kernel_sizes[1], + stride=1, + padding=(kernel_sizes[1] - 1) // 2, + groups=hidden_dim, + bias=bias, + ), + nn.GroupNorm(num_groups=1, num_channels=hidden_dim), + nn.SiLU(), + nn.Conv1d( + hidden_dim, + input_dim, + kernel_sizes[2], + stride=1, + padding=(kernel_sizes[2] - 1) // 2, + bias=bias, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the ConvolutionModule. + + Args: + - x (torch.Tensor): Input tensor of shape (B, T, D). + + Returns: + - torch.Tensor: Output tensor of shape (B, T, D). + """ + x = x.transpose(1, 2) + x = x + self.sequential(x) + x = x.transpose(1, 2) + return x + + +class SDLayer(nn.Module): + """ + SDLayer class implements a subband decomposition layer with downsampling and convolutional modules. + + Args: + - subband_interval (Tuple[float, float]): Tuple representing the frequency interval for subband decomposition. + - input_dim (int): Dimensionality of the input channels. + - output_dim (int): Dimensionality of the output channels after downsampling. + - downsample_stride (int): Stride value for the downsampling operation. + - n_conv_modules (int): Number of convolutional modules. + - kernel_sizes (List[int]): List of kernel sizes for the convolutional layers. + - bias (bool, optional): If True, adds a learnable bias to the convolutional layers. Default is True. + + Shapes: + - Input: (B, Fi, T, Ci) where + B is batch size, + Fi is the number of input subbands, + T is sequence length, and + Ci is the number of input channels. + - Output: (B, Fi+1, T, Ci+1) where + B is batch size, + Fi+1 is the number of output subbands, + T is sequence length, + Ci+1 is the number of output channels. + """ + + def __init__( + self, + subband_interval: Tuple[float, float], + input_dim: int, + output_dim: int, + downsample_stride: int, + n_conv_modules: int, + kernel_sizes: List[int], + bias: bool = True, + ): + """ + Initializes SDLayer with subband interval, input dimension, + output dimension, downsample stride, number of convolutional modules, kernel sizes, and bias. + """ + super().__init__() + self.subband_interval = subband_interval + self.downsample = Downsample(input_dim, output_dim, downsample_stride) + self.activation = nn.GELU() + conv_modules = [ + ConvolutionModule( + input_dim=output_dim, + hidden_dim=output_dim // 4, + kernel_sizes=kernel_sizes, + bias=bias, + ) + for _ in range(n_conv_modules) + ] + self.conv_modules = nn.Sequential(*conv_modules) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the SDLayer. + + Args: + - x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci). + + Returns: + - torch.Tensor: Output tensor of shape (B, Fi+1, T, Ci+1). + """ + B, F, T, C = x.shape + x = x[:, int(self.subband_interval[0] * F) : int(self.subband_interval[1] * F)] + x = x.permute(0, 3, 1, 2) + x = self.downsample(x) + x = self.activation(x) + x = x.permute(0, 2, 3, 1) + + B, F, T, C = x.shape + x = x.reshape((B * F), T, C) + x = self.conv_modules(x) + x = x.reshape(B, F, T, C) + + return x + + +class SDBlock(nn.Module): + """ + SDBlock class implements a block with subband decomposition layers and global convolution. + + Args: + - input_dim (int): Dimensionality of the input channels. + - output_dim (int): Dimensionality of the output channels. + - bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands. + - downsample_strides (List[int]): List of stride values for downsampling in each subband layer. + - n_conv_modules (List[int]): List specifying the number of convolutional modules in each subband layer. + - kernel_sizes (List[int], optional): List of kernel sizes for the convolutional layers. Default is None. + + Shapes: + - Input: (B, Fi, T, Ci) where + B is batch size, + Fi is the number of input subbands, + T is sequence length, + Ci is the number of input channels. + - Output: (B, Fi+1, T, Ci+1) where + B is batch size, + Fi+1 is the number of output subbands, + T is sequence length, + Ci+1 is the number of output channels. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + bandsplit_ratios: List[float], + downsample_strides: List[int], + n_conv_modules: List[int], + kernel_sizes: List[int] = None, + ): + """ + Initializes SDBlock with input dimension, output dimension, band split ratios, downsample strides, number of convolutional modules, and kernel sizes. + """ + super().__init__() + if kernel_sizes is None: + kernel_sizes = [3, 3, 1] + assert sum(bandsplit_ratios) == 1, "The split ratios must sum up to 1." + subband_intervals = create_intervals(bandsplit_ratios) + self.sd_layers = nn.ModuleList( + SDLayer( + input_dim=input_dim, + output_dim=output_dim, + subband_interval=sbi, + downsample_stride=dss, + n_conv_modules=ncm, + kernel_sizes=kernel_sizes, + ) + for sbi, dss, ncm in zip( + subband_intervals, downsample_strides, n_conv_modules + ) + ) + self.global_conv2d = nn.Conv2d(output_dim, output_dim, 1, 1) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Performs forward pass through the SDBlock. + + Args: + - x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci). + + Returns: + - Tuple[torch.Tensor, torch.Tensor]: Output tensor and skip connection tensor. + """ + x_skip = torch.concat([layer(x) for layer in self.sd_layers], dim=1) + x = self.global_conv2d(x_skip.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + return x, x_skip diff --git a/mvsepless/models/scnet_unofficial/modules/su_decoder.py b/mvsepless/models/scnet_unofficial/modules/su_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..660c1fa6cbfd9b43bed73204a0bb6593524de272 --- /dev/null +++ b/mvsepless/models/scnet_unofficial/modules/su_decoder.py @@ -0,0 +1,241 @@ +from typing import List, Tuple + +import torch +import torch.nn as nn + +from models.scnet_unofficial.utils import get_convtranspose_output_padding + + +class FusionLayer(nn.Module): + """ + FusionLayer class implements a module for fusing two input tensors using convolutional operations. + + Args: + - input_dim (int): Dimensionality of the input channels. + - kernel_size (int, optional): Kernel size for the convolutional layer. Default is 3. + - stride (int, optional): Stride value for the convolutional layer. Default is 1. + - padding (int, optional): Padding value for the convolutional layer. Default is 1. + + Shapes: + - Input: (B, F, T, C) and (B, F, T, C) where + B is batch size, + F is the number of features, + T is sequence length, + C is input dimensionality. + - Output: (B, F, T, C) where + B is batch size, + F is the number of features, + T is sequence length, + C is input dimensionality. + """ + + def __init__( + self, input_dim: int, kernel_size: int = 3, stride: int = 1, padding: int = 1 + ): + """ + Initializes FusionLayer with input dimension, kernel size, stride, and padding. + """ + super().__init__() + self.conv = nn.Conv2d( + input_dim * 2, + input_dim * 2, + kernel_size=(kernel_size, 1), + stride=(stride, 1), + padding=(padding, 0), + ) + self.activation = nn.GLU() + + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the FusionLayer. + + Args: + - x1 (torch.Tensor): First input tensor of shape (B, F, T, C). + - x2 (torch.Tensor): Second input tensor of shape (B, F, T, C). + + Returns: + - torch.Tensor: Output tensor of shape (B, F, T, C). + """ + x = x1 + x2 + x = x.repeat(1, 1, 1, 2) + x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + x = self.activation(x) + return x + + +class Upsample(nn.Module): + """ + Upsample class implements a module for upsampling input tensors using transposed 2D convolution. + + Args: + - input_dim (int): Dimensionality of the input channels. + - output_dim (int): Dimensionality of the output channels. + - stride (int): Stride value for the transposed convolution operation. + - output_padding (int): Output padding value for the transposed convolution operation. + + Shapes: + - Input: (B, C_in, F, T) where + B is batch size, + C_in is the number of input channels, + F is the frequency dimension, + T is the time dimension. + - Output: (B, C_out, F * stride + output_padding, T) where + B is batch size, + C_out is the number of output channels, + F * stride + output_padding is the upsampled frequency dimension. + """ + + def __init__( + self, input_dim: int, output_dim: int, stride: int, output_padding: int + ): + """ + Initializes Upsample with input dimension, output dimension, stride, and output padding. + """ + super().__init__() + self.conv = nn.ConvTranspose2d( + input_dim, output_dim, 1, (stride, 1), output_padding=(output_padding, 0) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the Upsample module. + + Args: + - x (torch.Tensor): Input tensor of shape (B, C_in, F, T). + + Returns: + - torch.Tensor: Output tensor of shape (B, C_out, F * stride + output_padding, T). + """ + return self.conv(x) + + +class SULayer(nn.Module): + """ + SULayer class implements a subband upsampling layer using transposed convolution. + + Args: + - input_dim (int): Dimensionality of the input channels. + - output_dim (int): Dimensionality of the output channels. + - upsample_stride (int): Stride value for the upsampling operation. + - subband_shape (int): Shape of the subband. + - sd_interval (Tuple[int, int]): Start and end indices of the subband interval. + + Shapes: + - Input: (B, F, T, C) where + B is batch size, + F is the number of features, + T is sequence length, + C is input dimensionality. + - Output: (B, F, T, C) where + B is batch size, + F is the number of features, + T is sequence length, + C is input dimensionality. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + upsample_stride: int, + subband_shape: int, + sd_interval: Tuple[int, int], + ): + """ + Initializes SULayer with input dimension, output dimension, upsample stride, subband shape, and subband interval. + """ + super().__init__() + sd_shape = sd_interval[1] - sd_interval[0] + upsample_output_padding = get_convtranspose_output_padding( + input_shape=sd_shape, output_shape=subband_shape, stride=upsample_stride + ) + self.upsample = Upsample( + input_dim=input_dim, + output_dim=output_dim, + stride=upsample_stride, + output_padding=upsample_output_padding, + ) + self.sd_interval = sd_interval + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the SULayer. + + Args: + - x (torch.Tensor): Input tensor of shape (B, F, T, C). + + Returns: + - torch.Tensor: Output tensor of shape (B, F, T, C). + """ + x = x[:, self.sd_interval[0] : self.sd_interval[1]] + x = x.permute(0, 3, 1, 2) + x = self.upsample(x) + x = x.permute(0, 2, 3, 1) + return x + + +class SUBlock(nn.Module): + """ + SUBlock class implements a block with fusion layer and subband upsampling layers. + + Args: + - input_dim (int): Dimensionality of the input channels. + - output_dim (int): Dimensionality of the output channels. + - upsample_strides (List[int]): List of stride values for the upsampling operations. + - subband_shapes (List[int]): List of shapes for the subbands. + - sd_intervals (List[Tuple[int, int]]): List of intervals for subband decomposition. + + Shapes: + - Input: (B, Fi-1, T, Ci-1) and (B, Fi-1, T, Ci-1) where + B is batch size, + Fi-1 is the number of input subbands, + T is sequence length, + Ci-1 is the number of input channels. + - Output: (B, Fi, T, Ci) where + B is batch size, + Fi is the number of output subbands, + T is sequence length, + Ci is the number of output channels. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + upsample_strides: List[int], + subband_shapes: List[int], + sd_intervals: List[Tuple[int, int]], + ): + """ + Initializes SUBlock with input dimension, output dimension, + upsample strides, subband shapes, and subband intervals. + """ + super().__init__() + self.fusion_layer = FusionLayer(input_dim=input_dim) + self.su_layers = nn.ModuleList( + SULayer( + input_dim=input_dim, + output_dim=output_dim, + upsample_stride=uss, + subband_shape=sbs, + sd_interval=sdi, + ) + for i, (uss, sbs, sdi) in enumerate( + zip(upsample_strides, subband_shapes, sd_intervals) + ) + ) + + def forward(self, x: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the SUBlock. + + Args: + - x (torch.Tensor): Input tensor of shape (B, Fi-1, T, Ci-1). + - x_skip (torch.Tensor): Input skip connection tensor of shape (B, Fi-1, T, Ci-1). + + Returns: + - torch.Tensor: Output tensor of shape (B, Fi, T, Ci). + """ + x = self.fusion_layer(x, x_skip) + x = torch.concat([layer(x) for layer in self.su_layers], dim=1) + return x diff --git a/mvsepless/models/scnet_unofficial/scnet.py b/mvsepless/models/scnet_unofficial/scnet.py new file mode 100644 index 0000000000000000000000000000000000000000..d6dcf7285da2358b6bf2dcd2b9a177fed6022a0d --- /dev/null +++ b/mvsepless/models/scnet_unofficial/scnet.py @@ -0,0 +1,246 @@ +""" +SCNet - great paper, great implementation +https://arxiv.org/pdf/2401.13276.pdf +https://github.com/amanteur/SCNet-PyTorch +""" + +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + +from models.scnet_unofficial.modules import DualPathRNN, SDBlock, SUBlock +from models.scnet_unofficial.utils import compute_sd_layer_shapes, compute_gcr + +from einops import rearrange, pack, unpack +from functools import partial + +from beartype.typing import Tuple, Optional, List, Callable +from beartype import beartype + + +def exists(val): + return val is not None + + +def default(v, d): + return v if exists(v) else d + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +class RMSNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return F.normalize(x, dim=-1) * self.scale * self.gamma + + +class BandSplit(nn.Module): + @beartype + def __init__(self, dim, dim_inputs: Tuple[int, ...]): + super().__init__() + self.dim_inputs = dim_inputs + self.to_features = ModuleList([]) + + for dim_in in dim_inputs: + net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim)) + + self.to_features.append(net) + + def forward(self, x): + x = x.split(self.dim_inputs, dim=-1) + + outs = [] + for split_input, to_feature in zip(x, self.to_features): + split_output = to_feature(split_input) + outs.append(split_output) + + return torch.stack(outs, dim=-2) + + +class SCNet(nn.Module): + """ + SCNet class implements a source separation network, + which explicitly split the spectrogram of the mixture into several subbands + and introduce a sparsity-based encoder to model different frequency bands. + + Paper: "SCNET: SPARSE COMPRESSION NETWORK FOR MUSIC SOURCE SEPARATION" + Authors: Weinan Tong, Jiaxu Zhu et al. + Link: https://arxiv.org/abs/2401.13276.pdf + + Args: + - n_fft (int): Number of FFTs to determine the frequency dimension of the input. + - dims (List[int]): List of channel dimensions for each block. + - bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands. + - downsample_strides (List[int]): List of stride values for downsampling in each block. + - n_conv_modules (List[int]): List specifying the number of convolutional modules in each block. + - n_rnn_layers (int): Number of recurrent layers in the dual path RNN. + - rnn_hidden_dim (int): Dimensionality of the hidden state in the dual path RNN. + - n_sources (int, optional): Number of sources to be separated. Default is 4. + + Shapes: + - Input: (B, C, T) where + B is batch size, + C is channel dim (mono / stereo), + T is time dim + - Output: (B, N, C, T) where + B is batch size, + N is the number of sources. + C is channel dim (mono / stereo), + T is sequence length, + """ + + @beartype + def __init__( + self, + n_fft: int, + dims: List[int], + bandsplit_ratios: List[float], + downsample_strides: List[int], + n_conv_modules: List[int], + n_rnn_layers: int, + rnn_hidden_dim: int, + n_sources: int = 4, + hop_length: int = 1024, + win_length: int = 4096, + stft_window_fn: Optional[Callable] = None, + stft_normalized: bool = False, + **kwargs, + ): + """ + Initializes SCNet with input parameters. + """ + super().__init__() + self.assert_input_data( + bandsplit_ratios, + downsample_strides, + n_conv_modules, + ) + + n_blocks = len(dims) - 1 + n_freq_bins = n_fft // 2 + 1 + subband_shapes, sd_intervals = compute_sd_layer_shapes( + input_shape=n_freq_bins, + bandsplit_ratios=bandsplit_ratios, + downsample_strides=downsample_strides, + n_layers=n_blocks, + ) + self.sd_blocks = nn.ModuleList( + SDBlock( + input_dim=dims[i], + output_dim=dims[i + 1], + bandsplit_ratios=bandsplit_ratios, + downsample_strides=downsample_strides, + n_conv_modules=n_conv_modules, + ) + for i in range(n_blocks) + ) + self.dualpath_blocks = DualPathRNN( + n_layers=n_rnn_layers, + input_dim=dims[-1], + hidden_dim=rnn_hidden_dim, + **kwargs, + ) + self.su_blocks = nn.ModuleList( + SUBlock( + input_dim=dims[i + 1], + output_dim=dims[i] if i != 0 else dims[i] * n_sources, + subband_shapes=subband_shapes[i], + sd_intervals=sd_intervals[i], + upsample_strides=downsample_strides, + ) + for i in reversed(range(n_blocks)) + ) + self.gcr = compute_gcr(subband_shapes) + + self.stft_kwargs = dict( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + normalized=stft_normalized, + ) + + self.stft_window_fn = partial( + default(stft_window_fn, torch.hann_window), win_length + ) + self.n_sources = n_sources + self.hop_length = hop_length + + @staticmethod + def assert_input_data(*args): + """ + Asserts that the shapes of input features are equal. + """ + for arg1 in args: + for arg2 in args: + if len(arg1) != len(arg2): + raise ValueError( + f"Shapes of input features {arg1} and {arg2} are not equal." + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the SCNet. + + Args: + - x (torch.Tensor): Input tensor of shape (B, C, T). + + Returns: + - torch.Tensor: Output tensor of shape (B, N, C, T). + """ + + device = x.device + stft_window = self.stft_window_fn(device=device) + + if x.ndim == 2: + x = rearrange(x, "b t -> b 1 t") + + c = x.shape[1] + + stft_pad = self.hop_length - x.shape[-1] % self.hop_length + x = F.pad(x, (0, stft_pad)) + + # stft + x, ps = pack_one(x, "* t") + x = torch.stft(x, **self.stft_kwargs, window=stft_window, return_complex=True) + x = torch.view_as_real(x) + x = unpack_one(x, ps, "* c f t") + x = rearrange(x, "b c f t r -> b f t (c r)") + + # encoder part + x_skips = [] + for sd_block in self.sd_blocks: + x, x_skip = sd_block(x) + x_skips.append(x_skip) + + # separation part + x = self.dualpath_blocks(x) + + # decoder part + for su_block, x_skip in zip(self.su_blocks, reversed(x_skips)): + x = su_block(x, x_skip) + + # istft + x = rearrange(x, "b f t (c r n) -> b n c f t r", c=c, n=self.n_sources, r=2) + x = x.contiguous() + + x = torch.view_as_complex(x) + x = rearrange(x, "b n c f t -> (b n c) f t") + x = torch.istft(x, **self.stft_kwargs, window=stft_window, return_complex=False) + x = rearrange(x, "(b n c) t -> b n c t", c=c, n=self.n_sources) + + x = x[..., :-stft_pad] + + return x diff --git a/mvsepless/models/scnet_unofficial/utils.py b/mvsepless/models/scnet_unofficial/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d236d499322a5db4ae4a813e75901dcfe28f7993 --- /dev/null +++ b/mvsepless/models/scnet_unofficial/utils.py @@ -0,0 +1,135 @@ +""" +SCNet - great paper, great implementation +https://arxiv.org/pdf/2401.13276.pdf +https://github.com/amanteur/SCNet-PyTorch +""" + +from typing import List, Tuple, Union + +import torch + + +def create_intervals( + splits: List[Union[float, int]], +) -> List[Union[Tuple[float, float], Tuple[int, int]]]: + """ + Create intervals based on splits provided. + + Args: + - splits (List[Union[float, int]]): List of floats or integers representing splits. + + Returns: + - List[Union[Tuple[float, float], Tuple[int, int]]]: List of tuples representing intervals. + """ + start = 0 + return [(start, start := start + split) for split in splits] + + +def get_conv_output_shape( + input_shape: int, + kernel_size: int = 1, + padding: int = 0, + dilation: int = 1, + stride: int = 1, +) -> int: + """ + Compute the output shape of a convolutional layer. + + Args: + - input_shape (int): Input shape. + - kernel_size (int, optional): Kernel size of the convolution. Default is 1. + - padding (int, optional): Padding size. Default is 0. + - dilation (int, optional): Dilation factor. Default is 1. + - stride (int, optional): Stride value. Default is 1. + + Returns: + - int: Output shape. + """ + return int( + (input_shape + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 + ) + + +def get_convtranspose_output_padding( + input_shape: int, + output_shape: int, + kernel_size: int = 1, + padding: int = 0, + dilation: int = 1, + stride: int = 1, +) -> int: + """ + Compute the output padding for a convolution transpose operation. + + Args: + - input_shape (int): Input shape. + - output_shape (int): Desired output shape. + - kernel_size (int, optional): Kernel size of the convolution. Default is 1. + - padding (int, optional): Padding size. Default is 0. + - dilation (int, optional): Dilation factor. Default is 1. + - stride (int, optional): Stride value. Default is 1. + + Returns: + - int: Output padding. + """ + return ( + output_shape + - (input_shape - 1) * stride + + 2 * padding + - dilation * (kernel_size - 1) + - 1 + ) + + +def compute_sd_layer_shapes( + input_shape: int, + bandsplit_ratios: List[float], + downsample_strides: List[int], + n_layers: int, +) -> Tuple[List[List[int]], List[List[Tuple[int, int]]]]: + """ + Compute the shapes for the subband layers. + + Args: + - input_shape (int): Input shape. + - bandsplit_ratios (List[float]): Ratios for splitting the frequency bands. + - downsample_strides (List[int]): Strides for downsampling in each layer. + - n_layers (int): Number of layers. + + Returns: + - Tuple[List[List[int]], List[List[Tuple[int, int]]]]: Tuple containing subband shapes and convolution shapes. + """ + bandsplit_shapes_list = [] + conv2d_shapes_list = [] + for _ in range(n_layers): + bandsplit_intervals = create_intervals(bandsplit_ratios) + bandsplit_shapes = [ + int(right * input_shape) - int(left * input_shape) + for left, right in bandsplit_intervals + ] + conv2d_shapes = [ + get_conv_output_shape(bs, stride=ds) + for bs, ds in zip(bandsplit_shapes, downsample_strides) + ] + input_shape = sum(conv2d_shapes) + bandsplit_shapes_list.append(bandsplit_shapes) + conv2d_shapes_list.append(create_intervals(conv2d_shapes)) + + return bandsplit_shapes_list, conv2d_shapes_list + + +def compute_gcr(subband_shapes: List[List[int]]) -> float: + """ + Compute the global compression ratio. + + Args: + - subband_shapes (List[List[int]]): List of subband shapes. + + Returns: + - float: Global compression ratio. + """ + t = torch.Tensor(subband_shapes) + gcr = torch.stack( + [(1 - t[i + 1] / t[i]).mean() for i in range(0, len(t) - 1)] + ).mean() + return float(gcr) diff --git a/mvsepless/models/vr_arch/__init__.py b/mvsepless/models/vr_arch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc75c42ef220d5787ba8371835a0cd48109806bf --- /dev/null +++ b/mvsepless/models/vr_arch/__init__.py @@ -0,0 +1,321 @@ +import os +import math +import sys +import json +import torch +import torch.nn as nn +import librosa +import numpy as np +from tqdm import tqdm + +from . import spec_utils, nets, nets_new +from .model_param_init import ModelParameters + +VOCAL_STEM = "vocals" +INST_STEM = "instrumental" +OTHER_STEM = "other" +BASS_STEM = "bass" +DRUM_STEM = "drums" +GUITAR_STEM = "guitar" +PIANO_STEM = "piano" +SYNTH_STEM = "synthesizer" +STRINGS_STEM = "strings" +WOODWINDS_STEM = "woodwinds" +BRASS_STEM = "brass" +WIND_INST_STEM = "wind_inst" + +NON_ACCOM_STEMS = ( + VOCAL_STEM, + OTHER_STEM, + BASS_STEM, + DRUM_STEM, + GUITAR_STEM, + PIANO_STEM, + SYNTH_STEM, + STRINGS_STEM, + WOODWINDS_STEM, + BRASS_STEM, + WIND_INST_STEM, +) + +class VRNet: + def __init__( + self, + model_params={}, + nout=None, + nout_lstm=None, + ): + self.torch_device_mps = torch.backends.mps.is_available() + self.enable_post_process = False + self.post_process_threshold = 0.2 + self.batch_size = 1 + self.window_size = 512 + self.high_end_process = False + self.primary_stem = "Instrumental" + self.secondary_stem = "Vocals" + self.model_capacity = 32, 128 + self.is_vr_51_model = False + if nout and nout_lstm: + self.model_capacity = nout, nout_lstm + self.is_vr_51_model = True + self.model_params = ModelParameters(model_params) + self.input_high_end_h = None + self.input_high_end = None + self.enable_tta = False + self.model_samplerate = self.model_params.param["sr"] + self.model_run = lambda *args, **kwargs: print( + "Model run method is not initialised yet." + ) + + def load_checkpoint(self, checkpoint_path: str, device: torch.device): + nn_arch_sizes = [ + 31191, + 33966, + 56817, + 123821, + 123812, + 129605, + 218409, + 537238, + 537227, + ] # default + vr_5_1_models = [56817, 218409] + model_size = math.ceil(os.stat(checkpoint_path).st_size / 1024) + nn_arch_size = min(nn_arch_sizes, key=lambda x: abs(x - model_size)) + + if nn_arch_size in vr_5_1_models or self.is_vr_51_model: + self.model_run = nets_new.CascadedNet( + self.model_params.param["bins"] * 2, + nn_arch_size, + nout=self.model_capacity[0], + nout_lstm=self.model_capacity[1], + ) + self.is_vr_51_model = True + else: + self.model_run = nets.determine_model_capacity( + self.model_params.param["bins"] * 2, nn_arch_size + ) + + self.model_run.load_state_dict( + torch.load(checkpoint_path, map_location=device) + ) + self.model_run.to(device) + + def loading_mix(self, numpy_array, orig_sr=44100): + X_wave, X_spec_s = {}, {} + + bands_n = len(self.model_params.param["band"]) + + audio_file = numpy_array + + for d in tqdm(range(bands_n, 0, -1)): + bp = self.model_params.param["band"][d] + + wav_resolution = bp["res_type"] + + if self.torch_device_mps: + wav_resolution = "polyphase" + + if d == bands_n: # high-end band + X_wave[d], _ = librosa.resample( + y=numpy_array, + orig_sr=orig_sr, + target_sr=bp["sr"], + res_type=wav_resolution, + ) + X_spec_s[d] = spec_utils.wave_to_spectrogram( + X_wave[d], + bp["hl"], + bp["n_fft"], + self.model_params, + band=d, + is_v51_model=self.is_vr_51_model, + ) + + if X_wave[d].ndim == 1: + X_wave[d] = np.asarray([X_wave[d], X_wave[d]]) + else: # lower bands + X_wave[d] = librosa.resample( + X_wave[d + 1], + orig_sr=self.model_params.param["band"][d + 1]["sr"], + target_sr=bp["sr"], + res_type=wav_resolution, + ) + X_spec_s[d] = spec_utils.wave_to_spectrogram( + X_wave[d], + bp["hl"], + bp["n_fft"], + self.model_params, + band=d, + is_v51_model=self.is_vr_51_model, + ) + + if d == bands_n and self.high_end_process: + self.input_high_end_h = (bp["n_fft"] // 2 - bp["crop_stop"]) + ( + self.model_params.param["pre_filter_stop"] + - self.model_params.param["pre_filter_start"] + ) + self.input_high_end = X_spec_s[d][ + :, bp["n_fft"] // 2 - self.input_high_end_h : bp["n_fft"] // 2, : + ] + + X_spec = spec_utils.combine_spectrograms( + X_spec_s, self.model_params, is_v51_model=self.is_vr_51_model + ) + + del X_wave, X_spec_s + + return X_spec + + def inference_vr(self, X_spec, device, aggressiveness): + def _execute(X_mag_pad, roi_size): + X_dataset = [] + patches = (X_mag_pad.shape[2] - 2 * self.model_run.offset) // roi_size + total = patches + for i in tqdm(range(patches)): + processed = min(i + self.batch_size, patches) + sys.stdout.write(json.dumps({"processing": {"processed": processed, "total": total}}, ensure_ascii=False) + '\n') + start = i * roi_size + X_mag_window = X_mag_pad[:, :, start : start + self.window_size] + X_dataset.append(X_mag_window) + + total_iterations = ( + patches // self.batch_size + if not self.enable_tta + else (patches // self.batch_size) * 2 + ) + + X_dataset = np.asarray(X_dataset) + self.model_run.eval() + with torch.no_grad(): + mask = [] + + for i in tqdm(range(0, patches, self.batch_size)): + processed = min(i + self.batch_size, patches) + sys.stdout.write(json.dumps({"processing": {"processed": processed, "total": total}}, ensure_ascii=False) + '\n') + X_batch = X_dataset[i : i + self.batch_size] + X_batch = torch.from_numpy(X_batch).to(device) + pred = self.model_run.predict_mask(X_batch) + if not pred.size()[3] > 0: + raise ValueError( + f"Window size error: h1_shape[3] must be greater than h2_shape[3]" + ) + pred = pred.detach().cpu().numpy() + pred = np.concatenate(pred, axis=2) + mask.append(pred) + if len(mask) == 0: + raise ValueError( + f"Window size error: h1_shape[3] must be greater than h2_shape[3]" + ) + + mask = np.concatenate(mask, axis=2) + return mask + + def postprocess(mask, X_mag, X_phase): + is_non_accom_stem = False + for stem in NON_ACCOM_STEMS: + if stem == self.primary_stem: + is_non_accom_stem = True + + mask = spec_utils.adjust_aggr(mask, is_non_accom_stem, aggressiveness) + + if self.enable_post_process: + mask = spec_utils.merge_artifacts( + mask, thres=self.post_process_threshold + ) + + y_spec = mask * X_mag * np.exp(1.0j * X_phase) + v_spec = (1 - mask) * X_mag * np.exp(1.0j * X_phase) + + return y_spec, v_spec + + X_mag, X_phase = spec_utils.preprocess(X_spec) + n_frame = X_mag.shape[2] + pad_l, pad_r, roi_size = spec_utils.make_padding( + n_frame, self.window_size, self.model_run.offset + ) + X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant") + X_mag_pad /= X_mag_pad.max() + mask = _execute(X_mag_pad, roi_size) + + if self.enable_tta: + pad_l += roi_size // 2 + pad_r += roi_size // 2 + X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant") + X_mag_pad /= X_mag_pad.max() + mask_tta = _execute(X_mag_pad, roi_size) + mask_tta = mask_tta[:, :, roi_size // 2 :] + mask = (mask[:, :, :n_frame] + mask_tta[:, :, :n_frame]) * 0.5 + else: + mask = mask[:, :, :n_frame] + + y_spec, v_spec = postprocess(mask, X_mag, X_phase) + + return y_spec, v_spec + + def spec_to_wav(self, spec): + if ( + self.high_end_process + and isinstance(self.input_high_end, np.ndarray) + and self.input_high_end_h + ): + input_high_end_ = spec_utils.mirroring( + "mirroring", spec, self.input_high_end, self.model_params + ) + wav = spec_utils.cmb_spectrogram_to_wave( + spec, + self.model_params, + self.input_high_end_h, + input_high_end_, + is_v51_model=self.is_vr_51_model, + ) + else: + wav = spec_utils.cmb_spectrogram_to_wave( + spec, self.model_params, is_v51_model=self.is_vr_51_model + ) + + return wav + + def settings(self, + enable_post_process=False, + post_process_threshold=0.2, + batch_size=1, + window_size=512, + high_end_process=False, + primary_stem="Instrumental", + secondary_stem="Vocals" + ): + self.enable_post_process = enable_post_process + self.post_process_threshold = post_process_threshold + self.batch_size = batch_size + self.window_size = window_size + self.high_end_process = high_end_process + self.primary_stem = primary_stem + self.secondary_stem = secondary_stem + + def demix(self, numpy_array, sr, device, aggression): + aggr = float(int(aggression) / 100) + aggressiveness = { # это должно быть в demix + "value": aggr, + "split_bin": self.model_params.param["band"][1]["crop_stop"], + "aggr_correction": self.model_params.param.get("aggr_correction"), + } + y_spec, v_spec = self.inference_vr( + self.loading_mix(numpy_array, sr), device, aggressiveness + ) + y_spec = np.nan_to_num(y_spec, nan=0.0, posinf=0.0, neginf=0.0) + v_spec = np.nan_to_num(v_spec, nan=0.0, posinf=0.0, neginf=0.0) + primary_stem_array = self.spec_to_wav(y_spec).T + primary_stem_array = librosa.resample( + primary_stem_array.T, + orig_sr=self.model_samplerate, + target_sr=44100, + ).T + secondary_stem_array = self.spec_to_wav(v_spec).T + secondary_stem_array = librosa.resample( + secondary_stem_array.T, + orig_sr=self.model_samplerate, + target_sr=44100, + ).T + return {self.primary_stem: primary_stem_array, self.secondary_stem: secondary_stem_array} + diff --git a/mvsepless/models/vr_arch/layers.py b/mvsepless/models/vr_arch/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..fde999c98fc9832b7161a949f27ce4b268d06213 --- /dev/null +++ b/mvsepless/models/vr_arch/layers.py @@ -0,0 +1,329 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from . import spec_utils + + +class Conv2DBNActiv(nn.Module): + """ + This class implements a convolutional layer followed by batch normalization and an activation function. + It is a common pattern in deep learning for processing images or feature maps. The convolutional layer + applies a set of learnable filters to the input. Batch normalization then normalizes the output of the + convolution, and finally, an activation function introduces non-linearity to the model, allowing it to + learn more complex patterns. + + Attributes: + conv (nn.Sequential): A sequential container of Conv2d, BatchNorm2d, and an activation layer. + + Args: + num_input_channels (int): Number of input channels. + num_output_channels (int): Number of output channels. + kernel_size (int, optional): Size of the kernel. Defaults to 3. + stride_length (int, optional): Stride of the convolution. Defaults to 1. + padding_size (int, optional): Padding added to all sides of the input. Defaults to 1. + dilation_rate (int, optional): Spacing between kernel elements. Defaults to 1. + activation_function (callable, optional): The activation function to use. Defaults to nn.ReLU. + """ + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(Conv2DBNActiv, self).__init__() + + # The nn.Sequential container allows us to stack the Conv2d, BatchNorm2d, and activation layers + # into a single module, simplifying the forward pass. + self.conv = nn.Sequential( + nn.Conv2d( + nin, + nout, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(nout), + activ(), + ) + + def __call__(self, input_tensor): + # Defines the computation performed at every call. + # Simply passes the input through the sequential container. + return self.conv(input_tensor) + + +class SeperableConv2DBNActiv(nn.Module): + """ + This class implements a separable convolutional layer followed by batch normalization and an activation function. + Separable convolutions are a type of convolution that splits the convolution operation into two simpler operations: + a depthwise convolution and a pointwise convolution. This can reduce the number of parameters and computational cost, + making the network more efficient while maintaining similar performance. + + The depthwise convolution applies a single filter per input channel (input depth). The pointwise convolution, + which follows, applies a 1x1 convolution to combine the outputs of the depthwise convolution across channels. + Batch normalization is then applied to stabilize learning and reduce internal covariate shift. Finally, + an activation function introduces non-linearity, allowing the network to learn complex patterns. + Attributes: + conv (nn.Sequential): A sequential container of depthwise Conv2d, pointwise Conv2d, BatchNorm2d, and an activation layer. + + Args: + num_input_channels (int): Number of input channels. + num_output_channels (int): Number of output channels. + kernel_size (int, optional): Size of the kernel for the depthwise convolution. Defaults to 3. + stride_length (int, optional): Stride of the convolution. Defaults to 1. + padding_size (int, optional): Padding added to all sides of the input for the depthwise convolution. Defaults to 1. + dilation_rate (int, optional): Spacing between kernel elements for the depthwise convolution. Defaults to 1. + activation_function (callable, optional): The activation function to use. Defaults to nn.ReLU. + """ + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(SeperableConv2DBNActiv, self).__init__() + + # Initialize the sequential container with the depthwise convolution. + # The number of groups in the depthwise convolution is set to num_input_channels, which means each input channel is treated separately. + # The pointwise convolution then combines these separate channels into num_output_channels channels. + # Batch normalization is applied to the output of the pointwise convolution. + # Finally, the activation function is applied to introduce non-linearity. + self.conv = nn.Sequential( + nn.Conv2d( + nin, + nin, # For depthwise convolution, in_channels = out_channels = num_input_channels + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + groups=nin, # This makes it a depthwise convolution + bias=False, # Bias is not used because it will be handled by BatchNorm2d + ), + nn.Conv2d( + nin, + nout, # Pointwise convolution to combine channels + kernel_size=1, # Kernel size of 1 for pointwise convolution + bias=False, # Bias is not used because it will be handled by BatchNorm2d + ), + nn.BatchNorm2d(nout), # Normalize the output of the pointwise convolution + activ(), # Apply the activation function + ) + + def __call__(self, input_tensor): + # Pass the input through the sequential container. + # This performs the depthwise convolution, followed by the pointwise convolution, + # batch normalization, and finally applies the activation function. + return self.conv(input_tensor) + + +class Encoder(nn.Module): + """ + The Encoder class is a part of the neural network architecture that is responsible for processing the input data. + It consists of two convolutional layers, each followed by batch normalization and an activation function. + The purpose of the Encoder is to transform the input data into a higher-level, abstract representation. + This is achieved by applying filters (through convolutions) that can capture patterns or features in the data. + The Encoder can be thought of as a feature extractor that prepares the data for further processing by the network. + Attributes: + conv1 (Conv2DBNActiv): The first convolutional layer in the encoder. + conv2 (Conv2DBNActiv): The second convolutional layer in the encoder. + + Args: + number_of_input_channels (int): Number of input channels for the first convolutional layer. + number_of_output_channels (int): Number of output channels for the convolutional layers. + kernel_size (int): Kernel size for the convolutional layers. + stride_length (int): Stride for the convolutional operations. + padding_size (int): Padding added to all sides of the input for the convolutional layers. + activation_function (callable): The activation function to use after each convolutional layer. + """ + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU): + super(Encoder, self).__init__() + + # The first convolutional layer takes the input and applies a convolution, + # followed by batch normalization and an activation function specified by `activation_function`. + # This layer is responsible for capturing the initial set of features from the input data. + self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + + # The second convolutional layer further processes the output from the first layer, + # applying another set of convolution, batch normalization, and activation. + # This layer helps in capturing more complex patterns in the data by building upon the initial features extracted by conv1. + self.conv2 = Conv2DBNActiv(nout, nout, ksize, stride, pad, activ=activ) + + def __call__(self, input_tensor): + # The input data `input_tensor` is passed through the first convolutional layer. + # The output of this layer serves as a 'skip connection' that can be used later in the network to preserve spatial information. + skip = self.conv1(input_tensor) + + # The output from the first layer is then passed through the second convolutional layer. + # This processed data `hidden` is the final output of the Encoder, representing the abstracted features of the input. + hidden = self.conv2(skip) + + # The Encoder returns two outputs: `hidden`, the abstracted feature representation, and `skip`, the intermediate representation from conv1. + return hidden, skip + + +class Decoder(nn.Module): + """ + The Decoder class is part of the neural network architecture, specifically designed to perform the inverse operation of an encoder. + Its main role is to reconstruct or generate data from encoded representations, which is crucial in tasks like image segmentation or audio processing. + This class uses upsampling, convolution, optional dropout for regularization, and concatenation of skip connections to achieve its goal. + + Attributes: + convolution (Conv2DBNActiv): A convolutional layer with batch normalization and activation function. + dropout_layer (nn.Dropout2d): An optional dropout layer for regularization to prevent overfitting. + + Args: + input_channels (int): Number of input channels for the convolutional layer. + output_channels (int): Number of output channels for the convolutional layer. + kernel_size (int): Kernel size for the convolutional layer. + stride (int): Stride for the convolutional operations. + padding (int): Padding added to all sides of the input for the convolutional layer. + activation_function (callable): The activation function to use after the convolutional layer. + include_dropout (bool): Whether to include a dropout layer for regularization. + """ + + def __init__( + self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False + ): + super(Decoder, self).__init__() + + # Initialize the convolutional layer with specified parameters. + self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + + # Initialize the dropout layer if include_dropout is set to True + self.dropout = nn.Dropout2d(0.1) if dropout else None + + def __call__(self, input_tensor, skip=None): + # Upsample the input tensor to a higher resolution using bilinear interpolation. + input_tensor = F.interpolate( + input_tensor, scale_factor=2, mode="bilinear", align_corners=True + ) + # If a skip connection is provided, crop it to match the size of input_tensor and concatenate them along the channel dimension. + if skip is not None: + skip = spec_utils.crop_center( + skip, input_tensor + ) # Crop skip_connection to match input_tensor's dimensions. + input_tensor = torch.cat( + [input_tensor, skip], dim=1 + ) # Concatenate input_tensor and skip_connection along the channel dimension. + + # Pass the concatenated tensor (or just input_tensor if no skip_connection is provided) through the convolutional layer. + output_tensor = self.conv(input_tensor) + + # If dropout is enabled, apply it to the output of the convolutional layer. + if self.dropout is not None: + output_tensor = self.dropout(output_tensor) + + # Return the final output tensor. + return output_tensor + + +class ASPPModule(nn.Module): + """ + Atrous Spatial Pyramid Pooling (ASPP) Module is designed for capturing multi-scale context by applying + atrous convolution at multiple rates. This is particularly useful in segmentation tasks where capturing + objects at various scales is beneficial. The module applies several parallel dilated convolutions with + different dilation rates to the input feature map, allowing it to efficiently capture information at + multiple scales. + + Attributes: + conv1 (nn.Sequential): Applies adaptive average pooling followed by a 1x1 convolution. + nn_architecture (int): Identifier for the neural network architecture being used. + six_layer (list): List containing architecture identifiers that require six layers. + seven_layer (list): List containing architecture identifiers that require seven layers. + conv2-conv7 (nn.Module): Convolutional layers with varying dilation rates for multi-scale feature extraction. + bottleneck (nn.Sequential): A 1x1 convolutional layer that combines all features followed by dropout for regularization. + """ + + def __init__(self, nn_architecture, nin, nout, dilations=(4, 8, 16), activ=nn.ReLU): + """ + Initializes the ASPP module with specified parameters. + + Args: + nn_architecture (int): Identifier for the neural network architecture. + input_channels (int): Number of input channels. + output_channels (int): Number of output channels. + dilations (tuple): Tuple of dilation rates for the atrous convolutions. + activation (callable): Activation function to use after convolutional layers. + """ + super(ASPPModule, self).__init__() + + # Adaptive average pooling reduces the spatial dimensions to 1x1, focusing on global context, + # followed by a 1x1 convolution to project back to the desired channel dimension. + self.conv1 = nn.Sequential( + nn.AdaptiveAvgPool2d((1, None)), + Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ), + ) + + self.nn_architecture = nn_architecture + # Architecture identifiers for models requiring additional layers. + self.six_layer = [129605] + self.seven_layer = [537238, 537227, 33966] + + # Extra convolutional layer used for six and seven layer configurations. + extra_conv = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[2], dilations[2], activ=activ + ) + + # Standard 1x1 convolution for channel reduction. + self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ) + + # Separable convolutions with different dilation rates for multi-scale feature extraction. + self.conv3 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[0], dilations[0], activ=activ + ) + self.conv4 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[1], dilations[1], activ=activ + ) + self.conv5 = SeperableConv2DBNActiv( + nin, nin, 3, 1, dilations[2], dilations[2], activ=activ + ) + + # Depending on the architecture, include the extra convolutional layers. + if self.nn_architecture in self.six_layer: + self.conv6 = extra_conv + nin_x = 6 + elif self.nn_architecture in self.seven_layer: + self.conv6 = extra_conv + self.conv7 = extra_conv + nin_x = 7 + else: + nin_x = 5 + + # Bottleneck layer combines all the multi-scale features into the desired number of output channels. + self.bottleneck = nn.Sequential( + Conv2DBNActiv(nin * nin_x, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1) + ) + + def forward(self, input_tensor): + """ + Forward pass of the ASPP module. + + Args: + input_tensor (Tensor): Input tensor. + + Returns: + Tensor: Output tensor after applying ASPP. + """ + _, _, h, w = input_tensor.size() + + # Apply the first convolutional sequence and upsample to the original resolution. + feat1 = F.interpolate( + self.conv1(input_tensor), size=(h, w), mode="bilinear", align_corners=True + ) + + # Apply the remaining convolutions directly on the input. + feat2 = self.conv2(input_tensor) + feat3 = self.conv3(input_tensor) + feat4 = self.conv4(input_tensor) + feat5 = self.conv5(input_tensor) + + # Concatenate features from all layers. Depending on the architecture, include the extra features. + if self.nn_architecture in self.six_layer: + feat6 = self.conv6(input_tensor) + out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6), dim=1) + elif self.nn_architecture in self.seven_layer: + feat6 = self.conv6(input_tensor) + feat7 = self.conv7(input_tensor) + out = torch.cat((feat1, feat2, feat3, feat4, feat5, feat6, feat7), dim=1) + else: + out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) + + # Apply the bottleneck layer to combine and reduce the channel dimensions. + bottleneck_output = self.bottleneck(out) + return bottleneck_output diff --git a/mvsepless/models/vr_arch/layers_new.py b/mvsepless/models/vr_arch/layers_new.py new file mode 100644 index 0000000000000000000000000000000000000000..b245ee25234e4058a0989a1a864aa0d8539d7f63 --- /dev/null +++ b/mvsepless/models/vr_arch/layers_new.py @@ -0,0 +1,180 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from . import spec_utils + + +class Conv2DBNActiv(nn.Module): + """ + Conv2DBNActiv Class: + This class implements a convolutional layer followed by batch normalization and an activation function. + It is a fundamental building block for constructing neural networks, especially useful in image and audio processing tasks. + The class encapsulates the pattern of applying a convolution, normalizing the output, and then applying a non-linear activation. + """ + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): + super(Conv2DBNActiv, self).__init__() + + # Sequential model combining Conv2D, BatchNorm, and activation function into a single module + self.conv = nn.Sequential( + nn.Conv2d( + nin, + nout, + kernel_size=ksize, + stride=stride, + padding=pad, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(nout), + activ(), + ) + + def __call__(self, input_tensor): + # Forward pass through the sequential model + return self.conv(input_tensor) + + +class Encoder(nn.Module): + """ + Encoder Class: + This class defines an encoder module typically used in autoencoder architectures. + It consists of two convolutional layers, each followed by batch normalization and an activation function. + """ + + def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU): + super(Encoder, self).__init__() + + # First convolutional layer of the encoder + self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ) + # Second convolutional layer of the encoder + self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ) + + def __call__(self, input_tensor): + # Applying the first and then the second convolutional layers + hidden = self.conv1(input_tensor) + hidden = self.conv2(hidden) + + return hidden + + +class Decoder(nn.Module): + """ + Decoder Class: + This class defines a decoder module, which is the counterpart of the Encoder class in autoencoder architectures. + It applies a convolutional layer followed by batch normalization and an activation function, with an optional dropout layer for regularization. + """ + + def __init__( + self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False + ): + super(Decoder, self).__init__() + # Convolutional layer with optional dropout for regularization + self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) + # self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ) + self.dropout = nn.Dropout2d(0.1) if dropout else None + + def __call__(self, input_tensor, skip=None): + # Forward pass through the convolutional layer and optional dropout + input_tensor = F.interpolate( + input_tensor, scale_factor=2, mode="bilinear", align_corners=True + ) + + if skip is not None: + skip = spec_utils.crop_center(skip, input_tensor) + input_tensor = torch.cat([input_tensor, skip], dim=1) + + hidden = self.conv1(input_tensor) + # hidden = self.conv2(hidden) + + if self.dropout is not None: + hidden = self.dropout(hidden) + + return hidden + + +class ASPPModule(nn.Module): + """ + ASPPModule Class: + This class implements the Atrous Spatial Pyramid Pooling (ASPP) module, which is useful for semantic image segmentation tasks. + It captures multi-scale contextual information by applying convolutions at multiple dilation rates. + """ + + def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False): + super(ASPPModule, self).__init__() + + # Global context convolution captures the overall context + self.conv1 = nn.Sequential( + nn.AdaptiveAvgPool2d((1, None)), + Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ), + ) + self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ) + self.conv3 = Conv2DBNActiv( + nin, nout, 3, 1, dilations[0], dilations[0], activ=activ + ) + self.conv4 = Conv2DBNActiv( + nin, nout, 3, 1, dilations[1], dilations[1], activ=activ + ) + self.conv5 = Conv2DBNActiv( + nin, nout, 3, 1, dilations[2], dilations[2], activ=activ + ) + self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ) + self.dropout = nn.Dropout2d(0.1) if dropout else None + + def forward(self, input_tensor): + _, _, h, w = input_tensor.size() + + # Upsample global context to match input size and combine with local and multi-scale features + feat1 = F.interpolate( + self.conv1(input_tensor), size=(h, w), mode="bilinear", align_corners=True + ) + feat2 = self.conv2(input_tensor) + feat3 = self.conv3(input_tensor) + feat4 = self.conv4(input_tensor) + feat5 = self.conv5(input_tensor) + out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) + out = self.bottleneck(out) + + if self.dropout is not None: + out = self.dropout(out) + + return out + + +class LSTMModule(nn.Module): + """ + LSTMModule Class: + This class defines a module that combines convolutional feature extraction with a bidirectional LSTM for sequence modeling. + It is useful for tasks that require understanding temporal dynamics in data, such as speech and audio processing. + """ + + def __init__(self, nin_conv, nin_lstm, nout_lstm): + super(LSTMModule, self).__init__() + # Convolutional layer for initial feature extraction + self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0) + + # Bidirectional LSTM for capturing temporal dynamics + self.lstm = nn.LSTM( + input_size=nin_lstm, hidden_size=nout_lstm // 2, bidirectional=True + ) + + # Dense layer for output dimensionality matching + self.dense = nn.Sequential( + nn.Linear(nout_lstm, nin_lstm), nn.BatchNorm1d(nin_lstm), nn.ReLU() + ) + + def forward(self, input_tensor): + N, _, nbins, nframes = input_tensor.size() + + # Extract features and prepare for LSTM + hidden = self.conv(input_tensor)[:, 0] # N, nbins, nframes + hidden = hidden.permute(2, 0, 1) # nframes, N, nbins + hidden, _ = self.lstm(hidden) + + # Apply dense layer and reshape to match expected output format + hidden = self.dense(hidden.reshape(-1, hidden.size()[-1])) # nframes * N, nbins + hidden = hidden.reshape(nframes, N, 1, nbins) + hidden = hidden.permute(1, 2, 3, 0) + + return hidden diff --git a/mvsepless/models/vr_arch/model_param_init.py b/mvsepless/models/vr_arch/model_param_init.py new file mode 100644 index 0000000000000000000000000000000000000000..1cfde44a8c958e989854e71613bfecc8c139bdfd --- /dev/null +++ b/mvsepless/models/vr_arch/model_param_init.py @@ -0,0 +1,76 @@ +import json + +default_param = {} +default_param["bins"] = -1 +default_param["unstable_bins"] = -1 # training only +default_param["stable_bins"] = -1 # training only +default_param["sr"] = 44100 +default_param["pre_filter_start"] = -1 +default_param["pre_filter_stop"] = -1 +default_param["band"] = {} + +N_BINS = "n_bins" + + +def int_keys(d): + """ + Converts string keys that represent integers into actual integer keys in a list. + + This function is particularly useful when dealing with JSON data that may represent + integer keys as strings due to the nature of JSON encoding. By converting these keys + back to integers, it ensures that the data can be used in a manner consistent with + its original representation, especially in contexts where the distinction between + string and integer keys is important. + + Args: + input_list (list of tuples): A list of (key, value) pairs where keys are strings + that may represent integers. + + Returns: + dict: A dictionary with keys converted to integers where applicable. + """ + # Initialize an empty dictionary to hold the converted key-value pairs. + result_dict = {} + # Iterate through each key-value pair in the input list. + for key, value in d: + # Check if the key is a digit (i.e., represents an integer). + if key.isdigit(): + # Convert the key from a string to an integer. + key = int(key) + result_dict[key] = value + return result_dict + + +class ModelParameters(object): + """ + A class to manage model parameters, including loading from a configuration file. + + Attributes: + param (dict): Dictionary holding all parameters for the model. + """ + + def __init__(self, model_params=None): + """ + Initializes the ModelParameters object by loading parameters from a JSON configuration file. + + Args: + config_path (str): Path to the JSON configuration file. + """ + + self.param = model_params + + # Ensure certain parameters are set to False if not specified in the configuration. + for k in [ + "mid_side", + "mid_side_b", + "mid_side_b2", + "stereo_w", + "stereo_n", + "reverse", + ]: + if not k in self.param: + self.param[k] = False + + # If 'n_bins' is specified in the parameters, it's used as the value for 'bins'. + if N_BINS in self.param: + self.param["bins"] = self.param[N_BINS] diff --git a/mvsepless/models/vr_arch/nets.py b/mvsepless/models/vr_arch/nets.py new file mode 100644 index 0000000000000000000000000000000000000000..53fbf64c5d2b8dcae60ac75017a4144030c0ec81 --- /dev/null +++ b/mvsepless/models/vr_arch/nets.py @@ -0,0 +1,223 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from . import layers + + +class BaseASPPNet(nn.Module): + """ + BaseASPPNet Class: + This class defines the base architecture for an Atrous Spatial Pyramid Pooling (ASPP) network. + It is designed to extract features from input data at multiple scales by using dilated convolutions. + This is particularly useful for tasks that benefit from understanding context at different resolutions, + such as semantic segmentation. The network consists of a series of encoder layers for downsampling and feature extraction, + followed by an ASPP module for multi-scale feature extraction, and finally a series of decoder layers for upsampling. + """ + + def __init__(self, nn_architecture, nin, ch, dilations=(4, 8, 16)): + super(BaseASPPNet, self).__init__() + self.nn_architecture = nn_architecture + + # Encoder layers progressively increase the number of channels while reducing spatial dimensions. + self.enc1 = layers.Encoder(nin, ch, 3, 2, 1) + self.enc2 = layers.Encoder(ch, ch * 2, 3, 2, 1) + self.enc3 = layers.Encoder(ch * 2, ch * 4, 3, 2, 1) + self.enc4 = layers.Encoder(ch * 4, ch * 8, 3, 2, 1) + + # Depending on the network architecture, an additional encoder layer and a specific ASPP module are initialized. + if self.nn_architecture == 129605: + self.enc5 = layers.Encoder(ch * 8, ch * 16, 3, 2, 1) + self.aspp = layers.ASPPModule(nn_architecture, ch * 16, ch * 32, dilations) + self.dec5 = layers.Decoder(ch * (16 + 32), ch * 16, 3, 1, 1) + else: + self.aspp = layers.ASPPModule(nn_architecture, ch * 8, ch * 16, dilations) + + # Decoder layers progressively decrease the number of channels while increasing spatial dimensions. + self.dec4 = layers.Decoder(ch * (8 + 16), ch * 8, 3, 1, 1) + self.dec3 = layers.Decoder(ch * (4 + 8), ch * 4, 3, 1, 1) + self.dec2 = layers.Decoder(ch * (2 + 4), ch * 2, 3, 1, 1) + self.dec1 = layers.Decoder(ch * (1 + 2), ch, 3, 1, 1) + + def __call__(self, input_tensor): + # The input tensor is passed through a series of encoder layers. + hidden_state, encoder_output1 = self.enc1(input_tensor) + hidden_state, encoder_output2 = self.enc2(hidden_state) + hidden_state, encoder_output3 = self.enc3(hidden_state) + hidden_state, encoder_output4 = self.enc4(hidden_state) + + # Depending on the network architecture, the hidden state is processed by an additional encoder layer and the ASPP module. + if self.nn_architecture == 129605: + hidden_state, encoder_output5 = self.enc5(hidden_state) + hidden_state = self.aspp(hidden_state) + # The decoder layers use skip connections from the encoder layers for better feature integration. + hidden_state = self.dec5(hidden_state, encoder_output5) + else: + hidden_state = self.aspp(hidden_state) + + # The hidden state is further processed by the decoder layers, using skip connections for feature integration. + hidden_state = self.dec4(hidden_state, encoder_output4) + hidden_state = self.dec3(hidden_state, encoder_output3) + hidden_state = self.dec2(hidden_state, encoder_output2) + hidden_state = self.dec1(hidden_state, encoder_output1) + + return hidden_state + + +def determine_model_capacity(n_fft_bins, nn_architecture): + """ + The determine_model_capacity function is designed to select the appropriate model configuration + based on the frequency bins and network architecture. It maps specific architectures to predefined + model capacities, which dictate the structure and parameters of the CascadedASPPNet model. + """ + + # Predefined model architectures categorized by their precision level. + sp_model_arch = [31191, 33966, 129605] + hp_model_arch = [123821, 123812] + hp2_model_arch = [537238, 537227] + + # Mapping network architectures to their corresponding model capacity data. + if nn_architecture in sp_model_arch: + model_capacity_data = [ + (2, 16), + (2, 16), + (18, 8, 1, 1, 0), + (8, 16), + (34, 16, 1, 1, 0), + (16, 32), + (32, 2, 1), + (16, 2, 1), + (16, 2, 1), + ] + + if nn_architecture in hp_model_arch: + model_capacity_data = [ + (2, 32), + (2, 32), + (34, 16, 1, 1, 0), + (16, 32), + (66, 32, 1, 1, 0), + (32, 64), + (64, 2, 1), + (32, 2, 1), + (32, 2, 1), + ] + + if nn_architecture in hp2_model_arch: + model_capacity_data = [ + (2, 64), + (2, 64), + (66, 32, 1, 1, 0), + (32, 64), + (130, 64, 1, 1, 0), + (64, 128), + (128, 2, 1), + (64, 2, 1), + (64, 2, 1), + ] + + # Initializing the CascadedASPPNet model with the selected model capacity data. + cascaded = CascadedASPPNet + model = cascaded(n_fft_bins, model_capacity_data, nn_architecture) + + return model + + +class CascadedASPPNet(nn.Module): + """ + CascadedASPPNet Class: + This class implements a cascaded version of the ASPP network, designed for processing audio signals + for tasks such as vocal removal. It consists of multiple stages, each with its own ASPP network, + to process different frequency bands of the input signal. This allows the model to effectively + handle the full spectrum of audio frequencies by focusing on different frequency bands separately. + """ + + def __init__(self, n_fft, model_capacity_data, nn_architecture): + super(CascadedASPPNet, self).__init__() + # The first stage processes the low and high frequency bands separately. + self.stg1_low_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[0]) + self.stg1_high_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[1]) + + # Bridge layers connect different stages of the network. + self.stg2_bridge = layers.Conv2DBNActiv(*model_capacity_data[2]) + self.stg2_full_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[3]) + + self.stg3_bridge = layers.Conv2DBNActiv(*model_capacity_data[4]) + self.stg3_full_band_net = BaseASPPNet(nn_architecture, *model_capacity_data[5]) + + # Output layers for the final mask prediction and auxiliary outputs. + self.out = nn.Conv2d(*model_capacity_data[6], bias=False) + self.aux1_out = nn.Conv2d(*model_capacity_data[7], bias=False) + self.aux2_out = nn.Conv2d(*model_capacity_data[8], bias=False) + + # Parameters for handling the frequency bins of the input signal. + self.max_bin = n_fft // 2 + self.output_bin = n_fft // 2 + 1 + + self.offset = 128 + + def forward(self, input_tensor): + # The forward pass processes the input tensor through each stage of the network, + # combining the outputs of different frequency bands and stages to produce the final mask. + mix = input_tensor.detach() + input_tensor = input_tensor.clone() + + # Preparing the input tensor by selecting the mainput_tensorimum frequency bin. + input_tensor = input_tensor[:, :, : self.max_bin] + + # Processing the low and high frequency bands separately in the first stage. + bandwidth = input_tensor.size()[2] // 2 + aux1 = torch.cat( + [ + self.stg1_low_band_net(input_tensor[:, :, :bandwidth]), + self.stg1_high_band_net(input_tensor[:, :, bandwidth:]), + ], + dim=2, + ) + + # Combining the outputs of the first stage and passing through the second stage. + hidden_state = torch.cat([input_tensor, aux1], dim=1) + aux2 = self.stg2_full_band_net(self.stg2_bridge(hidden_state)) + + # Further processing the combined outputs through the third stage. + hidden_state = torch.cat([input_tensor, aux1, aux2], dim=1) + hidden_state = self.stg3_full_band_net(self.stg3_bridge(hidden_state)) + + # Applying the final output layer to produce the mask. + mask = torch.sigmoid(self.out(hidden_state)) + + # Padding the mask to match the output frequency bin size. + mask = F.pad( + input=mask, + pad=(0, 0, 0, self.output_bin - mask.size()[2]), + mode="replicate", + ) + + # During training, auxiliary outputs are also produced and padded accordingly. + if self.training: + aux1 = torch.sigmoid(self.aux1_out(aux1)) + aux1 = F.pad( + input=aux1, + pad=(0, 0, 0, self.output_bin - aux1.size()[2]), + mode="replicate", + ) + aux2 = torch.sigmoid(self.aux2_out(aux2)) + aux2 = F.pad( + input=aux2, + pad=(0, 0, 0, self.output_bin - aux2.size()[2]), + mode="replicate", + ) + return mask * mix, aux1 * mix, aux2 * mix + else: + return mask # * mix + + def predict_mask(self, input_tensor): + # This method predicts the mask for the input tensor by calling the forward method + # and applying any necessary padding adjustments. + mask = self.forward(input_tensor) + + # Adjusting the mask by removing padding offsets if present. + if self.offset > 0: + mask = mask[:, :, :, self.offset : -self.offset] + + return mask diff --git a/mvsepless/models/vr_arch/nets_new.py b/mvsepless/models/vr_arch/nets_new.py new file mode 100644 index 0000000000000000000000000000000000000000..3f20cebac69542d085211d4ca02b6822a08ff34e --- /dev/null +++ b/mvsepless/models/vr_arch/nets_new.py @@ -0,0 +1,182 @@ +import torch +from torch import nn +import torch.nn.functional as F +from . import layers_new as layers + + +class BaseNet(nn.Module): + """ + BaseNet Class: + This class defines the base network architecture for vocal removal. It includes a series of encoders for feature extraction, + an ASPP module for capturing multi-scale context, and a series of decoders for reconstructing the output. Additionally, + it incorporates an LSTM module for capturing temporal dependencies. + """ + + def __init__( + self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6)) + ): + super(BaseNet, self).__init__() + # Initialize the encoder layers with increasing output channels for hierarchical feature extraction. + self.enc1 = layers.Conv2DBNActiv(nin, nout, 3, 1, 1) + self.enc2 = layers.Encoder(nout, nout * 2, 3, 2, 1) + self.enc3 = layers.Encoder(nout * 2, nout * 4, 3, 2, 1) + self.enc4 = layers.Encoder(nout * 4, nout * 6, 3, 2, 1) + self.enc5 = layers.Encoder(nout * 6, nout * 8, 3, 2, 1) + + # ASPP module for capturing multi-scale features with different dilation rates. + self.aspp = layers.ASPPModule(nout * 8, nout * 8, dilations, dropout=True) + + # Decoder layers for upscaling and merging features from different levels of the encoder and ASPP module. + self.dec4 = layers.Decoder(nout * (6 + 8), nout * 6, 3, 1, 1) + self.dec3 = layers.Decoder(nout * (4 + 6), nout * 4, 3, 1, 1) + self.dec2 = layers.Decoder(nout * (2 + 4), nout * 2, 3, 1, 1) + + # LSTM module for capturing temporal dependencies in the sequence of features. + self.lstm_dec2 = layers.LSTMModule(nout * 2, nin_lstm, nout_lstm) + self.dec1 = layers.Decoder(nout * (1 + 2) + 1, nout * 1, 3, 1, 1) + + def __call__(self, input_tensor): + # Sequentially pass the input through the encoder layers. + encoded1 = self.enc1(input_tensor) + encoded2 = self.enc2(encoded1) + encoded3 = self.enc3(encoded2) + encoded4 = self.enc4(encoded3) + encoded5 = self.enc5(encoded4) + + # Pass the deepest encoder output through the ASPP module. + bottleneck = self.aspp(encoded5) + + # Sequentially upscale and merge the features using the decoder layers. + bottleneck = self.dec4(bottleneck, encoded4) + bottleneck = self.dec3(bottleneck, encoded3) + bottleneck = self.dec2(bottleneck, encoded2) + # Concatenate the LSTM module output for temporal feature enhancement. + bottleneck = torch.cat([bottleneck, self.lstm_dec2(bottleneck)], dim=1) + bottleneck = self.dec1(bottleneck, encoded1) + + return bottleneck + + +class CascadedNet(nn.Module): + """ + CascadedNet Class: + This class defines a cascaded network architecture that processes input in multiple stages, each stage focusing on different frequency bands. + It utilizes the BaseNet for processing, and combines outputs from different stages to produce the final mask for vocal removal. + """ + + def __init__(self, n_fft, nn_arch_size=51000, nout=32, nout_lstm=128): + super(CascadedNet, self).__init__() + # Calculate frequency bins based on FFT size. + self.max_bin = n_fft // 2 + self.output_bin = n_fft // 2 + 1 + self.nin_lstm = self.max_bin // 2 + self.offset = 64 + # Adjust output channels based on the architecture size. + nout = 64 if nn_arch_size == 218409 else nout + + # print(nout, nout_lstm, n_fft) + + # Initialize the network stages, each focusing on different frequency bands and progressively refining the output. + self.stg1_low_band_net = nn.Sequential( + BaseNet(2, nout // 2, self.nin_lstm // 2, nout_lstm), + layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0), + ) + self.stg1_high_band_net = BaseNet( + 2, nout // 4, self.nin_lstm // 2, nout_lstm // 2 + ) + + self.stg2_low_band_net = nn.Sequential( + BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm), + layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0), + ) + self.stg2_high_band_net = BaseNet( + nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2 + ) + + self.stg3_full_band_net = BaseNet( + 3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm + ) + + # Output layer for generating the final mask. + self.out = nn.Conv2d(nout, 2, 1, bias=False) + # Auxiliary output layer for intermediate supervision during training. + self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False) + + def forward(self, input_tensor): + # Preprocess input tensor to match the maximum frequency bin. + input_tensor = input_tensor[:, :, : self.max_bin] + + # Split the input into low and high frequency bands. + bandw = input_tensor.size()[2] // 2 + l1_in = input_tensor[:, :, :bandw] + h1_in = input_tensor[:, :, bandw:] + + # Process each band through the first stage networks. + l1 = self.stg1_low_band_net(l1_in) + h1 = self.stg1_high_band_net(h1_in) + + # Combine the outputs for auxiliary supervision. + aux1 = torch.cat([l1, h1], dim=2) + + # Prepare inputs for the second stage by concatenating the original and processed bands. + l2_in = torch.cat([l1_in, l1], dim=1) + h2_in = torch.cat([h1_in, h1], dim=1) + + # Process through the second stage networks. + l2 = self.stg2_low_band_net(l2_in) + h2 = self.stg2_high_band_net(h2_in) + + # Combine the outputs for auxiliary supervision. + aux2 = torch.cat([l2, h2], dim=2) + + # Prepare input for the third stage by concatenating all previous outputs with the original input. + f3_in = torch.cat([input_tensor, aux1, aux2], dim=1) + + # Process through the third stage network. + f3 = self.stg3_full_band_net(f3_in) + + # Apply the output layer to generate the final mask and apply sigmoid for normalization. + mask = torch.sigmoid(self.out(f3)) + + # Pad the mask to match the output frequency bin size. + mask = F.pad( + input=mask, + pad=(0, 0, 0, self.output_bin - mask.size()[2]), + mode="replicate", + ) + + # During training, generate and pad the auxiliary output for additional supervision. + if self.training: + aux = torch.cat([aux1, aux2], dim=1) + aux = torch.sigmoid(self.aux_out(aux)) + aux = F.pad( + input=aux, + pad=(0, 0, 0, self.output_bin - aux.size()[2]), + mode="replicate", + ) + return mask, aux + else: + return mask + + # Method for predicting the mask given an input tensor. + def predict_mask(self, input_tensor): + mask = self.forward(input_tensor) + + # If an offset is specified, crop the mask to remove edge artifacts. + if self.offset > 0: + mask = mask[:, :, :, self.offset : -self.offset] + assert mask.size()[3] > 0 + + return mask + + # Method for applying the predicted mask to the input tensor to obtain the predicted magnitude. + def predict(self, input_tensor): + mask = self.forward(input_tensor) + pred_mag = input_tensor * mask + + # If an offset is specified, crop the predicted magnitude to remove edge artifacts. + if self.offset > 0: + pred_mag = pred_mag[:, :, :, self.offset : -self.offset] + assert pred_mag.size()[3] > 0 + + return pred_mag diff --git a/mvsepless/models/vr_arch/spec_utils.py b/mvsepless/models/vr_arch/spec_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3cdf421b38fb437cc412780987c19c257b71e5ce --- /dev/null +++ b/mvsepless/models/vr_arch/spec_utils.py @@ -0,0 +1,1494 @@ +import audioread +import librosa +import numpy as np +import soundfile as sf +import math +import platform +import traceback +from scipy.signal import correlate, hilbert +import io + +OPERATING_SYSTEM = platform.system() +SYSTEM_ARCH = platform.platform() +SYSTEM_PROC = platform.processor() +ARM = "arm" + +AUTO_PHASE = "Automatic" +POSITIVE_PHASE = "Positive Phase" +NEGATIVE_PHASE = "Negative Phase" +NONE_P = ("None",) +LOW_P = ("Shifts: Low",) +MED_P = ("Shifts: Medium",) +HIGH_P = ("Shifts: High",) +VHIGH_P = "Shifts: Very High" +MAXIMUM_P = "Shifts: Maximum" + +progress_value = 0 +last_update_time = 0 +is_macos = False + + +if OPERATING_SYSTEM == "Darwin": + wav_resolution = ( + "polyphase" if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else "sinc_fastest" + ) + wav_resolution_float_resampling = ( + "kaiser_best" if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else wav_resolution + ) + is_macos = True +else: + wav_resolution = "sinc_fastest" + wav_resolution_float_resampling = wav_resolution + +MAX_SPEC = "Max Spec" +MIN_SPEC = "Min Spec" +LIN_ENSE = "Linear Ensemble" + +MAX_WAV = MAX_SPEC +MIN_WAV = MIN_SPEC + +AVERAGE = "Average" + + +def crop_center(h1, h2): + """ + This function crops the center of the first input tensor to match the size of the second input tensor. + It is used to ensure that the two tensors have the same size in the time dimension. + """ + h1_shape = h1.size() + h2_shape = h2.size() + + # If the time dimensions are already equal, return the first tensor as is + if h1_shape[3] == h2_shape[3]: + return h1 + # If the time dimension of the first tensor is smaller, raise an error + elif h1_shape[3] < h2_shape[3]: + raise ValueError("h1_shape[3] must be greater than h2_shape[3]") + + # Calculate the start and end indices for cropping + s_time = (h1_shape[3] - h2_shape[3]) // 2 + e_time = s_time + h2_shape[3] + # Crop the first tensor + h1 = h1[:, :, :, s_time:e_time] + + return h1 + + +def preprocess(X_spec): + """ + This function preprocesses a spectrogram by separating it into magnitude and phase components. + This is a common preprocessing step in audio processing tasks. + """ + X_mag = np.abs(X_spec) + X_phase = np.angle(X_spec) + + return X_mag, X_phase + + +def make_padding(width, cropsize, offset): + """ + This function calculates the padding needed to make the width of an image divisible by the crop size. + It is used in the process of splitting an image into smaller patches. + """ + left = offset + roi_size = cropsize - offset * 2 + if roi_size == 0: + roi_size = cropsize + right = roi_size - (width % roi_size) + left + + return left, right, roi_size + + +def normalize(wave, max_peak=1.0, min_peak=None): + """Normalize (or amplify) audio waveform to a specified peak value. + + Args: + wave (array-like): Audio waveform. + max_peak (float): Maximum peak value for normalization. + + Returns: + array-like: Normalized or original waveform. + """ + maxv = np.abs(wave).max() + if maxv > max_peak: + wave *= max_peak / maxv + elif min_peak is not None and maxv < min_peak: + wave *= min_peak / maxv + + return wave + + +def auto_transpose(audio_array: np.ndarray): + """ + Ensure that the audio array is in the (channels, samples) format. + + Parameters: + audio_array (ndarray): Input audio array. + + Returns: + ndarray: Transposed audio array if necessary. + """ + + # If the second dimension is 2 (indicating stereo channels), transpose the array + if audio_array.shape[1] == 2: + return audio_array.T + return audio_array + + +def write_array_to_mem(audio_data, subtype): + if isinstance(audio_data, np.ndarray): + audio_buffer = io.BytesIO() + sf.write(audio_buffer, audio_data, 44100, subtype=subtype, format="WAV") + audio_buffer.seek(0) + return audio_buffer + else: + return audio_data + + +def spectrogram_to_image(spec, mode="magnitude"): + if mode == "magnitude": + if np.iscomplexobj(spec): + y = np.abs(spec) + else: + y = spec + y = np.log10(y**2 + 1e-8) + elif mode == "phase": + if np.iscomplexobj(spec): + y = np.angle(spec) + else: + y = spec + + y -= y.min() + y *= 255 / y.max() + img = np.uint8(y) + + if y.ndim == 3: + img = img.transpose(1, 2, 0) + img = np.concatenate([np.max(img, axis=2, keepdims=True), img], axis=2) + + return img + + +def reduce_vocal_aggressively(X, y, softmask): + v = X - y + y_mag_tmp = np.abs(y) + v_mag_tmp = np.abs(v) + + v_mask = v_mag_tmp > y_mag_tmp + y_mag = np.clip(y_mag_tmp - v_mag_tmp * v_mask * softmask, 0, np.inf) + + return y_mag * np.exp(1.0j * np.angle(y)) + + +def merge_artifacts(y_mask, thres=0.01, min_range=64, fade_size=32): + mask = y_mask + + try: + if min_range < fade_size * 2: + raise ValueError("min_range must be >= fade_size * 2") + + idx = np.where(y_mask.min(axis=(0, 1)) > thres)[0] + start_idx = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0]) + end_idx = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1]) + artifact_idx = np.where(end_idx - start_idx > min_range)[0] + weight = np.zeros_like(y_mask) + if len(artifact_idx) > 0: + start_idx = start_idx[artifact_idx] + end_idx = end_idx[artifact_idx] + old_e = None + for s, e in zip(start_idx, end_idx): + if old_e is not None and s - old_e < fade_size: + s = old_e - fade_size * 2 + + if s != 0: + weight[:, :, s : s + fade_size] = np.linspace(0, 1, fade_size) + else: + s -= fade_size + + if e != y_mask.shape[2]: + weight[:, :, e - fade_size : e] = np.linspace(1, 0, fade_size) + else: + e += fade_size + + weight[:, :, s + fade_size : e - fade_size] = 1 + old_e = e + + v_mask = 1 - y_mask + y_mask += weight * v_mask + + mask = y_mask + except Exception as e: + error_name = f"{type(e).__name__}" + traceback_text = "".join(traceback.format_tb(e.__traceback__)) + message = f'{error_name}: "{e}"\n{traceback_text}"' + print("Post Process Failed: ", message) + + return mask + + +def align_wave_head_and_tail(a, b): + l = min([a[0].size, b[0].size]) + + return a[:l, :l], b[:l, :l] + + +def convert_channels(spec, mp, band): + cc = mp.param["band"][band].get("convert_channels") + + if "mid_side_c" == cc: + spec_left = np.add(spec[0], spec[1] * 0.25) + spec_right = np.subtract(spec[1], spec[0] * 0.25) + elif "mid_side" == cc: + spec_left = np.add(spec[0], spec[1]) / 2 + spec_right = np.subtract(spec[0], spec[1]) + elif "stereo_n" == cc: + spec_left = np.add(spec[0], spec[1] * 0.25) / 0.9375 + spec_right = np.add(spec[1], spec[0] * 0.25) / 0.9375 + else: + return spec + + return np.asfortranarray([spec_left, spec_right]) + + +def combine_spectrograms(specs, mp, is_v51_model=False): + l = min([specs[i].shape[2] for i in specs]) + spec_c = np.zeros(shape=(2, mp.param["bins"] + 1, l), dtype=np.complex64) + offset = 0 + bands_n = len(mp.param["band"]) + + for d in range(1, bands_n + 1): + h = mp.param["band"][d]["crop_stop"] - mp.param["band"][d]["crop_start"] + spec_c[:, offset : offset + h, :l] = specs[d][ + :, mp.param["band"][d]["crop_start"] : mp.param["band"][d]["crop_stop"], :l + ] + offset += h + + if offset > mp.param["bins"]: + raise ValueError("Too much bins") + + # lowpass fiter + + if mp.param["pre_filter_start"] > 0: + if is_v51_model: + spec_c *= get_lp_filter_mask( + spec_c.shape[1], + mp.param["pre_filter_start"], + mp.param["pre_filter_stop"], + ) + else: + if bands_n == 1: + spec_c = fft_lp_filter( + spec_c, mp.param["pre_filter_start"], mp.param["pre_filter_stop"] + ) + else: + gp = 1 + for b in range( + mp.param["pre_filter_start"] + 1, mp.param["pre_filter_stop"] + ): + g = math.pow( + 10, -(b - mp.param["pre_filter_start"]) * (3.5 - gp) / 20.0 + ) + gp = g + spec_c[:, b, :] *= g + + return np.asfortranarray(spec_c) + + +def wave_to_spectrogram(wave, hop_length, n_fft, mp, band, is_v51_model=False): + + if wave.ndim == 1: + wave = np.asfortranarray([wave, wave]) + + if not is_v51_model: + if mp.param["reverse"]: + wave_left = np.flip(np.asfortranarray(wave[0])) + wave_right = np.flip(np.asfortranarray(wave[1])) + elif mp.param["mid_side"]: + wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2) + wave_right = np.asfortranarray(np.subtract(wave[0], wave[1])) + elif mp.param["mid_side_b2"]: + wave_left = np.asfortranarray(np.add(wave[1], wave[0] * 0.5)) + wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * 0.5)) + else: + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + else: + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + + spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length) + spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length) + + spec = np.asfortranarray([spec_left, spec_right]) + + if is_v51_model: + spec = convert_channels(spec, mp, band) + + return spec + + +def spectrogram_to_wave(spec, hop_length=1024, mp={}, band=0, is_v51_model=True): + spec_left = np.asfortranarray(spec[0]) + spec_right = np.asfortranarray(spec[1]) + + wave_left = librosa.istft(spec_left, hop_length=hop_length) + wave_right = librosa.istft(spec_right, hop_length=hop_length) + + if is_v51_model: + cc = mp.param["band"][band].get("convert_channels") + if "mid_side_c" == cc: + return np.asfortranarray( + [ + np.subtract(wave_left / 1.0625, wave_right / 4.25), + np.add(wave_right / 1.0625, wave_left / 4.25), + ] + ) + elif "mid_side" == cc: + return np.asfortranarray( + [ + np.add(wave_left, wave_right / 2), + np.subtract(wave_left, wave_right / 2), + ] + ) + elif "stereo_n" == cc: + return np.asfortranarray( + [ + np.subtract(wave_left, wave_right * 0.25), + np.subtract(wave_right, wave_left * 0.25), + ] + ) + else: + if mp.param["reverse"]: + return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)]) + elif mp.param["mid_side"]: + return np.asfortranarray( + [ + np.add(wave_left, wave_right / 2), + np.subtract(wave_left, wave_right / 2), + ] + ) + elif mp.param["mid_side_b2"]: + return np.asfortranarray( + [ + np.add(wave_right / 1.25, 0.4 * wave_left), + np.subtract(wave_left / 1.25, 0.4 * wave_right), + ] + ) + + return np.asfortranarray([wave_left, wave_right]) + + +def cmb_spectrogram_to_wave( + spec_m, mp, extra_bins_h=None, extra_bins=None, is_v51_model=False +): + bands_n = len(mp.param["band"]) + offset = 0 + + for d in range(1, bands_n + 1): + bp = mp.param["band"][d] + spec_s = np.zeros( + shape=(2, bp["n_fft"] // 2 + 1, spec_m.shape[2]), dtype=complex + ) + h = bp["crop_stop"] - bp["crop_start"] + spec_s[:, bp["crop_start"] : bp["crop_stop"], :] = spec_m[ + :, offset : offset + h, : + ] + + offset += h + if d == bands_n: # higher + if extra_bins_h: # if --high_end_process bypass + max_bin = bp["n_fft"] // 2 + spec_s[:, max_bin - extra_bins_h : max_bin, :] = extra_bins[ + :, :extra_bins_h, : + ] + if bp["hpf_start"] > 0: + if is_v51_model: + spec_s *= get_hp_filter_mask( + spec_s.shape[1], bp["hpf_start"], bp["hpf_stop"] - 1 + ) + else: + spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1) + if bands_n == 1: + wave = spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model) + else: + wave = np.add( + wave, spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model) + ) + else: + sr = mp.param["band"][d + 1]["sr"] + if d == 1: # lower + if is_v51_model: + spec_s *= get_lp_filter_mask( + spec_s.shape[1], bp["lpf_start"], bp["lpf_stop"] + ) + else: + spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"]) + + try: + wave = librosa.resample( + spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model), + orig_sr=bp["sr"], + target_sr=sr, + res_type=wav_resolution, + ) + except ValueError as e: + print(f"Error during resampling: {e}") + print( + f"Spec_s shape: {spec_s.shape}, SR: {sr}, Res type: {wav_resolution}" + ) + + else: # mid + if is_v51_model: + spec_s *= get_hp_filter_mask( + spec_s.shape[1], bp["hpf_start"], bp["hpf_stop"] - 1 + ) + spec_s *= get_lp_filter_mask( + spec_s.shape[1], bp["lpf_start"], bp["lpf_stop"] + ) + else: + spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1) + spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"]) + + wave2 = np.add( + wave, spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model) + ) + + try: + wave = librosa.resample( + wave2, orig_sr=bp["sr"], target_sr=sr, res_type=wav_resolution + ) + except ValueError as e: + print(f"Error during resampling: {e}") + print( + f"Spec_s shape: {spec_s.shape}, SR: {sr}, Res type: {wav_resolution}" + ) + + return wave + + +def get_lp_filter_mask(n_bins, bin_start, bin_stop): + mask = np.concatenate( + [ + np.ones((bin_start - 1, 1)), + np.linspace(1, 0, bin_stop - bin_start + 1)[:, None], + np.zeros((n_bins - bin_stop, 1)), + ], + axis=0, + ) + + return mask + + +def get_hp_filter_mask(n_bins, bin_start, bin_stop): + mask = np.concatenate( + [ + np.zeros((bin_stop + 1, 1)), + np.linspace(0, 1, 1 + bin_start - bin_stop)[:, None], + np.ones((n_bins - bin_start - 2, 1)), + ], + axis=0, + ) + + return mask + + +def fft_lp_filter(spec, bin_start, bin_stop): + g = 1.0 + for b in range(bin_start, bin_stop): + g -= 1 / (bin_stop - bin_start) + spec[:, b, :] = g * spec[:, b, :] + + spec[:, bin_stop:, :] *= 0 + + return spec + + +def fft_hp_filter(spec, bin_start, bin_stop): + g = 1.0 + for b in range(bin_start, bin_stop, -1): + g -= 1 / (bin_start - bin_stop) + spec[:, b, :] = g * spec[:, b, :] + + spec[:, 0 : bin_stop + 1, :] *= 0 + + return spec + + +def spectrogram_to_wave_old(spec, hop_length=1024): + if spec.ndim == 2: + wave = librosa.istft(spec, hop_length=hop_length) + elif spec.ndim == 3: + spec_left = np.asfortranarray(spec[0]) + spec_right = np.asfortranarray(spec[1]) + + wave_left = librosa.istft(spec_left, hop_length=hop_length) + wave_right = librosa.istft(spec_right, hop_length=hop_length) + wave = np.asfortranarray([wave_left, wave_right]) + + return wave + + +def wave_to_spectrogram_old(wave, hop_length, n_fft): + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + + spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length) + spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length) + + spec = np.asfortranarray([spec_left, spec_right]) + + return spec + + +def mirroring(a, spec_m, input_high_end, mp): + if "mirroring" == a: + mirror = np.flip( + np.abs( + spec_m[ + :, + mp.param["pre_filter_start"] + - 10 + - input_high_end.shape[1] : mp.param["pre_filter_start"] + - 10, + :, + ] + ), + 1, + ) + mirror = mirror * np.exp(1.0j * np.angle(input_high_end)) + + return np.where( + np.abs(input_high_end) <= np.abs(mirror), input_high_end, mirror + ) + + if "mirroring2" == a: + mirror = np.flip( + np.abs( + spec_m[ + :, + mp.param["pre_filter_start"] + - 10 + - input_high_end.shape[1] : mp.param["pre_filter_start"] + - 10, + :, + ] + ), + 1, + ) + mi = np.multiply(mirror, input_high_end * 1.7) + + return np.where(np.abs(input_high_end) <= np.abs(mi), input_high_end, mi) + + +def adjust_aggr(mask, is_non_accom_stem, aggressiveness): + aggr = aggressiveness["value"] * 2 + + if aggr != 0: + if is_non_accom_stem: + aggr = 1 - aggr + + if np.any(aggr > 10) or np.any(aggr < -10): + print(f"Warning: Extreme aggressiveness values detected: {aggr}") + + aggr = [aggr, aggr] + + if aggressiveness["aggr_correction"] is not None: + aggr[0] += aggressiveness["aggr_correction"]["left"] + aggr[1] += aggressiveness["aggr_correction"]["right"] + + for ch in range(2): + mask[ch, : aggressiveness["split_bin"]] = np.power( + mask[ch, : aggressiveness["split_bin"]], 1 + aggr[ch] / 3 + ) + mask[ch, aggressiveness["split_bin"] :] = np.power( + mask[ch, aggressiveness["split_bin"] :], 1 + aggr[ch] + ) + + return mask + + +def stft(wave, nfft, hl): + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl) + spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl) + spec = np.asfortranarray([spec_left, spec_right]) + + return spec + + +def istft(spec, hl): + spec_left = np.asfortranarray(spec[0]) + spec_right = np.asfortranarray(spec[1]) + wave_left = librosa.istft(spec_left, hop_length=hl) + wave_right = librosa.istft(spec_right, hop_length=hl) + wave = np.asfortranarray([wave_left, wave_right]) + + return wave + + +def spec_effects(wave, algorithm="Default", value=None): + if np.isnan(wave).any() or np.isinf(wave).any(): + print( + f"Warning: Detected NaN or infinite values in wave input. Shape: {wave.shape}" + ) + + spec = [stft(wave[0], 2048, 1024), stft(wave[1], 2048, 1024)] + if algorithm == "Min_Mag": + v_spec_m = np.where(np.abs(spec[1]) <= np.abs(spec[0]), spec[1], spec[0]) + wave = istft(v_spec_m, 1024) + elif algorithm == "Max_Mag": + v_spec_m = np.where(np.abs(spec[1]) >= np.abs(spec[0]), spec[1], spec[0]) + wave = istft(v_spec_m, 1024) + elif algorithm == "Default": + wave = (wave[1] * value) + (wave[0] * (1 - value)) + elif algorithm == "Invert_p": + X_mag = np.abs(spec[0]) + y_mag = np.abs(spec[1]) + max_mag = np.where(X_mag >= y_mag, X_mag, y_mag) + v_spec = spec[1] - max_mag * np.exp(1.0j * np.angle(spec[0])) + wave = istft(v_spec, 1024) + + return wave + + +def spectrogram_to_wave_no_mp(spec, n_fft=2048, hop_length=1024): + wave = librosa.istft(spec, n_fft=n_fft, hop_length=hop_length) + + if wave.ndim == 1: + wave = np.asfortranarray([wave, wave]) + + return wave + + +def wave_to_spectrogram_no_mp(wave): + + spec = librosa.stft(wave, n_fft=2048, hop_length=1024) + + if spec.ndim == 1: + spec = np.asfortranarray([spec, spec]) + + return spec + + +def invert_audio(specs, invert_p=True): + + ln = min([specs[0].shape[2], specs[1].shape[2]]) + specs[0] = specs[0][:, :, :ln] + specs[1] = specs[1][:, :, :ln] + + if invert_p: + X_mag = np.abs(specs[0]) + y_mag = np.abs(specs[1]) + max_mag = np.where(X_mag >= y_mag, X_mag, y_mag) + v_spec = specs[1] - max_mag * np.exp(1.0j * np.angle(specs[0])) + else: + specs[1] = reduce_vocal_aggressively(specs[0], specs[1], 0.2) + v_spec = specs[0] - specs[1] + + return v_spec + + +def invert_stem(mixture, stem): + mixture = wave_to_spectrogram_no_mp(mixture) + stem = wave_to_spectrogram_no_mp(stem) + output = spectrogram_to_wave_no_mp(invert_audio([mixture, stem])) + + return -output.T + + +def ensembling(a, inputs, is_wavs=False): + + for i in range(1, len(inputs)): + if i == 1: + input = inputs[0] + + if is_wavs: + ln = min([input.shape[1], inputs[i].shape[1]]) + input = input[:, :ln] + inputs[i] = inputs[i][:, :ln] + else: + ln = min([input.shape[2], inputs[i].shape[2]]) + input = input[:, :, :ln] + inputs[i] = inputs[i][:, :, :ln] + + if MIN_SPEC == a: + input = np.where(np.abs(inputs[i]) <= np.abs(input), inputs[i], input) + if MAX_SPEC == a: + # input = np.array(np.where(np.greater_equal(np.abs(inputs[i]), np.abs(input)), inputs[i], input), dtype=object) + input = np.where(np.abs(inputs[i]) >= np.abs(input), inputs[i], input) + # max_spec = np.array([np.where(np.greater_equal(np.abs(inputs[i]), np.abs(input)), s, specs[0]) for s in specs[1:]], dtype=object)[-1] + + # linear_ensemble + # input = ensemble_wav(inputs, split_size=1) + + return input + + +def ensemble_for_align(waves): + + specs = [] + + for wav in waves: + spec = wave_to_spectrogram_no_mp(wav.T) + specs.append(spec) + + wav_aligned = spectrogram_to_wave_no_mp(ensembling(MIN_SPEC, specs)).T + wav_aligned = match_array_shapes(wav_aligned, waves[1], is_swap=True) + + return wav_aligned + + +def ensemble_inputs( + audio_input, + algorithm, + is_normalization, + wav_type_set, + save_path, + is_wave=False, + is_array=False, +): + + wavs_ = [] + + if algorithm == AVERAGE: + output = average_audio(audio_input) + samplerate = 44100 + else: + specs = [] + + for i in range(len(audio_input)): + wave, samplerate = librosa.load(audio_input[i], mono=False, sr=44100) + wavs_.append(wave) + spec = wave if is_wave else wave_to_spectrogram_no_mp(wave) + specs.append(spec) + + wave_shapes = [w.shape[1] for w in wavs_] + target_shape = wavs_[wave_shapes.index(max(wave_shapes))] + + if is_wave: + output = ensembling(algorithm, specs, is_wavs=True) + else: + output = spectrogram_to_wave_no_mp(ensembling(algorithm, specs)) + + output = to_shape(output, target_shape.shape) + + sf.write( + save_path, + normalize(output.T, is_normalization), + samplerate, + subtype=wav_type_set, + ) + + +def to_shape(x, target_shape): + padding_list = [] + for x_dim, target_dim in zip(x.shape, target_shape): + pad_value = target_dim - x_dim + pad_tuple = (0, pad_value) + padding_list.append(pad_tuple) + + return np.pad(x, tuple(padding_list), mode="constant") + + +def to_shape_minimize(x: np.ndarray, target_shape): + + padding_list = [] + for x_dim, target_dim in zip(x.shape, target_shape): + pad_value = target_dim - x_dim + pad_tuple = (0, pad_value) + padding_list.append(pad_tuple) + + return np.pad(x, tuple(padding_list), mode="constant") + + +def detect_leading_silence(audio, sr, silence_threshold=0.007, frame_length=1024): + """ + Detect silence at the beginning of an audio signal. + + :param audio: np.array, audio signal + :param sr: int, sample rate + :param silence_threshold: float, magnitude threshold below which is considered silence + :param frame_length: int, the number of samples to consider for each check + + :return: float, duration of the leading silence in milliseconds + """ + + if len(audio.shape) == 2: + # If stereo, pick the channel with more energy to determine the silence + channel = np.argmax(np.sum(np.abs(audio), axis=1)) + audio = audio[channel] + + for i in range(0, len(audio), frame_length): + if np.max(np.abs(audio[i : i + frame_length])) > silence_threshold: + return (i / sr) * 1000 + + return (len(audio) / sr) * 1000 + + +def adjust_leading_silence( + target_audio, reference_audio, silence_threshold=0.01, frame_length=1024 +): + """ + Adjust the leading silence of the target_audio to match the leading silence of the reference_audio. + + :param target_audio: np.array, audio signal that will have its silence adjusted + :param reference_audio: np.array, audio signal used as a reference + :param sr: int, sample rate + :param silence_threshold: float, magnitude threshold below which is considered silence + :param frame_length: int, the number of samples to consider for each check + + :return: np.array, target_audio adjusted to have the same leading silence as reference_audio + """ + + def find_silence_end(audio): + if len(audio.shape) == 2: + # If stereo, pick the channel with more energy to determine the silence + channel = np.argmax(np.sum(np.abs(audio), axis=1)) + audio_mono = audio[channel] + else: + audio_mono = audio + + for i in range(0, len(audio_mono), frame_length): + if np.max(np.abs(audio_mono[i : i + frame_length])) > silence_threshold: + return i + return len(audio_mono) + + ref_silence_end = find_silence_end(reference_audio) + target_silence_end = find_silence_end(target_audio) + silence_difference = ref_silence_end - target_silence_end + + try: + ref_silence_end_p = (ref_silence_end / 44100) * 1000 + target_silence_end_p = (target_silence_end / 44100) * 1000 + silence_difference_p = ref_silence_end_p - target_silence_end_p + print("silence_difference: ", silence_difference_p) + except Exception as e: + pass + + if silence_difference > 0: # Add silence to target_audio + if len(target_audio.shape) == 2: # stereo + silence_to_add = np.zeros((target_audio.shape[0], silence_difference)) + else: # mono + silence_to_add = np.zeros(silence_difference) + return np.hstack((silence_to_add, target_audio)) + elif silence_difference < 0: # Remove silence from target_audio + if len(target_audio.shape) == 2: # stereo + return target_audio[:, -silence_difference:] + else: # mono + return target_audio[-silence_difference:] + else: # No adjustment needed + return target_audio + + +def match_array_shapes(array_1: np.ndarray, array_2: np.ndarray, is_swap=False): + + if is_swap: + array_1, array_2 = array_1.T, array_2.T + + # print("before", array_1.shape, array_2.shape) + if array_1.shape[1] > array_2.shape[1]: + array_1 = array_1[:, : array_2.shape[1]] + elif array_1.shape[1] < array_2.shape[1]: + padding = array_2.shape[1] - array_1.shape[1] + array_1 = np.pad(array_1, ((0, 0), (0, padding)), "constant", constant_values=0) + + # print("after", array_1.shape, array_2.shape) + + if is_swap: + array_1, array_2 = array_1.T, array_2.T + + return array_1 + + +def match_mono_array_shapes(array_1: np.ndarray, array_2: np.ndarray): + + if len(array_1) > len(array_2): + array_1 = array_1[: len(array_2)] + elif len(array_1) < len(array_2): + padding = len(array_2) - len(array_1) + array_1 = np.pad(array_1, (0, padding), "constant", constant_values=0) + + return array_1 + + +def change_pitch_semitones(y, sr, semitone_shift): + factor = 2 ** ( + semitone_shift / 12 + ) # Convert semitone shift to factor for resampling + y_pitch_tuned = [] + for y_channel in y: + y_pitch_tuned.append( + librosa.resample( + y_channel, + orig_sr=sr, + target_sr=sr * factor, + res_type=wav_resolution_float_resampling, + ) + ) + y_pitch_tuned = np.array(y_pitch_tuned) + new_sr = sr * factor + return y_pitch_tuned, new_sr + +def average_audio(audio): + + waves = [] + wave_shapes = [] + final_waves = [] + + for i in range(len(audio)): + wave = librosa.load(audio[i], sr=44100, mono=False) + waves.append(wave[0]) + wave_shapes.append(wave[0].shape[1]) + + wave_shapes_index = wave_shapes.index(max(wave_shapes)) + target_shape = waves[wave_shapes_index] + waves.pop(wave_shapes_index) + final_waves.append(target_shape) + + for n_array in waves: + wav_target = to_shape(n_array, target_shape.shape) + final_waves.append(wav_target) + + waves = sum(final_waves) + waves = waves / len(audio) + + return waves + + +def average_dual_sources(wav_1, wav_2, value): + + if wav_1.shape > wav_2.shape: + wav_2 = to_shape(wav_2, wav_1.shape) + if wav_1.shape < wav_2.shape: + wav_1 = to_shape(wav_1, wav_2.shape) + + wave = (wav_1 * value) + (wav_2 * (1 - value)) + + return wave + + +def reshape_sources(wav_1: np.ndarray, wav_2: np.ndarray): + + if wav_1.shape > wav_2.shape: + wav_2 = to_shape(wav_2, wav_1.shape) + if wav_1.shape < wav_2.shape: + ln = min([wav_1.shape[1], wav_2.shape[1]]) + wav_2 = wav_2[:, :ln] + + ln = min([wav_1.shape[1], wav_2.shape[1]]) + wav_1 = wav_1[:, :ln] + wav_2 = wav_2[:, :ln] + + return wav_2 + + +def reshape_sources_ref(wav_1_shape, wav_2: np.ndarray): + + if wav_1_shape > wav_2.shape: + wav_2 = to_shape(wav_2, wav_1_shape) + + return wav_2 + + +def combine_arrarys(audio_sources, is_swap=False): + source = np.zeros_like(max(audio_sources, key=np.size)) + + for v in audio_sources: + v = match_array_shapes(v, source, is_swap=is_swap) + source += v + + return source + + +def combine_audio( + paths: list, audio_file_base=None, wav_type_set="FLOAT", save_format=None +): + + source = combine_arrarys([load_audio(i) for i in paths]) + save_path = f"{audio_file_base}_combined.wav" + sf.write(save_path, source.T, 44100, subtype=wav_type_set) + save_format(save_path) + + +def reduce_mix_bv(inst_source, voc_source, reduction_rate=0.9): + # Reduce the volume + inst_source = inst_source * (1 - reduction_rate) + + mix_reduced = combine_arrarys([inst_source, voc_source], is_swap=True) + + return mix_reduced + + +def organize_inputs(inputs): + input_list = {"target": None, "reference": None, "reverb": None, "inst": None} + + for i in inputs: + if i.endswith("_(Vocals).wav"): + input_list["reference"] = i + elif "_RVC_" in i: + input_list["target"] = i + elif i.endswith("reverbed_stem.wav"): + input_list["reverb"] = i + elif i.endswith("_(Instrumental).wav"): + input_list["inst"] = i + + return input_list + + +def check_if_phase_inverted(wav1, wav2, is_mono=False): + # Load the audio files + if not is_mono: + wav1 = np.mean(wav1, axis=0) + wav2 = np.mean(wav2, axis=0) + + # Compute the correlation + correlation = np.corrcoef(wav1[:1000], wav2[:1000]) + + return correlation[0, 1] < 0 + + +def align_audio( + file1, + file2, + file2_aligned, + file_subtracted, + wav_type_set, + is_save_aligned, + command_Text, + save_format, + align_window: list, + align_intro_val: list, + db_analysis: tuple, + set_progress_bar, + phase_option, + phase_shifts, + is_match_silence, + is_spec_match, +): + + global progress_value + progress_value = 0 + is_mono = False + + def get_diff(a, b): + corr = np.correlate(a, b, "full") + diff = corr.argmax() - (b.shape[0] - 1) + + return diff + + def progress_bar(length): + global progress_value + progress_value += 1 + + if (0.90 / length * progress_value) >= 0.9: + length = progress_value + 1 + + set_progress_bar(0.1, (0.9 / length * progress_value)) + + # read tracks + + if file1.endswith(".mp3") and is_macos: + length1 = rerun_mp3(file1) + wav1, sr1 = librosa.load(file1, duration=length1, sr=44100, mono=False) + else: + wav1, sr1 = librosa.load(file1, sr=44100, mono=False) + + if file2.endswith(".mp3") and is_macos: + length2 = rerun_mp3(file2) + wav2, sr2 = librosa.load(file2, duration=length2, sr=44100, mono=False) + else: + wav2, sr2 = librosa.load(file2, sr=44100, mono=False) + + if wav1.ndim == 1 and wav2.ndim == 1: + is_mono = True + elif wav1.ndim == 1: + wav1 = np.asfortranarray([wav1, wav1]) + elif wav2.ndim == 1: + wav2 = np.asfortranarray([wav2, wav2]) + + # Check if phase is inverted + if phase_option == AUTO_PHASE: + if check_if_phase_inverted(wav1, wav2, is_mono=is_mono): + wav2 = -wav2 + elif phase_option == POSITIVE_PHASE: + wav2 = +wav2 + elif phase_option == NEGATIVE_PHASE: + wav2 = -wav2 + + if is_match_silence: + wav2 = adjust_leading_silence(wav2, wav1) + + wav1_length = int(librosa.get_duration(y=wav1, sr=44100)) + wav2_length = int(librosa.get_duration(y=wav2, sr=44100)) + + if not is_mono: + wav1 = wav1.transpose() + wav2 = wav2.transpose() + + wav2_org = wav2.copy() + + command_Text("Processing files... \n") + seconds_length = min(wav1_length, wav2_length) + + wav2_aligned_sources = [] + + for sec_len in align_intro_val: + # pick a position at 1 second in and get diff + sec_seg = 1 if sec_len == 1 else int(seconds_length // sec_len) + index = sr1 * sec_seg # 1 second in, assuming sr1 = sr2 = 44100 + + if is_mono: + samp1, samp2 = wav1[index : index + sr1], wav2[index : index + sr1] + diff = get_diff(samp1, samp2) + # print(f"Estimated difference: {diff}\n") + else: + index = sr1 * sec_seg # 1 second in, assuming sr1 = sr2 = 44100 + samp1, samp2 = wav1[index : index + sr1, 0], wav2[index : index + sr1, 0] + samp1_r, samp2_r = ( + wav1[index : index + sr1, 1], + wav2[index : index + sr1, 1], + ) + diff, diff_r = get_diff(samp1, samp2), get_diff(samp1_r, samp2_r) + # print(f"Estimated difference Left Channel: {diff}\nEstimated difference Right Channel: {diff_r}\n") + + # make aligned track 2 + if diff > 0: + zeros_to_append = np.zeros(diff) if is_mono else np.zeros((diff, 2)) + wav2_aligned = np.append(zeros_to_append, wav2_org, axis=0) + elif diff < 0: + wav2_aligned = wav2_org[-diff:] + else: + wav2_aligned = wav2_org + # command_Text(f"Audio files already aligned.\n") + + if not any( + np.array_equal(wav2_aligned, source) for source in wav2_aligned_sources + ): + wav2_aligned_sources.append(wav2_aligned) + + # print("Unique Sources: ", len(wav2_aligned_sources)) + + unique_sources = len(wav2_aligned_sources) + + sub_mapper_big_mapper = {} + + for s in wav2_aligned_sources: + wav2_aligned = ( + match_mono_array_shapes(s, wav1) + if is_mono + else match_array_shapes(s, wav1, is_swap=True) + ) + + if align_window: + wav_sub = time_correction( + wav1, + wav2_aligned, + seconds_length, + align_window=align_window, + db_analysis=db_analysis, + progress_bar=progress_bar, + unique_sources=unique_sources, + phase_shifts=phase_shifts, + ) + wav_sub_size = np.abs(wav_sub).mean() + sub_mapper_big_mapper = {**sub_mapper_big_mapper, **{wav_sub_size: wav_sub}} + else: + wav2_aligned = wav2_aligned * np.power(10, db_analysis[0] / 20) + db_range = db_analysis[1] + + for db_adjustment in db_range: + # Adjust the dB of track2 + s_adjusted = wav2_aligned * (10 ** (db_adjustment / 20)) + wav_sub = wav1 - s_adjusted + wav_sub_size = np.abs(wav_sub).mean() + sub_mapper_big_mapper = { + **sub_mapper_big_mapper, + **{wav_sub_size: wav_sub}, + } + + # print(sub_mapper_big_mapper.keys(), min(sub_mapper_big_mapper.keys())) + + sub_mapper_value_list = list(sub_mapper_big_mapper.values()) + + if is_spec_match and len(sub_mapper_value_list) >= 2: + # print("using spec ensemble with align") + wav_sub = ensemble_for_align(list(sub_mapper_big_mapper.values())) + else: + # print("using linear ensemble with align") + wav_sub = ensemble_wav(list(sub_mapper_big_mapper.values())) + + # print(f"Mix Mean: {np.abs(wav1).mean()}\nInst Mean: {np.abs(wav2).mean()}") + # print('Final: ', np.abs(wav_sub).mean()) + wav_sub = np.clip(wav_sub, -1, +1) + + command_Text(f"Saving inverted track... ") + + if is_save_aligned or is_spec_match: + wav1 = ( + match_mono_array_shapes(wav1, wav_sub) + if is_mono + else match_array_shapes(wav1, wav_sub, is_swap=True) + ) + wav2_aligned = wav1 - wav_sub + + if is_spec_match: + if wav1.ndim == 1 and wav2.ndim == 1: + wav2_aligned = np.asfortranarray([wav2_aligned, wav2_aligned]).T + wav1 = np.asfortranarray([wav1, wav1]).T + + wav2_aligned = ensemble_for_align([wav2_aligned, wav1]) + wav_sub = wav1 - wav2_aligned + + if is_save_aligned: + sf.write(file2_aligned, wav2_aligned, sr1, subtype=wav_type_set) + save_format(file2_aligned) + + sf.write(file_subtracted, wav_sub, sr1, subtype=wav_type_set) + save_format(file_subtracted) + + +def phase_shift_hilbert(signal, degree): + analytic_signal = hilbert(signal) + return ( + np.cos(np.radians(degree)) * analytic_signal.real + - np.sin(np.radians(degree)) * analytic_signal.imag + ) + + +def get_phase_shifted_tracks(track, phase_shift): + if phase_shift == 180: + return [track, -track] + + step = phase_shift + end = 180 - (180 % step) if 180 % step == 0 else 181 + phase_range = range(step, end, step) + + flipped_list = [track, -track] + for i in phase_range: + flipped_list.extend( + [phase_shift_hilbert(track, i), phase_shift_hilbert(track, -i)] + ) + + return flipped_list + + +def time_correction( + mix: np.ndarray, + instrumental: np.ndarray, + seconds_length, + align_window, + db_analysis, + sr=44100, + progress_bar=None, + unique_sources=None, + phase_shifts=NONE_P, +): + # Function to align two tracks using cross-correlation + + def align_tracks(track1, track2): + # A dictionary to store each version of track2_shifted and its mean absolute value + shifted_tracks = {} + + # Loop to adjust dB of track2 + track2 = track2 * np.power(10, db_analysis[0] / 20) + db_range = db_analysis[1] + + if phase_shifts == 190: + track2_flipped = [track2] + else: + track2_flipped = get_phase_shifted_tracks(track2, phase_shifts) + + for db_adjustment in db_range: + for t in track2_flipped: + # Adjust the dB of track2 + track2_adjusted = t * (10 ** (db_adjustment / 20)) + corr = correlate(track1, track2_adjusted) + delay = np.argmax(np.abs(corr)) - (len(track1) - 1) + track2_shifted = np.roll(track2_adjusted, shift=delay) + + # Compute the mean absolute value of track2_shifted + track2_shifted_sub = track1 - track2_shifted + mean_abs_value = np.abs(track2_shifted_sub).mean() + + # Store track2_shifted and its mean absolute value in the dictionary + shifted_tracks[mean_abs_value] = track2_shifted + + # Return the version of track2_shifted with the smallest mean absolute value + + return shifted_tracks[min(shifted_tracks.keys())] + + # Make sure the audio files have the same shape + + assert ( + mix.shape == instrumental.shape + ), f"Audio files must have the same shape - Mix: {mix.shape}, Inst: {instrumental.shape}" + + seconds_length = seconds_length // 2 + + sub_mapper = {} + + progress_update_interval = 120 + total_iterations = 0 + + if len(align_window) > 2: + progress_update_interval = 320 + + for secs in align_window: + step = secs / 2 + window_size = int(sr * secs) + step_size = int(sr * step) + + if len(mix.shape) == 1: + total_mono = ( + len(range(0, len(mix) - window_size, step_size)) + // progress_update_interval + ) * unique_sources + total_iterations += total_mono + else: + total_stereo_ = len(range(0, len(mix[:, 0]) - window_size, step_size)) * 2 + total_stereo = (total_stereo_ // progress_update_interval) * unique_sources + total_iterations += total_stereo + + # print(total_iterations) + + for secs in align_window: + sub = np.zeros_like(mix) + divider = np.zeros_like(mix) + step = secs / 2 + window_size = int(sr * secs) + step_size = int(sr * step) + window = np.hanning(window_size) + + # For the mono case: + if len(mix.shape) == 1: + # The files are mono + counter = 0 + for i in range(0, len(mix) - window_size, step_size): + counter += 1 + if counter % progress_update_interval == 0: + progress_bar(total_iterations) + window_mix = mix[i : i + window_size] * window + window_instrumental = instrumental[i : i + window_size] * window + window_instrumental_aligned = align_tracks( + window_mix, window_instrumental + ) + sub[i : i + window_size] += window_mix - window_instrumental_aligned + divider[i : i + window_size] += window + else: + # The files are stereo + counter = 0 + for ch in range(mix.shape[1]): + for i in range(0, len(mix[:, ch]) - window_size, step_size): + counter += 1 + if counter % progress_update_interval == 0: + progress_bar(total_iterations) + window_mix = mix[i : i + window_size, ch] * window + window_instrumental = instrumental[i : i + window_size, ch] * window + window_instrumental_aligned = align_tracks( + window_mix, window_instrumental + ) + sub[i : i + window_size, ch] += ( + window_mix - window_instrumental_aligned + ) + divider[i : i + window_size, ch] += window + + # Normalize the result by the overlap count + sub = np.where(divider > 1e-6, sub / divider, sub) + sub_size = np.abs(sub).mean() + sub_mapper = {**sub_mapper, **{sub_size: sub}} + + # print("SUB_LEN", len(list(sub_mapper.values()))) + + sub = ensemble_wav(list(sub_mapper.values()), split_size=12) + + return sub + + +def ensemble_wav(waveforms, split_size=240): + # Create a dictionary to hold the thirds of each waveform and their mean absolute values + waveform_thirds = { + i: np.array_split(waveform, split_size) for i, waveform in enumerate(waveforms) + } + + # Initialize the final waveform + final_waveform = [] + + # For chunk + for third_idx in range(split_size): + # Compute the mean absolute value of each third from each waveform + means = [ + np.abs(waveform_thirds[i][third_idx]).mean() for i in range(len(waveforms)) + ] + + # Find the index of the waveform with the lowest mean absolute value for this third + min_index = np.argmin(means) + + # Add the least noisy third to the final waveform + final_waveform.append(waveform_thirds[min_index][third_idx]) + + # Concatenate all the thirds to create the final waveform + final_waveform = np.concatenate(final_waveform) + + return final_waveform + + +def ensemble_wav_min(waveforms): + for i in range(1, len(waveforms)): + if i == 1: + wave = waveforms[0] + + ln = min(len(wave), len(waveforms[i])) + wave = wave[:ln] + waveforms[i] = waveforms[i][:ln] + + wave = np.where(np.abs(waveforms[i]) <= np.abs(wave), waveforms[i], wave) + + return wave + + +def align_audio_test(wav1, wav2, sr1=44100): + def get_diff(a, b): + corr = np.correlate(a, b, "full") + diff = corr.argmax() - (b.shape[0] - 1) + return diff + + # read tracks + wav1 = wav1.transpose() + wav2 = wav2.transpose() + + # print(f"Audio file shapes: {wav1.shape} / {wav2.shape}\n") + + wav2_org = wav2.copy() + + # pick a position at 1 second in and get diff + index = sr1 # *seconds_length # 1 second in, assuming sr1 = sr2 = 44100 + samp1 = wav1[index : index + sr1, 0] # currently use left channel + samp2 = wav2[index : index + sr1, 0] + diff = get_diff(samp1, samp2) + + # make aligned track 2 + if diff > 0: + wav2_aligned = np.append(np.zeros((diff, 1)), wav2_org, axis=0) + elif diff < 0: + wav2_aligned = wav2_org[-diff:] + else: + wav2_aligned = wav2_org + + return wav2_aligned + + +def load_audio(audio_file): + wav, sr = librosa.load(audio_file, sr=44100, mono=False) + + if wav.ndim == 1: + wav = np.asfortranarray([wav, wav]) + + return wav + + +def rerun_mp3(audio_file): + with audioread.audio_open(audio_file) as f: + track_length = int(f.duration) + + return track_length diff --git a/mvsepless/models/windowed_roformer/flex_attention_utils.py b/mvsepless/models/windowed_roformer/flex_attention_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef9b0e06ebb2087bf3cf0a4ce965a820eaa74e9 --- /dev/null +++ b/mvsepless/models/windowed_roformer/flex_attention_utils.py @@ -0,0 +1,110 @@ +from functools import lru_cache +from typing import Optional +import torch +import torch.nn as nn +from torch.nn.attention.flex_attention import ( + flex_attention, + create_block_mask, + _mask_mod_signature, +) + +@lru_cache(maxsize=128) +def create_block_mask_cached( + mask_mod: _mask_mod_signature, + B: Optional[int], + H: Optional[int], + Q_LEN: int, + KV_LEN: int, + device: torch.device +): + block_mask = create_block_mask( + mask_mod, + B=B, + H=H, + Q_LEN=Q_LEN, + KV_LEN=KV_LEN, + device=device + ) + return block_mask + + +def get_compiled_flex_attention(compile: bool = True, mode: str = "default"): + if compile: + return torch.compile(flex_attention, dynamic=False, mode=mode) + return flex_attention + +def generate_sliding_window_with_sinks( + window_size: int, + num_sink_tokens: int +) -> _mask_mod_signature: + + half_window = window_size // 2 + + def sliding_window_with_global_sinks(b, h, q_idx, kv_idx): + is_query_sink = q_idx < num_sink_tokens + is_kv_sink = kv_idx < num_sink_tokens + is_in_window = torch.abs(q_idx - kv_idx) <= half_window + + # Query sinks can attend to everything + # Regular queries can attend to: sinks OR tokens within sliding window + return is_query_sink | is_kv_sink | is_in_window + + sliding_window_with_global_sinks.__name__ = f"sliding_window_w{window_size}_sinks{num_sink_tokens}" + return sliding_window_with_global_sinks + +class FlexAttention(nn.Module): + def __init__( + self, + mask_mod: _mask_mod_signature, + dropout: float = 0.0, + scale: Optional[float] = None, + compile: bool = False, + compile_mode: str = "max-autotune" + ): + super().__init__() + self.mask_mod = mask_mod + self.scale = scale + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) if dropout > 0 else None + + # Get compiled or uncompiled flex_attention + self.flex_attn_fn = get_compiled_flex_attention(compile, mode=compile_mode) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + + batch_size, num_heads, seq_len, head_dim = q.shape + device = q.device + + # Create block mask (cached based on mask_mod and dimensions) + # Note: We use None for B and H to allow broadcasting across batches/heads + block_mask = create_block_mask_cached( + self.mask_mod, + B=None, + H=None, + Q_LEN=seq_len, + KV_LEN=seq_len, + device=device.type + ) + # Apply scale if specified + if self.scale is not None: + # Adjust query by scale factor + default_scale = head_dim ** -0.5 + q = q * (self.scale / default_scale) + + # Apply flex attention with block mask + out = self.flex_attn_fn( + q, k, v, + block_mask=block_mask, + scale=self.scale + ) + + # Apply dropout if training + if self.training and self.attn_dropout is not None: + out = self.attn_dropout(out) + + return out \ No newline at end of file diff --git a/mvsepless/models/windowed_roformer/model.py b/mvsepless/models/windowed_roformer/model.py new file mode 100644 index 0000000000000000000000000000000000000000..fde4704b643c772dd46e0ee006128d723abbe7a9 --- /dev/null +++ b/mvsepless/models/windowed_roformer/model.py @@ -0,0 +1,340 @@ +import torch +import torch.nn as nn + +import torch + +from rotary_embedding_torch import RotaryEmbedding + +from einops import rearrange, pack, unpack, reduce, repeat +from librosa import filters +from .modules import BandSplit, MaskEstimator, Transformer, pack_one, unpack_one +from functools import partial + +class MelBandRoformerWSA(nn.Module): + def __init__( + self, + dim: int = 384, + depth: int = 6, + # num_stems: int = 1, + num_bands: int = 60, + dim_head: int = 64, + heads: int = 8, + attn_dropout: float = 0.0, + ff_dropout: float = 0.0, + sample_rate: int = 44100, + stft_n_fft: int = 2048, + stft_hop_length: int = 441, + stft_win_length: int = 2048, + stft_normalized: bool = False, + wsa_window_len: int = 10, + n_wsa_sinks: int = 8, + **kwargs + ): + super().__init__() + + # Store configuration parameters + self.audio_channels = 2 + # self.num_stems = num_stems + self.n_wsa_sinks = n_wsa_sinks + self.dim = dim + self.depth = depth + self.num_bands = num_bands + self.sample_rate = sample_rate + self.stft_n_fft = stft_n_fft + self.stft_hop_length = stft_hop_length + self.stft_win_length = stft_win_length + self.stft_normalized = stft_normalized + self.wsa_window_len = wsa_window_len + + # Store transformer configuration + self.time_transformer_kwargs = dict( + dim=dim, + heads=heads, + dim_head=dim_head, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + wsa_window_len=wsa_window_len, + n_wsa_sinks=n_wsa_sinks + ) + self.freq_transformer_kwargs = dict( + dim=dim, + heads=heads, + dim_head=dim_head, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + ) + + # Build components + self._build_stft_config() + self._build_mel_filter_bank() + self._build_transformer_layers() + self._build_band_split_and_mask_estimators() + self._build_sink_tokens() + + # Initialize weights + self.apply(self._init_weights) + + def _init_weights(self, module): + """Initialize weights for linear and conv layers""" + if isinstance(module, (nn.Linear, nn.Conv1d)): + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + if module.bias is not None: + nn.init.zeros_(module.bias) + + def _build_stft_config(self): + """Build STFT configuration and window function""" + self.stft_window_fn = partial(torch.hann_window, self.stft_win_length) + self.stft_kwargs = dict( + n_fft=self.stft_n_fft, + hop_length=self.stft_hop_length, + win_length=self.stft_win_length, + normalized=self.stft_normalized + ) + + def _build_mel_filter_bank(self): + """Build mel filter bank and frequency indices""" + # Get number of frequencies from STFT + freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(self.stft_n_fft), return_complex=True).shape[1] + + # Create mel filter bank with librosa.filters.mel as in section 2 of paper + mel_filter_bank_numpy = filters.mel(sr=self.sample_rate, n_fft=self.stft_n_fft, n_mels=self.num_bands) + mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy) + + # Fix edge cases + mel_filter_bank[0][0] = 1. # First frequency + mel_filter_bank[-1, -1] = 1. # Last frequency + + # Binary mask as in paper (estimated masks are averaged for overlapping regions) + freqs_per_band = mel_filter_bank > 0 + assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now' + + # Create frequency indices for band splitting + repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=self.num_bands) + freq_indices = repeated_freq_indices[freqs_per_band] + + # if self.stereo: + freq_indices = repeat(freq_indices, 'f -> f s', s=2) + freq_indices = freq_indices * 2 + torch.arange(2) + freq_indices = rearrange(freq_indices, 'f s -> (f s)') + + # Register buffers + self.register_buffer('freq_indices', freq_indices, persistent=False) + self.register_buffer('freqs_per_band', freqs_per_band, persistent=False) + + # Calculate frequency statistics + num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum') + num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum') + + self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False) + self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False) + + def _build_transformer_layers(self): + """Build transformer layers with rotary embeddings""" + self.layers = nn.ModuleList([]) + + time_rotary_embed = RotaryEmbedding(dim=self.time_transformer_kwargs['dim_head']) + freq_rotary_embed = RotaryEmbedding(dim=self.freq_transformer_kwargs['dim_head']) + + for _ in range(self.depth): + tran_modules = [] + tran_modules.append( + Transformer(rotary_embed=time_rotary_embed, **self.time_transformer_kwargs) + ) + tran_modules.append( + Transformer(rotary_embed=freq_rotary_embed, **self.freq_transformer_kwargs) + ) + self.layers.append(nn.ModuleList(tran_modules)) + + def _build_band_split_and_mask_estimators(self): + """Build band split module and mask estimators""" + # Calculate input dimensions for each band + freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in self.num_freqs_per_band.tolist()) + + # Band split module + self.band_split = BandSplit( + dim=self.dim, + dim_inputs=freqs_per_bands_with_complex + ) + + # Mask estimators + self.mask_estimators = nn.ModuleList([]) + for _ in range(1): + mask_estimator = MaskEstimator( + dim=self.dim, + dim_inputs=freqs_per_bands_with_complex, + depth=2, + mlp_expansion_factor=4, + ) + self.mask_estimators.append(mask_estimator) + + def _build_sink_tokens(self): + """Build learnable sink tokens for efficient attention""" + if self.n_wsa_sinks > 0: + self.sink_tokens = nn.Parameter( + torch.randn(self.n_wsa_sinks, self.num_bands, self.dim) + ) + print(f"Using {self.n_wsa_sinks} sink tokens for attention") + else: + self.sink_tokens = None + + + def forward(self, raw_audio): + """ + Main forward pass for audio source separation + + einops notation: + b - batch + f - freq + t - time + s - audio channel (2 for stereo) + n - number of 'stems' + c - complex (2) + d - feature dimension + """ + # Preprocess input audio + raw_audio, batch_info = self._preprocess_audio(raw_audio) + + # Convert to STFT representation + stft_repr = self._audio_to_stft(raw_audio, batch_info) + + # Extract features and apply band splitting + features = self._extract_features(stft_repr, batch_info) + + # Apply transformer layers with sink tokens + processed_features = self._apply_transformer_layers(features) + + # Generate masks for each stem + masks = self._generate_masks(processed_features) + + # Apply masks and reconstruct audio + recon_audio = self._reconstruct_audio(stft_repr, masks, batch_info) + + return recon_audio + + def _preprocess_audio(self, raw_audio): + """Preprocess input audio and validate dimensions""" + device = raw_audio.device + + if raw_audio.ndim == 2: + raw_audio = rearrange(raw_audio, 'b t -> b 1 t') + + batch, channels, raw_audio_length = raw_audio.shape + + raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t') + + batch_info = { + 'batch_size': batch, + 'channels': channels, + 'device': device, + 'packed_shape': batch_audio_channel_packed_shape + } + + return raw_audio, batch_info + + def _audio_to_stft(self, raw_audio, batch_info): + """Convert audio to STFT representation""" + stft_window = self.stft_window_fn(device=batch_info['device']) + + stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True) + stft_repr = torch.view_as_real(stft_repr) + + stft_repr = unpack_one(stft_repr, batch_info['packed_shape'], '* f t c') + + # Merge stereo/mono into frequency dimension for band splitting + stft_repr = rearrange(stft_repr, 'b s f t c -> b (f s) t c') + + return stft_repr + + def _extract_features(self, stft_repr, batch_info): + """Extract features using band splitting""" + batch_arange = torch.arange(batch_info['batch_size'], device=batch_info['device'])[..., None] + + # Index frequencies for band splitting + x = stft_repr[batch_arange, self.freq_indices] + + # Fold complex dimensions into frequency dimension + x = rearrange(x, 'b f t c -> b t (f c)') + + # Apply band split + x = self.band_split(x) + + return x + + def _apply_transformer_layers(self, features): + """Apply transformer layers with sink tokens""" + x = features + + # Add sink tokens at the beginning of the time dimension + if self.sink_tokens is not None: + batch_size = x.shape[0] + sinks = repeat(self.sink_tokens, 'n f d -> b n f d', b=batch_size) + x = torch.cat([sinks, x], dim=1) + + # Apply axial/hierarchical attention + for transformer_block in self.layers: + time_transformer, freq_transformer = transformer_block + + # Time transformer + x = rearrange(x, 'b t f d -> b f t d') + x, ps = pack([x], '* t d') + x = time_transformer(x) + x, = unpack(x, ps, '* t d') + x = rearrange(x, 'b f t d -> b t f d') + + # Frequency transformer + x, ps = pack([x], '* f d') + x = freq_transformer(x) + x, = unpack(x, ps, '* f d') + + # Remove sink tokens before mask estimation + if self.sink_tokens is not None: + x = x[:, self.n_wsa_sinks:, :, :] + + return x + + def _generate_masks(self, processed_features): + """Generate masks for each stem""" + masks = torch.stack([fn(processed_features) for fn in self.mask_estimators], dim=1) + masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2) + return masks + + def _reconstruct_audio(self, stft_repr, masks, batch_info): + """Apply masks and reconstruct audio""" + batch = batch_info['batch_size'] + channels = batch_info['channels'] + device = batch_info['device'] + num_stems = len(self.mask_estimators) + + # Prepare STFT for modulation + stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c') + + # Convert to complex representation + stft_repr = torch.view_as_complex(stft_repr) + masks = torch.view_as_complex(masks) + masks = masks.type(stft_repr.dtype) + + # Average masks for overlapping frequencies + scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1]) + stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems) + masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks) + + denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels) + masks_averaged = masks_summed / denom.clamp(min=1e-8) + + # Apply masks to STFT + stft_repr = stft_repr * masks_averaged + + # Convert back to real representation for ISTFT + stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels) + + # Zero out DC component + stft_repr = stft_repr.index_fill(1, torch.tensor(0, device=device), 0.) + + # ISTFT reconstruction + stft_window = self.stft_window_fn(device=device) + recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=None) + + # Reshape output + recon_audio = rearrange(recon_audio, '(b s) t -> b s t', b=batch, s=self.audio_channels) + + return recon_audio \ No newline at end of file diff --git a/mvsepless/models/windowed_roformer/modules.py b/mvsepless/models/windowed_roformer/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..c872af4b98e2bb410ff0f1c385ea709996f1b89e --- /dev/null +++ b/mvsepless/models/windowed_roformer/modules.py @@ -0,0 +1,365 @@ +from functools import partial + +import torch +from torch import nn +from torch.nn import Module, ModuleList +import torch.nn.functional as F + +from einops import rearrange, pack, unpack +from typing import Tuple +from functools import wraps +from packaging import version +from collections import namedtuple +import os +from .flex_attention_utils import ( + FlexAttention, + generate_sliding_window_with_sinks, +) + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + +class RMSNorm(Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return F.normalize(x, dim=-1) * self.scale * self.gamma + +class FeedForward(Module): + def __init__( + self, + dim, + mult=4, + dropout=0. + ): + super().__init__() + dim_inner = int(dim * mult) + self.net = nn.Sequential( + RMSNorm(dim), + nn.Linear(dim, dim_inner), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim_inner, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + +class Attention(Module): + def __init__( + self, + dim, + heads=8, + dim_head=64, + dropout=0., + rotary_embed=None, + flash=True, + wsa_window_len=None, + n_wsa_sinks=None + ): + super().__init__() + self.heads = heads + self.scale = dim_head ** -0.5 + dim_inner = heads * dim_head + + self.rotary_embed = rotary_embed + + + self.attend = Attend(flash=flash, dropout=dropout, wsa_window_len=wsa_window_len, n_wsa_sinks=n_wsa_sinks) + self.norm = RMSNorm(dim) + self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False) + + self.to_gates = nn.Linear(dim, heads) + + self.to_out = nn.Sequential( + nn.Linear(dim_inner, dim, bias=False), + nn.Dropout(dropout) + ) + + def forward(self, x, return_attn=False): + x = self.norm(x) + + q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads) + + if self.rotary_embed is not None: + q = self.rotary_embed.rotate_queries_or_keys(q) + k = self.rotary_embed.rotate_queries_or_keys(k) + + if return_attn: + out, attn = self.attend(q, k, v, return_attn=True) + else: + out = self.attend(q, k, v) + + gates = self.to_gates(x) + out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid() + + out = rearrange(out, 'b h n d -> b n (h d)') + result = self.to_out(out) + + if return_attn: + return result, attn + return result + + +class Transformer(Module): + def __init__( + self, + *, + dim, + depth:int = 1, + dim_head=64, + heads=8, + attn_dropout=0., + ff_dropout=0., + ff_mult=4, + norm_output=True, + rotary_embed=None, + use_flash=True, + wsa_window_len=None, + n_wsa_sinks=None + ): + super().__init__() + self.layers = ModuleList([]) + + for _ in range(depth): + + attn = Attention( + dim=dim, + dim_head=dim_head, + heads=heads, + dropout=attn_dropout, + rotary_embed=rotary_embed, + flash=use_flash, + wsa_window_len=wsa_window_len, + n_wsa_sinks=n_wsa_sinks + ) + + self.layers.append(ModuleList([ + attn, + FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) + ])) + + self.norm = RMSNorm(dim) if norm_output else nn.Identity() + + def forward(self, x): + + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return self.norm(x) + +class BandSplit(Module): + def __init__( + self, + dim: int, + dim_inputs: Tuple[int, ...] + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_features = ModuleList([]) + + for dim_in in dim_inputs: + net = nn.Sequential( + RMSNorm(dim_in), + nn.Linear(dim_in, dim) + ) + + self.to_features.append(net) + + def forward(self, x): + x = x.split(self.dim_inputs, dim=-1) + + outs = [] + for split_input, to_feature in zip(x, self.to_features): + split_output = to_feature(split_input) + outs.append(split_output) + + return torch.stack(outs, dim=-2) + + +def MLP( + dim_in, + dim_out, + dim_hidden=None, + depth=1, + activation=nn.Tanh +): + dim_hidden = dim_hidden if dim_hidden is not None else dim_in + + net = [] + dims = (dim_in, *((dim_hidden,) * depth), dim_out) + + for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): + is_last = ind == (len(dims) - 2) + + net.append(nn.Linear(layer_dim_in, layer_dim_out)) + + if is_last: + continue + + net.append(activation()) + + return nn.Sequential(*net) + + +class MaskEstimator(Module): + def __init__( + self, + dim, + dim_inputs: Tuple[int, ...], + depth, + mlp_expansion_factor=4 + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_freqs = ModuleList([]) + dim_hidden = dim * mlp_expansion_factor + + for dim_in in dim_inputs: + mlp = nn.Sequential( + MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), + nn.GLU(dim=-1) + ) + + self.to_freqs.append(mlp) + + def forward(self, x): + x = x.unbind(dim=-2) + + outs = [] + + for band_features, mlp in zip(x, self.to_freqs): + freq_out = mlp(band_features) + outs.append(freq_out) + + return torch.cat(outs, dim=-1) + + + +# constants +FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) + +def once(fn): + called = False + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + return inner + +print_once = once(print) + +class Attend(nn.Module): + def __init__( + self, + dropout = 0., + flash = False, + scale = None, + wsa_window_len = None, + n_wsa_sinks = None + ): + super().__init__() + self.scale = scale + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) + self.wsa_window_len = wsa_window_len + self.n_wsa_sinks = n_wsa_sinks + self.use_flash = flash + + # Initialize FlexAttention module if enabled + if wsa_window_len is not None and n_wsa_sinks > 0: + assert not (version.parse(torch.__version__) < version.parse('2.5.0')), \ + 'in order to use flex attention, you must be using pytorch 2.5 or above' + # Create the appropriate mask function + mask_mod = generate_sliding_window_with_sinks(wsa_window_len, n_wsa_sinks) + + # Create FlexAttention module with compilation enabled + self.flex_attn = FlexAttention( + mask_mod=mask_mod, + dropout=dropout, + scale=scale, + compile=True, + ) + else: + self.flex_attn = None + + # Setup for standard flash attention + assert not (self.use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), \ + 'in order to use flash attention, you must be using pytorch 2.0 or above' + + # determine efficient attention configs for cuda and cpu + self.cpu_config = FlashAttentionConfig(True, True, True) + self.cuda_config = None + + if not torch.cuda.is_available() or not self.use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + device_version = version.parse(f'{device_properties.major}.{device_properties.minor}') + + if device_version >= version.parse('8.0'): + if os.name == 'nt': + print_once('Windows OS detected, using math or mem efficient attention if input tensor is on cuda') + self.cuda_config = FlashAttentionConfig(False, True, True) + else: + print_once('GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda') + self.cuda_config = FlashAttentionConfig(True, False, False) + else: + print_once('GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda') + self.cuda_config = FlashAttentionConfig(False, True, True) + + def flash_attn(self, q, k, v): + _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device + + if self.scale is not None: + default_scale = q.shape[-1] ** -0.5 + q = q * (self.scale / default_scale) + + # Check if there is a compatible device for flash attention + config = self.cuda_config if is_cuda else self.cpu_config + + # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale + with torch.backends.cuda.sdp_kernel(**config._asdict()): + out = F.scaled_dot_product_attention( + q, k, v, + dropout_p = self.dropout if self.training else 0. + ) + + return out + + def forward(self, q, k, v, return_attn=False): + + # Use FlexAttention module if enabled + if self.flex_attn is not None: + return self.flex_attn(q, k, v) + + # Standard flash attention path + if self.use_flash: + return self.flash_attn(q, k, v) + + # Fallback: Manual attention computation + q_len, k_len, device = q.shape[-2], k.shape[-2], q.device + scale = self.scale if self.scale is not None else q.shape[-1] ** -0.5 + + # similarity + sim = torch.einsum(f"b h i d, b h j d -> b h i j", q, k) * scale + + # attention + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + out = torch.einsum(f"b h i j, b h j d -> b h i d", attn, v) + + return out \ No newline at end of file diff --git a/mvsepless/namer.py b/mvsepless/namer.py new file mode 100644 index 0000000000000000000000000000000000000000..08ba77f0f2621122c8a6c5b6e69e7148c3f1c495 --- /dev/null +++ b/mvsepless/namer.py @@ -0,0 +1,112 @@ +import os +import re + +class Namer: + def __init__(self, max_length: int = 255, offset: int = 10): + if max_length < 40: + self.max_length = 40 + else: + self.max_length = max_length + if offset < max_length: + self.safe_max_length = max_length - offset + else: + self.safe_max_length = max_length + + def sanitize(self, name: str): + """ + Очищает имя файла от запрещенных символов + """ + sanitized = re.sub(r'[<>:"/\\|?*]', '_', name) + sanitized = re.sub(r'_+', '_', sanitized) + sanitized = sanitized.strip('_. ') + return sanitized + + def short(self, name: str, length: int = None): + """ + Укорачивание имени файла до безопасной максимальной длины + """ + if length: + if len(name) > length: + return f"{name[:int(length // 2)]}...{name[-int(length // 2.5):]}" + else: + return name + else: + if len(name) > self.safe_max_length: + return f"{name[:int(self.safe_max_length // 4)]}...{name[-int(self.safe_max_length // 4):]}" + else: + return name + + def iter(self, filepath: str): + """ + Позволяет избежать перезаписи + """ + if not os.path.exists(filepath): + return filepath + + directory, filename = os.path.split(filepath) + name, ext = os.path.splitext(filename) + + counter = 1 + while True: + new_filename = f"{name} ({counter}){ext}" + new_filepath = os.path.join(directory, new_filename) + if not os.path.exists(new_filepath): + return new_filepath + counter += 1 + + def template(self, template: str, **kwargs): + """ + Заменяет ключи на значения (ключи указыаются в именованных аргументах типа KEY="value") + """ + if kwargs: + for key in kwargs: + template = template.replace(str(key), str(kwargs[key])) + return template + + def dedup_template(self, template: str, keys: list = []): + """ + Убирает дубликаты ключей в шаблоне + """ + seen = set() + pattern = r"({})".format("|".join(re.escape(key) for key in keys)) + + def replace(match): + key = match.group(1) + if key in seen: + return "" + seen.add(key) + return key + + result = re.sub(pattern, replace, template) + return result + + def short_input_name_template(self, template: str, **kwargs): + """ + Укорачивает имя входного файла для шаблона + """ + if kwargs: + input_file_name = kwargs.get("NAME", None) + if input_file_name: + merged_keys_value = "" + no_keys_template = template + for key in kwargs: + if key != "NAME": + merged_keys_value += str(kwargs[key]) + for key in kwargs: + no_keys_template = no_keys_template.replace(str(key), "") + len_merged_keys = len(merged_keys_value) + len_no_keys = len(no_keys_template) + free_length = self.safe_max_length - (len_merged_keys + len_no_keys) + len_file_name = len(input_file_name) + start_index = free_length // 2 + end_index = free_length // 2.5 + if len_file_name > free_length: + return f"{input_file_name[:int(start_index)]}...{input_file_name[-int(end_index):]}" + else: + return input_file_name + else: + print("Не был введен ключ NAME") + return "" + else: + print("Сначала введите ключи") + return "" diff --git a/mvsepless/plugins/chainless.py b/mvsepless/plugins/chainless.py new file mode 100644 index 0000000000000000000000000000000000000000..1df3eb9b76c37d1465b83dfb786a45f30ae91851 --- /dev/null +++ b/mvsepless/plugins/chainless.py @@ -0,0 +1,443 @@ +import gradio as gr +import os, sys, subprocess +import pandas as pd +from datetime import datetime +import tempfile +import json +if not __package__: + from __init__ import Separator + from downloader import dw_yt_dlp +else: + from .. import Separator + from ..downloader import dw_yt_dlp + +class Plugin(Separator): + def __init__(self): + self.name = "Авто-цепочка разделений" + self.requirements = [] + self.install_requirements(self.requirements) + + def install_requirements(self, requirements: list): + if requirements: + cmd = [os.sys.executable, "-m", "pip", "install"] + for pkg in requirements: + cmd.append(pkg) + result = subprocess.run(cmd, text=True, capture_output=True) + + + class ModelManager(Separator): + def __init__(self): + self.data = [] + self.dir_presets = os.path.join(tempfile.tempdir, "presets") + os.makedirs(self.dir_presets, exist_ok=True) + + def save(self, name): + if not name: + name = "chainless_preset" + filepath = os.path.join(self.dir_presets, f"{self.namer.short(self.namer.sanitize(name), length=50)}.json") + with open(filepath, "w") as f: + json.dump(self.data, f, indent=4, ensure_ascii=False) + return filepath + + def load(self, filepath): + with open(filepath, "r") as f: + self.data = json.load(f) + + def add(self, mt, mn, s_stem, out_stem, int_stem): + if int_stem: + self.data.append((mt, mn, s_stem, out_stem, int_stem)) + + def replace(self, mt, mn, s_stem, out_stem, int_stem, index=1): + if self.data: + len_data = len(self.data) + if index >= 1: + if index <= len_data: + self.data[index - 1] = (mt, mn, s_stem, out_stem, int_stem) + elif index == 0: + self.data[0] = (mt, mn, s_stem, out_stem, int_stem) + + def remove(self, index=1): + if self.data: + len_data = len(self.data) + if index >= 1: + if index <= len_data: + del self.data[index - 1] + elif index == 0: + del self.data[0] + + def clear(self): + self.data = [] + + def get_df(self): + if not self.data: + columns = ["#", "Имя модели", "Выбранные стемы", "Остаток", "Промежуточный стем"] + return pd.DataFrame(columns=columns) + + data = [] + for i, model in enumerate(self.data): + data.append( + [ + f"{i+1}", + model[1], + str(model[2]), + str(model[3]), + model[4], + ] + ) + columns = ["#", "Имя модели", "Выбранные стемы", "Остаток", "Промежуточный стем"] + return pd.DataFrame(data, columns=columns) + + def UI(self): + def get_output_stems(mt, mn, s_stem): + output_stems = [] + stems = self.model_manager.get_stems(mt, mn) + if not s_stem: + for stem in stems: + output_stems.append(stem) + if set(stems) == {"bass", "drums", "vocals", "other"} or set(stems) == {"bass", "drums", "vocals", "other", "guitar", "piano"}: + output_stems.append("instrumental +") + output_stems.append("instrumental -") + elif s_stem: + if len(stems) > 2: + for stem in stems: + if stem in s_stem: + output_stems.append(stem) + output_stems.append("inverted -") + if len(stems) - len(s_stem) > 1: + output_stems.append("inverted +") + + return output_stems + + def get_invert_output_stems(mt, mn, s_stem, out_stem): + output_stems = [] + stems = get_output_stems(mt, mn, s_stem) + for stem in stems: + if stem not in out_stem: + output_stems.append(stem) + return output_stems + + default_model = { + "mt": self.model_manager.get_mt(), + "mn": self.model_manager.get_mn(self.model_manager.get_mt()[0]), + "stems": self.model_manager.get_stems( + self.model_manager.get_mt()[0], + self.model_manager.get_mn(self.model_manager.get_mt()[0])[0], + ), + "output_stems": get_output_stems( + self.model_manager.get_mt()[0], + self.model_manager.get_mn(self.model_manager.get_mt()[0])[0], + [] + ), + "int_stems": get_invert_output_stems( + self.model_manager.get_mt()[0], + self.model_manager.get_mn(self.model_manager.get_mt()[0])[0], + [], + "vocals", + ), + } + + chain_manager = self.ModelManager() + + gr.Markdown("

Пресет

") + with gr.Group(): + with gr.Row(equal_height=True): + export_preset_name = gr.Textbox( + label="Имя пресета", + interactive=True, + value="chainless_preset", scale=9 + ) + export_btn = gr.DownloadButton("Экспорт", variant="secondary", scale=3, interactive=True) + import_btn = gr.UploadButton( + "Импорт", file_types=[".json"], file_count="single", scale=3, interactive=True + ) + gr.Markdown("

Цепочка разделений

") + with gr.Row(): + with gr.Column(scale=3): # логика добавлеия моделей + model_type = gr.Dropdown(label="Тип модели", choices=default_model["mt"], value=default_model["mt"][0], interactive=True, filterable=False) + model_name = gr.Dropdown(label="Имя модели", choices=default_model["mn"], value=default_model["mn"][0], interactive=True, filterable=False) + selected_stems = gr.Dropdown(label="Выбранные стемы", choices=default_model["stems"], value=[], multiselect=True, interactive=False, filterable=False) + output_stems = gr.Dropdown(label="Остаток", choices=default_model["output_stems"], value=[default_model["output_stems"][0]], multiselect=True, interactive=True, filterable=False) + intermediate_stem = gr.Dropdown(label="Проиежуточный стем", choices=default_model["int_stems"], value=default_model["int_stems"][0], interactive=True, filterable=False) + @model_type.change( + inputs=[model_type], + outputs=[model_name] + ) + def update_model_names(mt): + model_names = self.model_manager.get_mn(mt) + new_mn = model_names[0] if model_names else "" + + return gr.update(choices=model_names, value=new_mn) + @model_name.change( + inputs=[model_type, model_name], + outputs=[selected_stems, output_stems, intermediate_stem] + ) + def update_stems_after_model_change(mt, mn): + stems = self.model_manager.get_stems(mt, mn) + _output_stems = get_output_stems(mt, mn, []) + invert_stems = get_invert_output_stems(mt, mn, [], [_output_stems[0]]) + + new_out_stem = [_output_stems[0]] + new_inv_out_stem = invert_stems[0] + + return ( + gr.update(choices=stems, value=[], interactive=False if len(stems) <= 2 else True), + gr.update(choices=_output_stems, value=new_out_stem, max_choices=len(_output_stems) - 1), + gr.update(choices=invert_stems, value=new_inv_out_stem) + ) + + @selected_stems.change( + inputs=[model_type, model_name, selected_stems], + outputs=[output_stems, intermediate_stem] + ) + def update_invert_stems(mt, mn, s_stem): + stems = get_output_stems(mt, mn, s_stem) + new_i_stem = [stems[0]] + invert_stems = get_invert_output_stems(mt, mn, s_stem, new_i_stem) + return gr.update(choices=stems, value=new_i_stem), gr.update(choices=invert_stems, value=invert_stems[0]) + + @output_stems.change( + inputs=[model_type, model_name, selected_stems, output_stems], + outputs=[intermediate_stem] + ) + def update_invert2_stems(mt, mn, s_stem, out_stem): + invert_stems = get_invert_output_stems(mt, mn, s_stem, out_stem) + return gr.update(choices=invert_stems, value=invert_stems[0]) + + model_add_button = gr.Button("Добавить", interactive=True) + + + with gr.Column(scale=10): + df = gr.DataFrame( + value=chain_manager.get_df(), + headers=["#", "Имя модели", "Выбранные стемы", "Остаток", "Промежуточный стем"], + datatype=["number", "str", "str", "str", "str"], + interactive=False + ) + + with gr.Group(): + with gr.Row(equal_height=True): + with gr.Column(): + model_index = gr.Number(label="Индекс модели", value=1, interactive=True) + model_clear_btn = gr.Button("Очистить", variant="stop", interactive=True) + with gr.Column(): + model_replace_btn = gr.Button("Заменить", variant="primary", interactive=True) + model_delete_btn = gr.Button("Удалить", variant="stop", interactive=True) + + @model_add_button.click( + inputs=[model_type, model_name, selected_stems, output_stems, intermediate_stem], + outputs=df + ) + def add_model_to_auto_ensemble(mt, mn, s_stem, out_stem, int_stem): + chain_manager.add(mt, mn, s_stem, out_stem, int_stem) + return chain_manager.get_df() + + @model_replace_btn.click( + inputs=[model_type, model_name, selected_stems, output_stems, intermediate_stem, model_index], + outputs=df + ) + def replace_model_to_auto_ensemble(mt, mn, s_stem, out_stem, int_stem, index): + chain_manager.replace(mt, mn, s_stem, out_stem, int_stem, index) + return chain_manager.get_df() + + @model_delete_btn.click( + inputs=[model_index], + outputs=df + ) + def delete_model_to_auto_ensemble(index): + chain_manager.remove(index) + return chain_manager.get_df() + + @model_clear_btn.click( + outputs=df + ) + def clear_model_to_auto_ensemble(): + chain_manager.clear() + return chain_manager.get_df() + + gr.on(fn=chain_manager.get_df, outputs=df) + + df.change( + fn=chain_manager.save, + inputs=export_preset_name, + outputs=export_btn + ) + + export_preset_name.change( + fn=chain_manager.save, + inputs=export_preset_name, + outputs=export_btn + ) + + @import_btn.upload( + inputs=import_btn, + outputs=df + ) + def load_ensemble_preset(filepath): + chain_manager.load(filepath) + return chain_manager.get_df() + + with gr.Row(): + with gr.Column(): + gr.Markdown("

Входное аудио

") + with gr.Group(): + with gr.Group(visible=False) as add_inputs: + input_path = gr.Textbox(label="Путь к входному файлу", interactive=True) + add_inputs_btn = gr.Button("Загрузить файл", variant="primary") + with gr.Group(visible=False) as add_inputs_from_url: + input_url = gr.Textbox(label="URL входного файла", interactive=True) + with gr.Row(): + inputs_url_format = gr.Dropdown(label="Формат входного файла", interactive=True, + choices=self.audio.output_formats, + value="mp3", filterable=False) + inputs_url_bitrate = gr.Slider(label="Битрейт входного файла", minimum=64, maximum=512, step=32, value=320, interactive=True) + with gr.Row(): + inputs_url_cookie = gr.UploadButton(label="Файл cookie (необязательно)", interactive=True, type="filepath", file_count="single", file_types=[".txt", ".cookies"], variant="secondary") + add_inputs_url_btn = gr.Button("Загрузить файл", variant="primary") + with gr.Row(visible=True) as add_buttons_row: + add_path_btn = gr.Button("Загрузить файл по пути", variant="secondary") + add_url_btn = gr.Button("Загрузить файл по URL", variant="secondary") + with gr.Group(): + input_audio = gr.File(label="Входное аудио", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.audio.input_formats]) + with gr.Column(): + gr.Markdown("

Настройки

") + with gr.Group(): + save_only_last_intermediate_stem_check = gr.Checkbox(label="Сохранить только последний промежуточный стем", interactive=True, value=False) + output_format = gr.Dropdown(label="Формат выходного файла", interactive=True, + choices=self.audio.output_formats, + value="mp3", filterable=False) + run_btn = gr.Button("Создать цепочку разделений", variant="primary", interactive=True) + + with gr.Row(): + with gr.Column(): + gr.Markdown("

Результаты

") + last_intermediate_stem = gr.Audio(label="Последний промежуточный стем", type="filepath", interactive=False, show_download_button=True) + with gr.Group(): + invert_method = gr.Radio( + choices=["waveform", "spectrogram"], + label="Метод создания инверсии", + value="waveform", + ) + invert_btn = gr.Button("Инвертировать") + output_inverted_audio = gr.Audio(label="Инверсия", type="filepath", interactive=False, show_download_button=True) + @invert_btn.click(inputs=[input_audio, last_intermediate_stem, invert_method, output_format], outputs=[output_inverted_audio]) + def invert_result_ensemble(input_file, output_file, method, out_format): + if input_file and output_file: + o_dir = os.path.dirname(output_file) + basename = os.path.splitext(os.path.basename(input_file))[0] + output_path = os.path.join(o_dir, f"chainless_{self.namer.short(basename, length=50)}_{method}_invert.{out_format}") + inverted = self.inverter.process_audio(audio1_path=input_file, audio2_path=output_file, out_format=out_format, method=method, output_path=output_path) + return inverted + else: + return None + + with gr.Column(): + gr.Markdown("

Исходники цепочки

") + output_source_files = gr.Files(type="filepath", interactive=False, show_label=False) + output_source_preview_check = gr.Checkbox(label="Показать плееры для исходников цепочки", interactive=True, value=False) + @gr.render(inputs=[output_source_preview_check, output_source_files]) + def show_output_auto_ensemble_players(preview, audios): + if preview: + if audios: + with gr.Group(): + for file in audios: + gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath") + + + @run_btn.click( + inputs=[input_audio, output_format, save_only_last_intermediate_stem_check], + outputs=[last_intermediate_stem, output_source_files] + ) + def chain( + input_audio, + out_format, + save_only_last_intermediate_stem, + progress=gr.Progress(track_tqdm=True) + ): + + input_settings = chain_manager.data + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + o = tempfile.mkdtemp(prefix=f"chainless_outputs_{timestamp}_") + os.makedirs(o, exist_ok=True) + + base_name = os.path.splitext(os.path.basename(input_audio))[0] + last_intermediate_stem = None + last_intermediate_stem_name = None + _output_stems = [] + + block_count = len(input_settings) + + for i, model in enumerate(input_settings, start=1): + input_model_type = model[0] + input_model_name = model[1] + selected_stems = model[2] + selected_output_stems = model[3] + intermediate_stem = model[4] + output_p = self.separate( + input=( + input_audio + if not last_intermediate_stem + else last_intermediate_stem + ), + output_dir=os.path.join(o, input_model_name), + model_type=input_model_type, + model_name=input_model_name, + ext_inst=True, + output_format=out_format, + template=f"NAME_{i}_{f'({i - 1}_{str(last_intermediate_stem_name)})_' if last_intermediate_stem_name else ''}MODEL_STEM", + selected_stems=selected_stems, + add_settings={"add_single_sep_text_progress": f"{i} из {block_count}"}, + progress=progress + ) + for stem, file in output_p: + if stem in selected_output_stems: + _output_stems.append(file) + elif stem == intermediate_stem: + last_intermediate_stem = file + last_intermediate_stem_name = stem + _output_stems.append(file) + + return last_intermediate_stem, _output_stems if not save_only_last_intermediate_stem else [] + + @add_inputs_btn.click( + inputs=[input_path, input_audio], + outputs=[add_inputs, input_audio, add_buttons_row]) + def add_inputs_fn(input_p, input_a): + if input_p and os.path.exists(input_p): + if input_a is None: + input_a = None + if self.audio.check(input_p): + input_a = input_p + return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True) + return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True) + + @add_inputs_url_btn.click( + inputs=[input_url, input_audio, inputs_url_format, inputs_url_bitrate, inputs_url_cookie], + outputs=[add_inputs_from_url, input_audio, add_buttons_row]) + def add_inputs_from_url_fn(input_u, input_a, fmt, br, cookie): + if input_u: + if input_a is None: + input_a = None + downloaded_file = dw_yt_dlp( + url=input_u, + output_format=fmt, + output_bitrate=str(int(br)), + cookie=cookie + ) + if downloaded_file and os.path.exists(downloaded_file): + if self.audio.check(downloaded_file): + input_a = downloaded_file + return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True) + return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True) + + add_path_btn.click( + lambda: (gr.update(visible=True), gr.update(visible=False)), + outputs=[add_inputs, add_buttons_row]) + + add_url_btn.click( + lambda: (gr.update(visible=True), gr.update(visible=False)), + outputs=[add_inputs_from_url, add_buttons_row]) + + inputs_url_format.change(lambda x: gr.update(visible=False if x in ["wav", "flac", "aiff"] else True), inputs=inputs_url_format, outputs=inputs_url_bitrate) + + diff --git a/mvsepless/plugins/matchering.py b/mvsepless/plugins/matchering.py new file mode 100644 index 0000000000000000000000000000000000000000..0a9d1cd4afd99236bbbfad7f2bab439b8b450088 --- /dev/null +++ b/mvsepless/plugins/matchering.py @@ -0,0 +1,93 @@ +import gradio as gr +import os, sys, subprocess +if not __package__: + from audio import Audio +else: + from ..audio import Audio +import tempfile + +class Plugin(Audio): + def __init__(self): + super().__init__() + self.name = "Matchering" + self.requirements = ["matchering"] + self.install_requirements(self.requirements) + + def install_requirements(self, requirements: list): + if requirements: + cmd = [os.sys.executable, "-m", "pip", "install"] + for pkg in requirements: + cmd.append(pkg) + result = subprocess.run(cmd, text=True, capture_output=True) + + def match(self, target_audio: str = None, reference_audio: str = None, output_path: str = "matchered.mp3", output_format: str = "mp3"): + from matchering.log import Code, info, debug, debug_line, ModuleError + from matchering import Config, Result + from matchering.stages import main + from matchering.checker import check, check_equality + from matchering.dsp import channel_count, size + config = Config() + target, target_sample_rate, _ = self.read(target_audio, mono=False, sr=44100) + target = target.T + target, target_sample_rate = check(target, target_sample_rate, config, "target") + reference, reference_sample_rate, _ = self.read(reference_audio, mono=False, sr=44100) + reference = reference.T + reference, reference_sample_rate = check( + reference, reference_sample_rate, config, "reference" + ) + + if not config.allow_equality: + check_equality(target, reference) + + if ( + not (target_sample_rate == reference_sample_rate == config.internal_sample_rate) + or not (channel_count(target) == channel_count(reference) == 2) + or not (size(target) > config.fft_size and size(reference) > config.fft_size) + ): + raise ModuleError(Code.ERROR_VALIDATION) + + result, result_no_limiter, result_no_limiter_normalized = main( + target, + reference, + config, + need_default=True, + need_no_limiter=False, + need_no_limiter_normalized=False, + ) + + output_path = self.write(o=output_path, array=result, of=output_format, sr=config.internal_sample_rate) + return output_path + + def UI(self): + with gr.Group(): + with gr.Row(): + target = gr.File(label="Целевое аудио", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.input_formats]) + reference = gr.File(label="Референс", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.input_formats]) + @gr.render(inputs=[target, reference]) + def check_audio_files(tgt, rfr): + if tgt and rfr: + info_tgt = self.get_info(i=tgt) + info_rfr = self.get_info(i=rfr) + sr_tgt = info_tgt[0]["sample_rate"] + sr_rfr = info_rfr[0]["sample_rate"] + status = f"{sr_tgt} = {sr_rfr}" if sr_tgt == sr_rfr else f"{sr_tgt} != {sr_rfr}" + gr.Markdown(f"

{status}

") + + output_format = gr.Dropdown(label="Формат выходного файла", interactive=True, + choices=self.output_formats, + value="mp3", filterable=False) + matchered_audio = gr.Audio(label="Результат", type="filepath", show_download_button=True, interactive=False) + match_btn = gr.Button("Сопоставить") + @match_btn.click( + inputs=[target, reference, output_format], + outputs=[matchered_audio] + ) + def match_audio(tgt, rfr, of): + if tgt and rfr: + o_dir = tempfile.mkdtemp(prefix="matchering") + basename = os.path.splitext(os.path.basename(tgt))[0] + output_path = os.path.join(o_dir, f"{self.short(basename, length=120)}_matchered.{of}") + matchered = self.match(target_audio=tgt, reference_audio=rfr, output_path=output_path, output_format=of) + return matchered + else: + return None \ No newline at end of file diff --git a/mvsepless/plugins/remove_center.py b/mvsepless/plugins/remove_center.py new file mode 100644 index 0000000000000000000000000000000000000000..d48763f2bdec1ff6fd09b7157927a170c8baf2d3 --- /dev/null +++ b/mvsepless/plugins/remove_center.py @@ -0,0 +1,259 @@ +import gradio as gr +import os, sys, subprocess +import tempfile +from scipy import signal +import numpy as np +from datetime import datetime +if not __package__: + from audio import Audio +else: + from ..audio import Audio + +class Plugin(Audio): + def __init__(self): + super().__init__() + self.name = "Вычитание фантомного центра" + self.requirements = [] + self.install_requirements(self.requirements) + self.w_types = [ + "boxcar", # Прямоугольное окно + "triang", # Треугольное окно + "blackman", # Окно Блэкмана + "hamming", # Окно Хэмминга + "hann", # Окно Ханна + "bartlett", # Окно Бартлетта + "flattop", # Окно с плоской вершиной + "parzen", # Окно Парзена + "bohman", # Окно Бохмана + "blackmanharris", # Окно Блэкмана-Харриса + "nuttall", # Окно Нуттала + "barthann", # Окно Бартлетта-Ханна + "cosine", # Косинусное окно + "exponential", # Экспоненциальное окно + "tukey", # Окно Туки + "taylor", # Окно Тейлора + "lanczos", # Окно Ланцоша + ] + + def install_requirements(self, requirements: list): + if requirements: + cmd = [os.sys.executable, "-m", "pip", "install"] + for pkg in requirements: + cmd.append(pkg) + result = subprocess.run(cmd, text=True, capture_output=True) + + def remove_center( + self, + input_file, + output_format="flac", + out_center="center.flac", + out_stereo_base="stereo_base.flac", + rdf=0.99999, + window_size=4096, + overlap=2, + window_type="blackman", + stereo_mode="stereo", + ): + output_file = out_stereo_base + output_center_file = out_center + data, samplerate, _ = self.read(i=input_file, mono=False, sr=None) + + if data.ndim != 2 or data.shape[0] != 2: + raise ValueError("Требуется стереофайл (2 канала)") + + left = data[0, :] + right = data[1, :] + mono = left * 0.5 + right * 0.5 + + nperseg = window_size # Размер окна + noverlap = nperseg // overlap # Перекрытие окон + + f, t, Z_left = signal.stft( + left, fs=samplerate, nperseg=nperseg, noverlap=noverlap, window=window_type + ) + f, t, Z_right = signal.stft( + right, fs=samplerate, nperseg=nperseg, noverlap=noverlap, window=window_type + ) + f, t, Z_mono = signal.stft( + mono, fs=samplerate, nperseg=nperseg, noverlap=noverlap, window=window_type + ) + if stereo_mode == "mono": + Z_common_left = np.minimum(np.abs(Z_left), np.abs(Z_right)) * np.exp( + 1j * np.angle(Z_mono) + ) + Z_common_right = np.minimum(np.abs(Z_left), np.abs(Z_right)) * np.exp( + 1j * np.angle(Z_mono) + ) + else: + Z_common_left = np.minimum(np.abs(Z_left), np.abs(Z_right)) * np.exp( + 1j * np.angle(Z_right) + ) + Z_common_right = np.minimum(np.abs(Z_left), np.abs(Z_right)) * np.exp( + 1j * np.angle(Z_left) + ) + + reduction_factor = rdf + + Z_new_left = Z_left - Z_common_left * reduction_factor + Z_new_right = Z_right - Z_common_right * reduction_factor + + _, new_left = signal.istft( + Z_new_left, + fs=samplerate, + nperseg=nperseg, + noverlap=noverlap, + window=window_type, + ) + _, new_right = signal.istft( + Z_new_right, + fs=samplerate, + nperseg=nperseg, + noverlap=noverlap, + window=window_type, + ) + + _, common_signal_left = signal.istft( + Z_common_left, + fs=samplerate, + nperseg=nperseg, + noverlap=noverlap, + window=window_type, + ) + _, common_signal_right = signal.istft( + Z_common_right, + fs=samplerate, + nperseg=nperseg, + noverlap=noverlap, + window=window_type, + ) + + new_left = new_left[: len(left)] + new_right = new_right[: len(right)] + common_signal_left = common_signal_left[: len(left)] + common_signal_right = common_signal_right[: len(right)] + + peak = np.max([np.abs(new_left).max(), np.abs(new_right).max()]) + if peak > 1.0: + new_left = new_left / peak + new_right = new_right / peak + + output_file = self.write( + o=output_file, + array=np.column_stack((new_left, new_right)), + sr=samplerate, + of=output_format, + br="320k", + ) + + inverted_center_left = -common_signal_left + inverted_center_right = -common_signal_right + + mixed_left = left + inverted_center_left + mixed_right = right + inverted_center_right + + peak_mixed = np.max([np.abs(mixed_left).max(), np.abs(mixed_right).max()]) + if peak_mixed > 1.0: + mixed_left = mixed_left / peak_mixed + mixed_right = mixed_right / peak_mixed + + output_center_file = self.write( + o=output_center_file, + array=np.column_stack((common_signal_left, common_signal_right)), + sr=samplerate, + of=output_format, + br="320k", + ) + + return (output_file, output_center_file) + + def UI(self): + with gr.Row(): + rmv_center_ui_input_audio = gr.File(label="Входное аудио", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.input_formats]) + with gr.Group(): + with gr.Row(): + rmv_center_ui_reduction_f = gr.Slider( + 0.1, + 10, + value=1, + step=0.1, + label="Фактор подавления", + interactive=True, + visible=False, + ) + rmv_center_ui_overlap = gr.Slider( + 2, + 30, + value=2, + step=1, + label="Перекрытие", + interactive=True, + visible=True, + ) + rmv_center_ui_window_size = gr.Number( + label="Размер окна", + interactive=True, + visible=True, + minimum=32, + maximum=882000, + precision=1, + value=2048, + ) + with gr.Row(): + rmv_center_ui_format = gr.Dropdown( + self.output_formats, + value=self.output_formats[0], + filterable=False, + label="Формат выходного файла", + interactive=True + ) + rmv_center_ui_window_types = gr.Dropdown( + self.w_types, + value=self.w_types[4], + filterable=False, + label="Тип окна", + interactive=True + ) + + rmv_center_ui_mono_mode = gr.Dropdown( + ["mono", "stereo"], + value="mono", + filterable=False, + label="Стерео-режим", + interactive=True + ) + rmv_center_ui_extract_btn = gr.Button("Разделить") + with gr.Group(): + with gr.Column(): + with gr.Row(): + rmv_center_ui_mid = gr.Audio( + type="filepath", + interactive=False, + label="Фантомный центр", + visible=True, + show_download_button=True, + ) + rmv_center_ui_side = gr.Audio( + type="filepath", + interactive=False, + label="Стерео-база", + visible=True, + show_download_button=True, + ) + @rmv_center_ui_extract_btn.click( + inputs=[ + rmv_center_ui_input_audio, + rmv_center_ui_format, + rmv_center_ui_reduction_f, + rmv_center_ui_window_size, + rmv_center_ui_overlap, + rmv_center_ui_window_types, + rmv_center_ui_mono_mode, + ], + outputs=[rmv_center_ui_side, rmv_center_ui_mid], + ) + def wrap_remove_center(input_audio, output_format, rf, ws, ovlp, wt, mono_mode): + if input_audio: + temp_dir = tempfile.mkdtemp(prefix="remove_center_") + basename = self.short(os.path.splitext(os.path.basename(input_audio))[0], length=80) + side, mid = self.remove_center(input_file=input_audio, output_format=output_format, out_center=os.path.join(temp_dir, f"{basename}_center.{output_format}"), out_stereo_base=os.path.join(temp_dir, f"{basename}_stereo_base.{output_format}"), overlap=ovlp, rdf=rf, stereo_mode=mono_mode, window_size=ws, window_type=wt) + return side, mid \ No newline at end of file diff --git a/mvsepless/plugins/test.py b/mvsepless/plugins/test.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4e2de68a41308108f66458ce3e17896f62fc09 --- /dev/null +++ b/mvsepless/plugins/test.py @@ -0,0 +1,31 @@ +import gradio as gr +import os, sys, subprocess +if not __package__: + from __init__ import Separator +else: + from .. import Separator + +class Plugin(Separator): + def __init__(self): + self.name = "Тестовый плагин" + self.requirements = [] + self.install_requirements(self.requirements) + + def install_requirements(self, requirements: list): + if requirements: + cmd = [os.sys.executable, "-m", "pip", "install"] + for pkg in requirements: + cmd.append(pkg) + result = subprocess.run(cmd, text=True, capture_output=True) + + def test(self): + print("Тест") + print(self.model_manager.get_mt()) + + def UI(self): + with gr.Column(): + gr.Markdown("

Пример рабочего плагина

") + gr.Button("Показать все типы моделей", variant="primary").click(self.test) + + + \ No newline at end of file diff --git a/mvsepless/plugins/whatbpm.py b/mvsepless/plugins/whatbpm.py new file mode 100644 index 0000000000000000000000000000000000000000..9115af91c196209fc0ef11df2d892eb5237b8339 --- /dev/null +++ b/mvsepless/plugins/whatbpm.py @@ -0,0 +1,89 @@ +import gradio as gr +import os, sys, subprocess +import librosa +if not __package__: + from audio import Audio +else: + from ..audio import Audio + +class Plugin(Audio): + def __init__(self): + super().__init__() + self.name = "Узнать темп (через Librosa)" + self.requirements = [] + self.install_requirements(self.requirements) + + def install_requirements(self, requirements: list): + if requirements: + cmd = [os.sys.executable, "-m", "pip", "install"] + for pkg in requirements: + cmd.append(pkg) + result = subprocess.run(cmd, text=True, capture_output=True) + + def get_tempo(self, input_audio) -> float: + y, sr, _ = self.read(i=input_audio, sr=44100, mono=True) + tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sr) + return float(tempo) + + def UI(self): + with gr.Tab("Узнать темп из множества файлов"): + with gr.Row(): + with gr.Column(): + input_audio = gr.File(label="Аудио", interactive=True, type="filepath", file_count="multiple", file_types=[f".{of}" for of in self.input_formats]) + input_preview_check = gr.Checkbox(label="Показать плееры для входных аудио", interactive=True, value=False) + @gr.render(inputs=[input_preview_check, input_audio]) + def show_input_players(preview, audios): + if preview: + if audios: + with gr.Group(): + for file in audios: + gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath") + + @gr.render(inputs=[input_audio]) + def get_tempos_from_multiple_audios(audio_list): + status = f"Результаты анализа темпа\n---" + if audio_list: + for file in audio_list: + if self.check(file): + basename = os.path.splitext(os.path.basename(file))[0] + status += f"\n {self.short(basename, length=30)} - {self.get_tempo(file)}" + gr.Textbox(container=False, interactive=False, value=status, lines=len(status.split("\n"))) + + with gr.Tab("Сравнить темпы двух файлов"): + with gr.Row(): + audio1 = gr.File(label="Аудио 1", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.input_formats]) + audio2 = gr.File(label="Аудио 2", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.input_formats]) + @gr.render(inputs=[audio1, audio2]) + def compare_tempos(a1, a2): + if a1 and a2: + if self.check(a1) and self.check(a2): + tempo1 = self.get_tempo(a1) + tempo2 = self.get_tempo(a2) + with gr.Group(): + with gr.Row(variant="compact"): + with gr.Column(scale=1, min_width=80): + gr.Markdown("

Темп 1

") + t1 = gr.Number(container=False, value=tempo1, interactive=False) + with gr.Row(): + t1_div = gr.Button("/2", interactive=True, variant="stop", min_width=20) + t1_multiplic = gr.Button("*2", interactive=True, min_width=20) + t1_div.click(lambda x: float(x) / 2, inputs=t1, outputs=t1) + t1_multiplic.click(lambda x: float(x) * 2, inputs=t1, outputs=t1) + with gr.Column(scale=1, min_width=80): + gr.Markdown("

Темп 2

") + t2 = gr.Number(container=False, value=tempo2, interactive=False) + with gr.Row(): + t2_div = gr.Button("/2", interactive=True, variant="stop", min_width=20) + t2_multiplic = gr.Button("*2", interactive=True, min_width=20) + t2_div.click(lambda x: x / 2, inputs=t2, outputs=t2) + t2_multiplic.click(lambda x: x * 2, inputs=t2, outputs=t2) + compare_result = gr.Textbox(container=False, interactive=False, lines=2) + compare_btn = gr.Button("Сравнить", variant="primary", interactive=True) + @compare_btn.click( + inputs=[t1, t2], + outputs=compare_result + ) + def compare_tempo_fn(y1, y2): + result = f""" Темп 1 / Темп 2 = {y1 / y2} +Темп 2 / Темп 1 = {y2 / y1}""" + return result \ No newline at end of file diff --git a/mvsepless/vbach_infer.py b/mvsepless/vbach_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..88cbdd38289ad02c845c704cf343d6170a68b21a --- /dev/null +++ b/mvsepless/vbach_infer.py @@ -0,0 +1,1207 @@ +import os +import gc +import torch +import torch.nn.functional as F +import torchcrepe +import faiss +import librosa +import numpy as np +from scipy import signal +import argparse +script_dir = os.path.dirname(os.path.abspath(__file__)) + +FILTER_ORDER = 5 +CUTOFF_FREQUENCY = 48 +SAMPLE_RATE = 16000 +bh, ah = signal.butter(N=FILTER_ORDER, Wn=CUTOFF_FREQUENCY, btype="high", fs=SAMPLE_RATE) + +import librosa +from multiprocessing import cpu_count + +if not __package__: + from model_manager import VbachModelManager + from audio import Audio + from namer import Namer + from vbach_lib.fairseq import load_model_ensemble_and_task, load_checkpoint_to_cpu + from vbach_lib.algorithm.synthesizers import Synthesizer + from vbach_lib.predictors.FCPE import FCPEF0Predictor + from vbach_lib.predictors.RMVPE import RMVPE0Predictor +else: + from .model_manager import VbachModelManager + from .audio import Audio + from .namer import Namer + from .vbach_lib.fairseq import load_model_ensemble_and_task, load_checkpoint_to_cpu + from .vbach_lib.algorithm.synthesizers import Synthesizer + from .vbach_lib.predictors.FCPE import FCPEF0Predictor + from .vbach_lib.predictors.RMVPE import RMVPE0Predictor + + +model_manager = VbachModelManager() +audio = Audio() +namer = Namer() + +RMVPE_DIR = model_manager.rmvpe_path +FCPE_DIR = model_manager.fcpe_path + +def remove_center(input_array, samplerate, rdf=0.99999, window_size=2048, overlap=2, window_type="blackman"): + + left = input_array[0] + right = input_array[1] + + nperseg = min(window_size, len(left)) + if nperseg < 16: + nperseg = 16 + if len(left) < 16: + import warnings + warnings.warn(f"Input too short ({len(left)} samples), returning original audio") + return left, right, left, right + + noverlap = nperseg // overlap + if noverlap >= nperseg: + noverlap = nperseg - 1 + + f, t, Z_left = signal.stft(left, fs=samplerate, nperseg=nperseg, noverlap=noverlap, window=window_type) + f, t, Z_right = signal.stft(right, fs=samplerate, nperseg=nperseg, noverlap=noverlap, window=window_type) + + Z_common_left = np.minimum(np.abs(Z_left), np.abs(Z_right)) * np.exp(1j*np.angle(Z_right)) + Z_common_right = np.minimum(np.abs(Z_left), np.abs(Z_right)) * np.exp(1j*np.angle(Z_left)) + + reduction_factor = rdf + + Z_new_left = Z_left - Z_common_left * reduction_factor + Z_new_right = Z_right - Z_common_right * reduction_factor + + _, new_left = signal.istft(Z_new_left, fs=samplerate, nperseg=nperseg, noverlap=noverlap, window=window_type) + _, new_right = signal.istft(Z_new_right, fs=samplerate, nperseg=nperseg, noverlap=noverlap, window=window_type) + _, common_signal_left = signal.istft(Z_common_left, fs=samplerate, nperseg=nperseg, noverlap=noverlap, window=window_type) + _, common_signal_right = signal.istft(Z_common_right, fs=samplerate, nperseg=nperseg, noverlap=noverlap, window=window_type) + + new_left = new_left[:len(left)] + new_right = new_right[:len(right)] + common_signal_left = common_signal_left[:len(left)] + common_signal_right = common_signal_right[:len(left)] + + peak = np.max([np.abs(new_left).max(), np.abs(new_right).max()]) + if peak > 1.0: + new_left = new_left / peak + new_right = new_right / peak + + inverted_center_left = -common_signal_left + inverted_center_right = -common_signal_right + + mixed_left = left + inverted_center_left + mixed_right = right + inverted_center_right + + peak_mixed = np.max([np.abs(mixed_left).max(), np.abs(mixed_right).max()]) + if peak_mixed > 1.0: + mixed_left = mixed_left / peak_mixed + mixed_right = mixed_right / peak_mixed + + return common_signal_left, common_signal_right, new_left, new_right + +class AudioProcessor: + @staticmethod + def change_rms(sourceaudio, source_rate, targetaudio, target_rate, rate): + """ + Изменяет RMS (среднеквадратичное значение) аудио. + """ + rms1 = librosa.feature.rms( + y=sourceaudio, + frame_length=source_rate // 2 * 2, + hop_length=source_rate // 2, + ) + rms2 = librosa.feature.rms( + y=targetaudio, + frame_length=target_rate // 2 * 2, + hop_length=target_rate // 2, + ) + + rms1 = F.interpolate( + torch.from_numpy(rms1).float().unsqueeze(0), + size=targetaudio.shape[0], + mode="linear", + ).squeeze() + rms2 = F.interpolate( + torch.from_numpy(rms2).float().unsqueeze(0), + size=targetaudio.shape[0], + mode="linear", + ).squeeze() + rms2 = torch.maximum(rms2, torch.zeros_like(rms2) + 1e-6) + + adjustedaudio = ( + targetaudio * (torch.pow(rms1, 1 - rate) * torch.pow(rms2, rate - 1)).numpy() + ) + return adjustedaudio + + +# Класс для преобразования голоса +class VC: + def __init__(self, tgt_sr, config): + """ + Инициализация параметров для преобразования голоса. + """ + self.x_pad = config.x_pad + self.x_query = config.x_query + self.x_center = config.x_center + self.x_max = config.x_max + self.is_half = config.is_half + self.sample_rate = 16000 + self.window = 160 + self.t_pad = self.sample_rate * self.x_pad + self.t_pad_tgt = tgt_sr * self.x_pad + self.t_pad2 = self.t_pad * 2 + self.t_query = self.sample_rate * self.x_query + self.t_center = self.sample_rate * self.x_center + self.t_max = self.sample_rate * self.x_max + self.time_step = self.window / self.sample_rate * 1000 + self.device = config.device + + def get_f0_crepe(self, x, f0_min, f0_max, p_len, hop_length, model="full"): + """ + Получает F0 с использованием модели crepe. + """ + x = x.astype(np.float32) + x /= np.quantile(np.abs(x), 0.999) + audio = torch.from_numpy(x).to(self.device, copy=True).unsqueeze(0) + if audio.ndim == 2 and audio.shape[0] > 1: + audio = torch.mean(audio, dim=0, keepdim=True) + + pitch = torchcrepe.predict( + audio, + self.sample_rate, + hop_length, + f0_min, + f0_max, + model, + batch_size=hop_length * 2, + device=self.device, + pad=True, + ) + + p_len = p_len or x.shape[0] // hop_length + source = np.array(pitch.squeeze(0).cpu().float().numpy()) + source[source < 0.001] = np.nan + target = np.interp( + np.arange(0, len(source) * p_len, len(source)) / p_len, + np.arange(0, len(source)), + source, + ) + f0 = np.nan_to_num(target) + return f0 + + def get_f0_rmvpe(self, x, f0_min=1, f0_max=40000, *args, **kwargs): + """ + Получает F0 с использованием модели rmvpe. + """ + if not hasattr(self, "model_rmvpe"): + self.model_rmvpe = RMVPE0Predictor( + RMVPE_DIR, is_half=self.is_half, device=self.device + ) + f0 = self.model_rmvpe.infer_from_audio_with_pitch( + x, thred=0.03, f0_min=f0_min, f0_max=f0_max + ) + return f0 + + def get_f0( + self, + inputaudio_path, + x, + p_len, + pitch, + f0_method, + filter_radius, + hop_length, + inp_f0=None, + f0_min=50, + f0_max=1100, + ): + """ + Получает F0 с использованием выбранного метода. + """ + global inputaudio_path2wav + f0_mel_min = 1127 * np.log(1 + f0_min / 700) + f0_mel_max = 1127 * np.log(1 + f0_max / 700) + + if f0_method == "mangio-crepe": + f0 = self.get_f0_crepe(x, f0_min, f0_max, p_len, int(hop_length)) + + elif f0_method == "rmvpe+": + params = { + "x": x, + "p_len": p_len, + "pitch": pitch, + "f0_min": f0_min, + "f0_max": f0_max, + "time_step": self.time_step, + "filter_radius": filter_radius, + "crepe_hop_length": int(hop_length), + "model": "full", + } + f0 = self.get_f0_rmvpe(**params) + + elif f0_method == "fcpe": + self.model_fcpe = FCPEF0Predictor( + FCPE_DIR, + f0_min=int(f0_min), + f0_max=int(f0_max), + dtype=torch.float32, + device=self.device, + sample_rate=self.sample_rate, + threshold=0.03, + ) + f0 = self.model_fcpe.compute_f0(x, p_len=p_len) + del self.model_fcpe + gc.collect() + + f0 *= pow(2, pitch / 12) + tf0 = self.sample_rate // self.window + if inp_f0 is not None: + delta_t = np.round( + (inp_f0[:, 0].max() - inp_f0[:, 0].min()) * tf0 + 1 + ).astype("int16") + replace_f0 = np.interp(list(range(delta_t)), inp_f0[:, 0] * 100, inp_f0[:, 1]) + shape = f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)].shape[0] + f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)] = replace_f0[:shape] + + f0bak = f0.copy() + f0_mel = 1127 * np.log(1 + f0 / 700) + f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / ( + f0_mel_max - f0_mel_min + ) + 1 + f0_mel[f0_mel <= 1] = 1 + f0_mel[f0_mel > 255] = 255 + f0_coarse = np.rint(f0_mel).astype(int) + return f0_coarse, f0bak + + def vc( + self, + model, + net_g, + sid, + audio0, + pitch, + pitchf, + index, + big_npy, + index_rate, + version, + protect, + ): + """ + Преобразует аудио с использованием модели. + """ + feats = torch.from_numpy(audio0) + feats = feats.half() if self.is_half else feats.float() + if feats.dim() == 2: + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False) + + inputs = { + "source": feats.to(self.device), + "padding_mask": padding_mask, + "output_layer": 9 if version == "v1" else 12, + } + + with torch.no_grad(): + logits = model.extract_features(**inputs) + feats = model.final_proj(logits[0]) if version == "v1" else logits[0] + if protect < 0.5 and pitch is not None and pitchf is not None: + feats0 = feats.clone() + if index is not None and big_npy is not None and index_rate != 0: + npy = feats[0].cpu().numpy() + npy = npy.astype("float32") if self.is_half else npy + score, ix = index.search(npy, k=8) + weight = np.square(1 / score) + weight /= weight.sum(axis=1, keepdims=True) + npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) + npy = npy.astype("float16") if self.is_half else npy + feats = ( + torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate + + (1 - index_rate) * feats + ) + + feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) + if protect < 0.5 and pitch is not None and pitchf is not None: + feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute( + 0, 2, 1 + ) + p_len = audio0.shape[0] // self.window + if feats.shape[1] < p_len: + p_len = feats.shape[1] + if pitch is not None and pitchf is not None: + pitch = pitch[:, :p_len] + pitchf = pitchf[:, :p_len] + + if protect < 0.5 and pitch is not None and pitchf is not None: + pitchff = pitchf.clone() + pitchff[pitchf > 0] = 1 + pitchff[pitchf < 1] = protect + pitchff = pitchff.unsqueeze(-1) + feats = feats * pitchff + feats0 * (1 - pitchff) + feats = feats.to(feats0.dtype) + p_len = torch.tensor([p_len], device=self.device).long() + with torch.no_grad(): + if pitch is not None and pitchf is not None: + audio1 = ( + (net_g.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0]) + .data.cpu() + .float() + .numpy() + ) + else: + audio1 = ( + (net_g.infer(feats, p_len, sid)[0][0, 0]).data.cpu().float().numpy() + ) + del feats, p_len, padding_mask + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return audio1 + + def pipeline( + self, + model, + net_g, + sid, + audio, + inputaudio_path, + pitch, + f0_method, + file_index, + index_rate, + pitch_guidance, + filter_radius, + tgt_sr, + resample_sr, + volume_envelope, + version, + protect, + hop_length, + f0_file, + f0_min=50, + f0_max=1100, + ): + """ + Основной конвейер для преобразования аудио. + """ + if ( + file_index is not None + and file_index != "" + and os.path.exists(file_index) + and index_rate != 0 + ): + try: + index = faiss.read_index(file_index) + big_npy = index.reconstruct_n(0, index.ntotal) + except Exception as e: + print(f"Произошла ошибка при чтении индекса FAISS: {e}") + index = big_npy = None + else: + index = big_npy = None + audio = signal.filtfilt(bh, ah, audio) + audio_pad = np.pad(audio, (self.window // 2, self.window // 2), mode="reflect") + opt_ts = [] + if audio_pad.shape[0] > self.t_max: + audio_sum = np.zeros_like(audio) + for i in range(self.window): + audio_sum += audio_pad[i : i - self.window] + for t in range(self.t_center, audio.shape[0], self.t_center): + opt_ts.append( + t + - self.t_query + + np.where( + np.abs(audio_sum[t - self.t_query : t + self.t_query]) + == np.abs(audio_sum[t - self.t_query : t + self.t_query]).min() + )[0][0] + ) + s = 0 + audio_opt = [] + t = None + audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect") + p_len = audio_pad.shape[0] // self.window + inp_f0 = None + if f0_file and hasattr(f0_file, "name"): + try: + with open(f0_file.name, "r") as f: + lines = f.read().strip("\n").split("\n") + inp_f0 = np.array( + [[float(i) for i in line.split(",")] for line in lines], + dtype="float32", + ) + except Exception as e: + print(f"Произошла ошибка при чтении файла F0: {e}") + sid = torch.tensor(sid, device=self.device).unsqueeze(0).long() + if pitch_guidance: + pitch, pitchf = self.get_f0( + inputaudio_path, + audio_pad, + p_len, + pitch, + f0_method, + filter_radius, + hop_length, + inp_f0, + f0_min, + f0_max, + ) + pitch = pitch[:p_len] + pitchf = pitchf[:p_len] + if self.device == "mps": + pitchf = pitchf.astype(np.float32) + pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long() + pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float() + for t in opt_ts: + t = t // self.window * self.window + if pitch_guidance: + audio_opt.append( + self.vc( + model, + net_g, + sid, + audio_pad[s : t + self.t_pad2 + self.window], + pitch[:, s // self.window : (t + self.t_pad2) // self.window], + pitchf[:, s // self.window : (t + self.t_pad2) // self.window], + index, + big_npy, + index_rate, + version, + protect, + )[self.t_pad_tgt : -self.t_pad_tgt] + ) + else: + audio_opt.append( + self.vc( + model, + net_g, + sid, + audio_pad[s : t + self.t_pad2 + self.window], + None, + None, + index, + big_npy, + index_rate, + version, + protect, + )[self.t_pad_tgt : -self.t_pad_tgt] + ) + s = t + if pitch_guidance: + audio_opt.append( + self.vc( + model, + net_g, + sid, + audio_pad[t:], + pitch[:, t // self.window :] if t is not None else pitch, + pitchf[:, t // self.window :] if t is not None else pitchf, + index, + big_npy, + index_rate, + version, + protect, + )[self.t_pad_tgt : -self.t_pad_tgt] + ) + else: + audio_opt.append( + self.vc( + model, + net_g, + sid, + audio_pad[t:], + None, + None, + index, + big_npy, + index_rate, + version, + protect, + )[self.t_pad_tgt : -self.t_pad_tgt] + ) + + audio_opt = np.concatenate(audio_opt) + if volume_envelope != 1: + audio_opt = AudioProcessor.change_rms( + audio, self.sample_rate, audio_opt, tgt_sr, volume_envelope + ) + if resample_sr >= self.sample_rate and tgt_sr != resample_sr: + audio_opt = librosa.resample(audio_opt, orig_sr=tgt_sr, target_sr=resample_sr) + + audio_max = np.abs(audio_opt).max() / 0.99 + max_int16 = 32768 + if audio_max > 1: + max_int16 /= audio_max + audio_opt = (audio_opt * max_int16).astype(np.int16) + + del pitch, pitchf, sid + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return audio_opt + +def overlay_mono_on_stereo(monoaudio, stereoaudio, gain=0.5): + if monoaudio is None or stereoaudio is None: + raise ValueError("Input audio arrays cannot be None") + + # Ensure float32 for processing + monoaudio = monoaudio.astype(np.float32) + stereoaudio = stereoaudio.astype(np.float32) + + # Convert mono to stereo if needed + if monoaudio.ndim == 1: + monoaudio = np.vstack([monoaudio, monoaudio]) + elif monoaudio.shape[0] == 1: + monoaudio = np.vstack([monoaudio[0], monoaudio[0]]) + + if monoaudio.shape[0] != 2 or stereoaudio.shape[0] != 2: + raise ValueError("Shapes must be (2, N)") + + min_len = min(monoaudio.shape[1], stereoaudio.shape[1]) + if min_len == 0: + raise ValueError("Audio arrays cannot be empty") + + monoaudio = monoaudio[:, :min_len] + stereoaudio = stereoaudio[:, :min_len] + + result = stereoaudio + monoaudio * gain + + # Normalize to prevent clipping + max_amp = np.max(np.abs(result)) + if max_amp > 0: + result /= max_amp + + # Convert back to int16 for output (if needed) + result = (result * 32767).astype(np.int16) + + return result + +def loadaudio( + file_path: str, + target_sr: int, + stereo_mode: str +) -> np.ndarray: + """ + Загружает аудиофайл с помощью librosa, обрабатывает и возвращает аудиосигнал + + Параметры: + file_path: Путь к аудиофайлу + target_sr: Целевая частота дискретизации + mono: Преобразовать в моно (по умолчанию True) + normalize: Нормализовать аудио (по умолчанию False) + duration: Загрузить только указанную длительность (в секундах) + offset: Начальное смещение для загрузки (в секундах) + + Возвращает: + Аудиоданные в виде numpy array (моно: (samples,), стерео: (channels, samples)) + + Исключения: + RuntimeError: При ошибках загрузки или обработки аудио + """ + try: + mid, left, right = None, None, None + + if stereo_mode == "mono": + # Загрузка аудио с помощью librosa + midaudio, sr, _ = audio.read( + i=file_path, + sr=None, + mono=True + ) + midaudio = librosa.resample( + midaudio, # Исправлено: было audio + orig_sr=sr, + target_sr=target_sr + ) + mid = midaudio.flatten() + + elif stereo_mode == "left/right" or stereo_mode == "sim/dif": + # Загрузка аудио с помощью librosa + stereoaudio, sr, _ = audio.read( + i=file_path, + sr=None, + mono=False + ) + + if stereo_mode == "left/right": + leftaudio = stereoaudio[0] # Исправлено: было [:, 0] + rightaudio = stereoaudio[1] # Исправлено: было [:, 1] + leftaudio = librosa.resample( + leftaudio, + orig_sr=sr, + target_sr=target_sr + ) + rightaudio = librosa.resample( + rightaudio, + orig_sr=sr, + target_sr=target_sr + ) + + left = leftaudio.flatten() + right = rightaudio.flatten() + + elif stereo_mode == "sim/dif": + mid_left, mid_right, dif_left, dif_right = remove_center(input_array=stereoaudio, samplerate=sr) + midaudio = (mid_left + mid_right) * 0.5 + + midaudio = librosa.resample( + midaudio, + orig_sr=sr, + target_sr=target_sr + ) + dif_left = librosa.resample( + dif_left, + orig_sr=sr, + target_sr=target_sr + ) + dif_right = librosa.resample( + dif_right, + orig_sr=sr, + target_sr=target_sr + ) + + mid = midaudio.flatten() + left = dif_left.flatten() # Исправлено: было leftaudio + right = dif_right.flatten() # Исправлено: было rightaudio + + return mid, left, right + + except Exception as e: + raise RuntimeError(f"Ошибка загрузки аудио '{file_path}': {str(e)}") + +class Config: + def __init__(self): + self.device = self.get_device() + self.is_half = self.device == "cpu" + self.n_cpu = cpu_count() + self.gpu_name = None + self.gpu_mem = None + self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config() + + def get_device(self): + if torch.cuda.is_available(): + return "cuda" + elif torch.backends.mps.is_available(): + return "mps" + else: + return "cpu" + + def device_config(self): + if torch.cuda.is_available(): + print("Используется устройство CUDA") + self._configure_gpu() + elif torch.backends.mps.is_available(): + print("Используется устройство MPS") + self.device = "mps" + else: + print("Используется CPU") + self.device = "cpu" + self.is_half = True + + x_pad, x_query, x_center, x_max = ( + (3, 10, 60, 65) if self.is_half else (1, 6, 38, 41) + ) + if self.gpu_mem is not None and self.gpu_mem <= 4: + x_pad, x_query, x_center, x_max = (1, 5, 30, 32) + + return x_pad, x_query, x_center, x_max + + def _configure_gpu(self): + self.gpu_name = torch.cuda.get_device_name(self.device) + low_end_gpus = ["16", "P40", "P10", "1060", "1070", "1080"] + if ( + any(gpu in self.gpu_name for gpu in low_end_gpus) + and "V100" not in self.gpu_name.upper() + ): + self.is_half = False + self.gpu_mem = int( + torch.cuda.get_device_properties(self.device).total_memory + / 1024 + / 1024 + / 1024 + + 0.4 + ) + +# Загрузка модели Hubert +def load_hubert(device, is_half, model_path): + models, saved_cfg, task = load_model_ensemble_and_task( + [model_path], suffix="" + ) + hubert = models[0].to(device) + hubert = hubert.half() if is_half else hubert.float() + hubert.eval() + return hubert + +# Получение голосового преобразователя +def get_vc(device, is_half, config, model_path): + cpt = torch.load(model_path, map_location="cpu", weights_only=False) + if "config" not in cpt or "weight" not in cpt: + raise ValueError( + f"Некорректный формат для {model_path}. " + "Используйте голосовую модель, обученную с использованием RVC v2." + ) + + tgt_sr = cpt["config"][-1] + cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] + pitch_guidance = cpt.get("f0", 1) + version = cpt.get("version", "v1") + input_dim = 768 if version == "v2" else 256 + + net_g = Synthesizer( + *cpt["config"], + use_f0=pitch_guidance, + input_dim=input_dim, + is_half=is_half, + ) + + del net_g.enc_q + print(net_g.load_state_dict(cpt["weight"], strict=False)) + net_g.eval().to(device) + net_g = net_g.half() if is_half else net_g.float() + + vc = VC(tgt_sr, config) + return cpt, version, net_g, tgt_sr, vc + +def rvc_infer( + index_path, + index_rate, + input_path, + output_path, + pitch, + f0_method, + cpt, + version, + net_g, + filter_radius, + tgt_sr, + volume_envelope, + protect, + hop_length, + vc, + hubert_model, + f0_min=50, + f0_max=1100, + format_output="wav", + output_bitrate="320k", + stereo_mode="mono" +) -> str: + + mid, left, right = loadaudio(input_path, 16000, stereo_mode) + pitch_guidance = cpt.get("f0", 1) + + if stereo_mode == "mono": + if mid is None: + raise ValueError("Mono audio data is None") + audio_opt = vc.pipeline( + hubert_model, + net_g, + 0, + mid, + input_path, + pitch, + f0_method, + index_path, + index_rate, + pitch_guidance, + filter_radius, + tgt_sr, + 0, + volume_envelope, + version, + protect, + hop_length, + f0_file=None, + f0_min=f0_min, + f0_max=f0_max, + ) + + elif stereo_mode == "left/right": + if left is None or right is None: + raise ValueError("Left or right audio channel is None") + + leftaudio_opt = vc.pipeline( + hubert_model, + net_g, + 0, + left, + input_path, + pitch, + f0_method, + index_path, + index_rate, + pitch_guidance, + filter_radius, + tgt_sr, + 0, + volume_envelope, + version, + protect, + hop_length, + f0_file=None, + f0_min=f0_min, + f0_max=f0_max, + ) + rightaudio_opt = vc.pipeline( + hubert_model, + net_g, + 0, + right, + input_path, + pitch, + f0_method, + index_path, + index_rate, + pitch_guidance, + filter_radius, + tgt_sr, + 0, + volume_envelope, + version, + protect, + hop_length, + f0_file=None, + f0_min=f0_min, + f0_max=f0_max, + ) + + # Ensure both channels have the same length + min_len = min(len(leftaudio_opt), len(rightaudio_opt)) + if min_len == 0: + raise ValueError("Processed audio is empty") + + leftaudio_opt = leftaudio_opt[:min_len] + rightaudio_opt = rightaudio_opt[:min_len] + + audio_opt = np.stack((leftaudio_opt, rightaudio_opt), axis=0) + + elif stereo_mode == "sim/dif": + if mid is None or left is None or right is None: + raise ValueError("Mid, left or right audio channel is None") + + midaudio_opt = vc.pipeline( + hubert_model, + net_g, + 0, + mid, + input_path, + pitch, + f0_method, + index_path, + index_rate, + pitch_guidance, + filter_radius, + tgt_sr, + 0, + volume_envelope, + version, + protect, + hop_length, + f0_file=None, + f0_min=f0_min, + f0_max=f0_max, + ) + leftaudio_opt = vc.pipeline( + hubert_model, + net_g, + 0, + left, + input_path, + pitch, + f0_method, + index_path, + index_rate, + pitch_guidance, + filter_radius, + tgt_sr, + 0, + volume_envelope, + version, + protect, + hop_length, + f0_file=None, + f0_min=f0_min, + f0_max=f0_max, + ) + rightaudio_opt = vc.pipeline( + hubert_model, + net_g, + 0, + right, + input_path, + pitch, + f0_method, + index_path, + index_rate, + pitch_guidance, + filter_radius, + tgt_sr, + 0, + volume_envelope, + version, + protect, + hop_length, + f0_file=None, + f0_min=f0_min, + f0_max=f0_max, + ) + + # Ensure all channels have the same length + min_len = min(len(midaudio_opt), len(leftaudio_opt), len(rightaudio_opt)) + if min_len == 0: + raise ValueError("Processed audio is empty") + + midaudio_opt = midaudio_opt[:min_len] + leftaudio_opt = leftaudio_opt[:min_len] + rightaudio_opt = rightaudio_opt[:min_len] + + difaudio_opt = np.stack((leftaudio_opt, rightaudio_opt), axis=0) + + audio_opt = overlay_mono_on_stereo(midaudio_opt, difaudio_opt) + + output_path = audio.write(o=output_path, array=audio_opt, sr=tgt_sr, of=format_output, br=output_bitrate) + return output_path + + +def load_rvc_model(voice_model): + + if voice_model in model_manager.parse_voice_models(): + rvc_model_path, rvc_index_path = model_manager.parse_pth_and_index(voice_model) + + if not rvc_model_path: + raise ValueError( + f"Файла для модели {voice_model} не существует. " + "Возможно, вы неправильно её установили." + ) + + else: + raise ValueError( + f"Модели {voice_model} не существует. " + "Возможно, вы неправильно ввели имя." + ) + + return rvc_model_path, rvc_index_path + +def voice_conversion( + voice_model, + vocals_path, + output_path, + pitch, + f0_method, + index_rate, + filter_radius, + volume_envelope, + protect, + hop_length, + f0_min, + f0_max, + format_output, + output_bitrate, + stereo_mode, + hubert_path=None +): + rvc_model_path, rvc_index_path = load_rvc_model(voice_model) + + config = Config() + hubert_model = load_hubert(config.device, config.is_half, hubert_path if hubert_path else model_manager.hubert_path) + cpt, version, net_g, tgt_sr, vc = get_vc( + config.device, config.is_half, config, rvc_model_path + ) + + outputaudio = rvc_infer( + rvc_index_path, + index_rate, + vocals_path, + output_path, + pitch, + f0_method, + cpt, + version, + net_g, + filter_radius, + tgt_sr, + volume_envelope, + protect, + hop_length, + vc, + hubert_model, + f0_min, + f0_max, + format_output, + output_bitrate, + stereo_mode + ) + + del hubert_model, cpt, net_g, vc + gc.collect() + torch.cuda.empty_cache() + return outputaudio + +def vbach_inference( + input_file: str, + model_name: str, + output_dir: str, + output_name: str, + output_format: str, + output_bitrate: str | int, + pitch: int, + method_pitch: str, + format_name: bool = False, + add_params: dict = { + "index_rate": 0, + "filter_radius": 3, + "protect": 0.33, + "rms": 0.25, + "mangio_crepe_hop_length": 128, + "f0_min": 50, + "f0_max": 1100, + "stereo_mode": "mono" + } + ): + stereo_mode = add_params.get("stereo_mode", "mono") + index_rate = add_params.get("index_rate", 0) + filter_radius = add_params.get("filter_radius", 3) + protect = add_params.get("protect", 0.33) + rms = add_params.get("rms", 0.25) + mangio_crepe_hop_length = add_params.get("mangio_crepe_hop_length", 0) + f0_min = add_params.get("f0_min", 50) + f0_max = add_params.get("f0_max", 1100) + if not input_file: + raise ValueError("Входной файл не указан") + if not os.path.exists(input_file): + raise ValueError(f"Входной файл не найден: {input}") + if not audio.check(input_file): + raise ValueError("Входной файл не содержит аудио") + basename = os.path.splitext(os.path.basename(input_file))[0] + + final_output_name = None + + print("Инференс запущен") + + if format_name: + cleaned_output_name_template = namer.sanitize(namer.dedup_template(output_name, keys=["NAME", "MODEL", "F0METHOD", "PITCH"])) + short_basename = namer.short_input_name_template(cleaned_output_name_template, MODEL=model_name, F0METHOD=method_pitch, PITCH=pitch, NAME=basename) + final_output_name = namer.template(cleaned_output_name_template, MODEL=model_name, F0METHOD=method_pitch, PITCH=pitch, NAME=short_basename) + + else: + final_output_name = output_name + + final_output_path = os.path.join(output_dir, f"{final_output_name}.{output_format}") + output_converted_voice = voice_conversion(voice_model=model_name, vocals_path=input_file, output_path=final_output_path, pitch=pitch, f0_method=method_pitch, index_rate=index_rate, filter_radius=filter_radius, volume_envelope=rms, protect=protect, hop_length=mangio_crepe_hop_length, f0_min=f0_min, f0_max=f0_max, format_output=output_format, hubert_path=None, output_bitrate=output_bitrate, stereo_mode=stereo_mode) + print(f"Инференс завершен\nПуть к выходному файлу: \"{output_converted_voice}\"") + return output_converted_voice + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Vbach - форк Polgen-RVC 1.2.0" + ) + parser.add_argument("--input", type=str, help="Путь к входному файлу или папке") + parser.add_argument( + "--output_dir", type=str, required=True, help="Путь для сохранения результатов" + ) + parser.add_argument( + "--output_format", + type=str, + default="wav", + choices=audio.output_formats, + help="Формат выходных файлов", + ) + parser.add_argument( + "--output_bitrate", type=str, default="320k", help="Битрейт выходного файла" + ) + parser.add_argument( + "--format_name", + action="store_true", + help="Форматировать имя выходного файла", + ) + parser.add_argument( + "--output_name", + type=str, + default="NAME_STEM", + help="Имя выходного файла", + ) + parser.add_argument( + "--model_name", + type=str, + default="model", + help="Имя голосовой модели", + ) + + parser.add_argument( + '--index_rate', + type=float, + default=0, + help='Интенсивность использования индексного файла (от 0.0 до 1.0)', + metavar='[0.0-1.0]' + ) + parser.add_argument( + '--stereo_mode', + type=str, + default="mono", + choices=["mono", "left/right", "sim/dif"], + help='Режим каналов: моно или стерео' + ) + parser.add_argument( + '--method_pitch', + type=str, + default="rmvpe+", + help='Метод извлечения pitch (тона)' + ) + parser.add_argument( + '--pitch', + type=int, + default=0, + help='Корректировка тона в полутонах' + ) + parser.add_argument( + '--hop_length', + type=int, + default=128, + help='Длина hop (в семплах) для обработки' + ) + parser.add_argument( + '--filter_radius', + type=int, + default=3, + help='Радиус фильтра для сглаживания' + ) + parser.add_argument( + '--rms', + type=float, + default=0.25, + help='Масштабирование огибающей громкости (RMS)' + ) + parser.add_argument( + '--protect', + type=float, + default=0.33, + help='Защита для глухих согласных звуков' + ) + parser.add_argument( + '--f0_min', + type=int, + default=50, + help='Минимальная частота pitch (F0) в Hz' + ) + parser.add_argument( + '--f0_max', + type=int, + default=1100, + help='Максимальная частота pitch (F0) в Hz' + ) + + args = parser.parse_args() + + if args.input: + if os.path.exists(args.input) and os.path.isfile(args.input): + if audio.check(args.input): + vbach_inference(input_file=args.input, model_name=args.model_name, output_dir=args.output_dir, output_name=args.output_name, output_bitrate=args.output_bitrate, output_format=args.output_format, pitch=args.pitch, method_pitch=args.method_pitch, format_name=args.format_name, add_params={ "index_rate": args.index_rate,"filter_radius": args.filter_radius,"protect": args.protect,"rms": args.rms,"mangio_crepe_hop_length": args.hop_length,"f0_min": args.f0_min,"f0_max": args.f0_max,"stereo_mode": args.stereo_mode}) + elif os.path.exists(args.input) and os.path.isdir(args.input): + list_valid_files = [] + for file in os.listdir(args.input): + if os.path.isfile(os.path.join(args.input, file)): + if audio.check(os.path.join(args.input, file)): + list_valid_files.append(os.path.join(args.input, file)) + if list_valid_files: + for i, vocals_file in enumerate(list_valid_files, start=1): + print(f"Файл {i} из {len(list_valid_files)}: {vocals_file}") + vbach_inference(input_file=vocals_file, model_name=args.model_name, output_dir=args.output_dir, output_name=args.output_name, output_bitrate=args.output_bitrate, output_format=args.output_format, pitch=args.pitch, method_pitch=args.method_pitch, format_name=True if len(list_valid_files) > 1 else args.format_name, add_params={ "index_rate": args.index_rate,"filter_radius": args.filter_radius,"protect": args.protect,"rms": args.rms,"mangio_crepe_hop_length": args.hop_length,"f0_min": args.f0_min,"f0_max": args.f0_max,"stereo_mode": args.stereo_mode}) diff --git a/mvsepless/vbach_lib/algorithm/__init__.py b/mvsepless/vbach_lib/algorithm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b0dac65653c4ac827e44864913497651c8434874 --- /dev/null +++ b/mvsepless/vbach_lib/algorithm/__init__.py @@ -0,0 +1,2 @@ + + \ No newline at end of file diff --git a/mvsepless/vbach_lib/algorithm/attentions.py b/mvsepless/vbach_lib/algorithm/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..2424e79fbba2a0b41f965f9bb4038ecda0fbc848 --- /dev/null +++ b/mvsepless/vbach_lib/algorithm/attentions.py @@ -0,0 +1,224 @@ + +import math +import torch +from torch import nn +from torch.nn import functional as F + +from .commons import convert_pad_shape + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + p_dropout=0.0, + window_size=None, + heads_share=True, + block_length=None, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + self.emb_rel_v = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + assert t_s == t_t, "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys( + query / math.sqrt(self.k_channels), key_relative_embeddings + ) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to( + device=scores.device, dtype=scores.dtype + ) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + assert t_s == t_t, "Local attention is only available for self-attention." + block_mask = ( + torch.ones_like(scores) + .triu(-self.block_length) + .tril(self.block_length) + ) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values( + relative_weights, value_relative_embeddings + ) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + ) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[ + :, slice_start_position:slice_end_position + ] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + batch, heads, length, _ = x.size() + + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) + + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) + + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ + :, :, :length, length - 1 : + ] + return x_final + + def _absolute_position_to_relative_position(self, x): + batch, heads, length, _ = x.size() + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__( + self, + in_channels, + out_channels, + filter_channels, + kernel_size, + p_dropout=0.0, + activation=None, + causal=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, convert_pad_shape(padding)) + return x + + \ No newline at end of file diff --git a/mvsepless/vbach_lib/algorithm/commons.py b/mvsepless/vbach_lib/algorithm/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..7ed8615d1ac464ba2d58f97e78f1af0cbe33965e --- /dev/null +++ b/mvsepless/vbach_lib/algorithm/commons.py @@ -0,0 +1,114 @@ + +import math +import torch +from torch.nn import functional as F +from typing import List, Optional + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + kl = (logs_q - logs_p) - 0.5 + kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + return kl + + +def slice_segments( + x: torch.Tensor, ids_str: torch.Tensor, segment_size: int = 4, dim: int = 2 +): + if dim == 2: + ret = torch.zeros_like(x[:, :segment_size]) + elif dim == 3: + ret = torch.zeros_like(x[:, :, :segment_size]) + + for i in range(x.size(0)): + idx_str = ids_str[i].item() + idx_end = idx_str + segment_size + if dim == 2: + ret[i] = x[i, idx_str:idx_end] + else: + ret[i] = x[i, :, idx_str:idx_end] + + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size, dim=3) + return ret, ids_str + + +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( + num_timescales - 1 + ) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def clip_grad_value(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = List(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm + + \ No newline at end of file diff --git a/mvsepless/vbach_lib/algorithm/discriminators.py b/mvsepless/vbach_lib/algorithm/discriminators.py new file mode 100644 index 0000000000000000000000000000000000000000..16714ef30199eb1b2bdadda63f8633e9145e84ce --- /dev/null +++ b/mvsepless/vbach_lib/algorithm/discriminators.py @@ -0,0 +1,128 @@ + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils.parametrizations import spectral_norm, weight_norm + +from .commons import get_padding +from .residuals import LRELU_SLOPE + + +PERIODS_V1 = [2, 3, 5, 7, 11, 17] +PERIODS_V2 = [2, 3, 5, 7, 11, 17, 23, 37] +IN_CHANNELS = [1, 32, 128, 512, 1024] +OUT_CHANNELS = [32, 128, 512, 1024, 1024] + + +class MultiPeriodDiscriminator(nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList( + [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + + [DiscriminatorP(p, use_spectral_norm=use_spectral_norm) for p in PERIODS_V1] + ) + + def forward(self, y, y_hat): + y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], [] + for d in self.discriminators: + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class MultiPeriodDiscriminatorV2(nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminatorV2, self).__init__() + self.discriminators = nn.ModuleList( + [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + + [DiscriminatorP(p, use_spectral_norm=use_spectral_norm) for p in PERIODS_V2] + ) + + def forward(self, y, y_hat): + y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], [] + for d in self.discriminators: + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = spectral_norm if use_spectral_norm else weight_norm + self.convs = nn.ModuleList( + [ + norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)), + norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1)) + self.lrelu = nn.LeakyReLU(LRELU_SLOPE) + + def forward(self, x): + fmap = [] + for conv in self.convs: + x = self.lrelu(conv(x)) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + return x, fmap + + +class DiscriminatorP(nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = spectral_norm if use_spectral_norm else weight_norm + + self.convs = nn.ModuleList( + [ + norm_f( + nn.Conv2d( + in_ch, + out_ch, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ) + for in_ch, out_ch in zip(IN_CHANNELS, OUT_CHANNELS) + ] + ) + + self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + self.lrelu = nn.LeakyReLU(LRELU_SLOPE) + + def forward(self, x): + fmap = [] + b, c, t = x.shape + if t % self.period != 0: + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + x = x.view(b, c, -1, self.period) + + for conv in self.convs: + x = self.lrelu(conv(x)) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + return x, fmap + + \ No newline at end of file diff --git a/mvsepless/vbach_lib/algorithm/encoders.py b/mvsepless/vbach_lib/algorithm/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..61e273a6000aa15adf5890a6f2bdec2fd27fdfd5 --- /dev/null +++ b/mvsepless/vbach_lib/algorithm/encoders.py @@ -0,0 +1,183 @@ + +import math +import torch +from torch import nn +from torch.nn.utils.weight_norm import remove_weight_norm +from typing import Optional + +from .attentions import FFN, MultiHeadAttention +from .commons import sequence_mask +from .modules import WaveNet +from .normalization import LayerNorm + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + window_size=10, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + window_size=window_size, + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class TextEncoder(nn.Module): + def __init__( + self, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + embedding_dim, + f0=True, + ): + super(TextEncoder, self).__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = float(p_dropout) + self.emb_phone = nn.Linear(embedding_dim, hidden_channels) + self.lrelu = nn.LeakyReLU(0.1, inplace=True) + if f0: + self.emb_pitch = nn.Embedding(256, hidden_channels) + self.encoder = Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + float(p_dropout), + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward( + self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor + ): + if pitch is None: + x = self.emb_phone(phone) + else: + x = self.emb_phone(phone) + self.emb_pitch(pitch) + x = x * math.sqrt(self.hidden_channels) + x = self.lrelu(x) + x = torch.transpose(x, 1, -1) + x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype) + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return m, logs, x_mask + + +class PosteriorEncoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + ): + super(PosteriorEncoder, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = WaveNet( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward( + self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None + ): + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + def remove_weight_norm(self): + self.enc.remove_weight_norm() + + def __prepare_scriptable__(self): + for hook in self.enc._forward_pre_hooks.values(): + if ( + hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" + and hook.__class__.__name__ == "_WeightNorm" + ): + remove_weight_norm(self.enc) + return self + + \ No newline at end of file diff --git a/mvsepless/vbach_lib/algorithm/generators.py b/mvsepless/vbach_lib/algorithm/generators.py new file mode 100644 index 0000000000000000000000000000000000000000..dfedba9a1a9e5efa0fa85adab8e5899464333d84 --- /dev/null +++ b/mvsepless/vbach_lib/algorithm/generators.py @@ -0,0 +1,159 @@ + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils.weight_norm import remove_weight_norm +from torch.nn.utils.parametrizations import weight_norm +from typing import Optional + +from .commons import init_weights +from .residuals import LRELU_SLOPE, ResBlock1, ResBlock2 + + +class Generator(nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=0, + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = nn.Conv1d( + initial_channel, upsample_initial_channel, 7, 1, padding=3 + ) + resblock = ResBlock1 if resblock == "1" else ResBlock2 + + self.ups_and_resblocks = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups_and_resblocks.append( + weight_norm( + nn.ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.ups_and_resblocks.append(resblock(ch, k, d)) + + self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups_and_resblocks.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + resblock_idx = 0 + for _ in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups_and_resblocks[resblock_idx](x) + resblock_idx += 1 + xs = 0 + for _ in range(self.num_kernels): + xs += self.ups_and_resblocks[resblock_idx](x) + resblock_idx += 1 + x = xs / self.num_kernels + + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def __prepare_scriptable__(self): + for l in self.ups_and_resblocks: + for hook in l._forward_pre_hooks.values(): + if ( + hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" + and hook.__class__.__name__ == "_WeightNorm" + ): + remove_weight_norm(l) + return self + + def remove_weight_norm(self): + for l in self.ups_and_resblocks: + remove_weight_norm(l) + + +class SineGen(nn.Module): + def __init__( + self, + samp_rate, + harmonic_num=0, + sine_amp=0.1, + noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False, + ): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sample_rate = samp_rate + self.voiced_threshold = voiced_threshold + + def _f02uv(self, f0): + uv = torch.ones_like(f0) + uv = uv * (f0 > self.voiced_threshold) + return uv + + def forward(self, f0: torch.Tensor, upp: int): + with torch.no_grad(): + f0 = f0[:, None].transpose(1, 2) + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) + f0_buf[:, :, 0] = f0[:, :, 0] + f0_buf[:, :, 1:] = ( + f0_buf[:, :, 0:1] + * torch.arange(2, self.harmonic_num + 2, device=f0.device)[None, None, :] + ) + rad_values = (f0_buf / float(self.sample_rate)) % 1 + rand_ini = torch.rand(f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + tmp_over_one = torch.cumsum(rad_values, 1) + tmp_over_one *= upp + tmp_over_one = F.interpolate( + tmp_over_one.transpose(2, 1), + scale_factor=float(upp), + mode="linear", + align_corners=True, + ).transpose(2, 1) + rad_values = F.interpolate( + rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest" + ).transpose(2, 1) + tmp_over_one %= 1 + tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + sine_waves = torch.sin( + torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi + ) + sine_waves = sine_waves * self.sine_amp + uv = self._f02uv(f0) + uv = F.interpolate( + uv.transpose(2, 1), scale_factor=float(upp), mode="nearest" + ).transpose(2, 1) + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + \ No newline at end of file diff --git a/mvsepless/vbach_lib/algorithm/modules.py b/mvsepless/vbach_lib/algorithm/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..4717fa4c0a97179cfcf8d4235811bed1650386eb --- /dev/null +++ b/mvsepless/vbach_lib/algorithm/modules.py @@ -0,0 +1,95 @@ + +import torch +from torch import nn +from torch.nn.utils.weight_norm import remove_weight_norm +from torch.nn.utils.parametrizations import weight_norm + +from .commons import fused_add_tanh_sigmoid_multiply + + +class WaveNet(nn.Module): + def __init__( + self, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + p_dropout=0, + ): + super(WaveNet, self).__init__() + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.kernel_size = (kernel_size,) + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = nn.ModuleList() + self.res_skip_layers = nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + if gin_channels != 0: + cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) + self.cond_layer = weight_norm(cond_layer, name="weight") + + dilations = [dilation_rate**i for i in range(n_layers)] + paddings = [(kernel_size * d - d) // 2 for d in dilations] + + for i in range(n_layers): + in_layer = nn.Conv1d( + hidden_channels, + 2 * hidden_channels, + kernel_size, + dilation=dilations[i], + padding=paddings[i], + ) + in_layer = weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + res_skip_channels = ( + hidden_channels if i == n_layers - 1 else 2 * hidden_channels + ) + + res_skip_layer = nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + + acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, : self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + remove_weight_norm(self.cond_layer) + for l in self.in_layers: + remove_weight_norm(l) + for l in self.res_skip_layers: + remove_weight_norm(l) + + \ No newline at end of file diff --git a/mvsepless/vbach_lib/algorithm/normalization.py b/mvsepless/vbach_lib/algorithm/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..0b22833a5a5d0b7b06b620c52f5ebf428d6560aa --- /dev/null +++ b/mvsepless/vbach_lib/algorithm/normalization.py @@ -0,0 +1,18 @@ + +import torch +from torch import nn +from torch.nn import functional as F + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.eps = eps + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (x.size(-1),), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + \ No newline at end of file diff --git a/mvsepless/vbach_lib/algorithm/nsf.py b/mvsepless/vbach_lib/algorithm/nsf.py new file mode 100644 index 0000000000000000000000000000000000000000..7a0bbdaff53e486465d494d5d42b932442f644e0 --- /dev/null +++ b/mvsepless/vbach_lib/algorithm/nsf.py @@ -0,0 +1,170 @@ + +import math +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils.weight_norm import remove_weight_norm +from torch.nn.utils.parametrizations import weight_norm +from typing import Optional + +from .commons import init_weights +from .generators import SineGen +from .residuals import LRELU_SLOPE, ResBlock1, ResBlock2 + + +class SourceModuleHnNSF(nn.Module): + def __init__( + self, + sample_rate, + harmonic_num=0, + sine_amp=0.1, + add_noise_std=0.003, + voiced_threshod=0, + is_half=True, + ): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + self.is_half = is_half + + self.l_sin_gen = SineGen( + sample_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod + ) + self.l_linear = nn.Linear(harmonic_num + 1, 1) + self.l_tanh = nn.Tanh() + + def forward(self, x: torch.Tensor, upsample_factor: int = 1): + sine_wavs, uv, _ = self.l_sin_gen(x, upsample_factor) + sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + return sine_merge, None, None + + +class GeneratorNSF(nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels, + sr, + is_half=False, + ): + super(GeneratorNSF, self).__init__() + + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.f0_upsamp = nn.Upsample(scale_factor=math.prod(upsample_rates)) + self.m_source = SourceModuleHnNSF(sample_rate=sr, harmonic_num=0, is_half=is_half) + + self.conv_pre = nn.Conv1d( + initial_channel, upsample_initial_channel, 7, 1, padding=3 + ) + resblock_cls = ResBlock1 if resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + self.noise_convs = nn.ModuleList() + + channels = [ + upsample_initial_channel // (2 ** (i + 1)) for i in range(len(upsample_rates)) + ] + stride_f0s = [ + math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1 + for i in range(len(upsample_rates)) + ] + + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + nn.ConvTranspose1d( + upsample_initial_channel // (2**i), + channels[i], + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.noise_convs.append( + nn.Conv1d( + 1, + channels[i], + kernel_size=(stride_f0s[i] * 2 if stride_f0s[i] > 1 else 1), + stride=stride_f0s[i], + padding=(stride_f0s[i] // 2 if stride_f0s[i] > 1 else 0), + ) + ) + + self.resblocks = nn.ModuleList( + [ + resblock_cls(channels[i], k, d) + for i in range(len(self.ups)) + for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes) + ] + ) + + self.conv_post = nn.Conv1d(channels[-1], 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + self.upp = math.prod(upsample_rates) + self.lrelu_slope = LRELU_SLOPE + + def forward(self, x, f0, g: Optional[torch.Tensor] = None): + har_source, _, _ = self.m_source(f0, self.upp) + har_source = har_source.transpose(1, 2) + x = self.conv_pre(x) + + if g is not None: + x = x + self.cond(g) + + for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)): + x = F.leaky_relu(x, self.lrelu_slope) + x = ups(x) + x = x + noise_convs(har_source) + + xs = sum( + [ + resblock(x) + for j, resblock in enumerate(self.resblocks) + if j in range(i * self.num_kernels, (i + 1) * self.num_kernels) + ] + ) + x = xs / self.num_kernels + + x = F.leaky_relu(x) + x = torch.tanh(self.conv_post(x)) + return x + + def remove_weight_norm(self): + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + def __prepare_scriptable__(self): + for l in self.ups: + for hook in l._forward_pre_hooks.values(): + if ( + hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" + and hook.__class__.__name__ == "_WeightNorm" + ): + remove_weight_norm(l) + for l in self.resblocks: + for hook in l._forward_pre_hooks.values(): + if ( + hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" + and hook.__class__.__name__ == "_WeightNorm" + ): + remove_weight_norm(l) + return self + + \ No newline at end of file diff --git a/mvsepless/vbach_lib/algorithm/residuals.py b/mvsepless/vbach_lib/algorithm/residuals.py new file mode 100644 index 0000000000000000000000000000000000000000..271a8c7b101fd49cd4bfd4092c647e50c8065e04 --- /dev/null +++ b/mvsepless/vbach_lib/algorithm/residuals.py @@ -0,0 +1,235 @@ + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils.weight_norm import remove_weight_norm +from torch.nn.utils.parametrizations import weight_norm +from typing import Optional + +from .commons import get_padding, init_weights +from .modules import WaveNet + + +LRELU_SLOPE = 0.1 + + +def create_conv1d_layer(channels, kernel_size, dilation): + return weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation, + padding=get_padding(kernel_size, dilation), + ) + ) + + +def apply_mask(tensor, mask): + return tensor * mask if mask is not None else tensor + + +class ResBlockBase(nn.Module): + def __init__(self, channels, kernel_size, dilations): + super(ResBlockBase, self).__init__() + self.convs1 = nn.ModuleList( + [create_conv1d_layer(channels, kernel_size, d) for d in dilations] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [create_conv1d_layer(channels, kernel_size, 1) for _ in dilations] + ) + self.convs2.apply(init_weights) + + def forward(self, x, x_mask=None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = apply_mask(xt, x_mask) + xt = F.leaky_relu(c1(xt), LRELU_SLOPE) + xt = apply_mask(xt, x_mask) + xt = c2(xt) + x = xt + x + return apply_mask(x, x_mask) + + def remove_weight_norm(self): + for conv in self.convs1 + self.convs2: + remove_weight_norm(conv) + + +class ResBlock1(ResBlockBase): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__(channels, kernel_size, dilation) + + +class ResBlock2(ResBlockBase): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__(channels, kernel_size, dilation) + + +class Log(nn.Module): + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class Flip(nn.Module): + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x + + +class ElementwiseAffine(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels, 1)) + self.logs = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class ResidualCouplingBlock(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0, + ): + super(ResidualCouplingBlock, self).__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + ResidualCouplingLayer( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + mean_only=True, + ) + ) + self.flows.append(Flip()) + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + reverse: bool = False, + ): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow.forward(x, x_mask, g=g, reverse=reverse) + return x + + def remove_weight_norm(self): + for i in range(self.n_flows): + self.flows[i * 2].remove_weight_norm() + + def __prepare_scriptable__(self): + for i in range(self.n_flows): + for hook in self.flows[i * 2]._forward_pre_hooks.values(): + if ( + hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" + and hook.__class__.__name__ == "_WeightNorm" + ): + remove_weight_norm(self.flows[i * 2]) + + return self + + +class ResidualCouplingLayer(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=0, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WaveNet( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=p_dropout, + gin_channels=gin_channels, + ) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x + + def remove_weight_norm(self): + self.enc.remove_weight_norm() + + \ No newline at end of file diff --git a/mvsepless/vbach_lib/algorithm/synthesizers.py b/mvsepless/vbach_lib/algorithm/synthesizers.py new file mode 100644 index 0000000000000000000000000000000000000000..7981ed59cba25e14166e44800ba6f10804c82613 --- /dev/null +++ b/mvsepless/vbach_lib/algorithm/synthesizers.py @@ -0,0 +1,191 @@ + +import torch +from torch import nn +from torch.nn.utils.weight_norm import remove_weight_norm +from typing import Optional + +from .commons import slice_segments, rand_slice_segments +from .encoders import TextEncoder, PosteriorEncoder +from .generators import Generator +from .nsf import GeneratorNSF +from .residuals import ResidualCouplingBlock + + +class Synthesizer(nn.Module): + def __init__( + self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + spk_embed_dim, + gin_channels, + sr, + use_f0, + input_dim=768, + **kwargs + ): + super(Synthesizer, self).__init__() + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = float(p_dropout) + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.gin_channels = gin_channels + self.spk_embed_dim = spk_embed_dim + self.use_f0 = use_f0 + + self.enc_p = TextEncoder( + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + float(p_dropout), + input_dim, + f0=use_f0, + ) + + if use_f0: + self.dec = GeneratorNSF( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + sr=sr, + is_half=kwargs["is_half"], + ) + else: + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + self.flow = ResidualCouplingBlock( + inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels + ) + self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels) + + def remove_weight_norm(self): + self.dec.remove_weight_norm() + self.flow.remove_weight_norm() + self.enc_q.remove_weight_norm() + + def __prepare_scriptable__(self): + for hook in self.dec._forward_pre_hooks.values(): + if ( + hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" + and hook.__class__.__name__ == "_WeightNorm" + ): + remove_weight_norm(self.dec) + for hook in self.flow._forward_pre_hooks.values(): + if ( + hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" + and hook.__class__.__name__ == "_WeightNorm" + ): + remove_weight_norm(self.flow) + if hasattr(self, "enc_q"): + for hook in self.enc_q._forward_pre_hooks.values(): + if ( + hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" + and hook.__class__.__name__ == "_WeightNorm" + ): + remove_weight_norm(self.enc_q) + return self + + @torch.jit.ignore + def forward( + self, + phone: torch.Tensor, + phone_lengths: torch.Tensor, + pitch: Optional[torch.Tensor] = None, + pitchf: Optional[torch.Tensor] = None, + y: torch.Tensor = None, + y_lengths: torch.Tensor = None, + ds: Optional[torch.Tensor] = None, + ): + g = self.emb_g(ds).unsqueeze(-1) + m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) + if y is not None: + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) + z_p = self.flow(z, y_mask, g=g) + z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size) + if self.use_f0: + pitchf = slice_segments(pitchf, ids_slice, self.segment_size, 2) + o = self.dec(z_slice, pitchf, g=g) + else: + o = self.dec(z_slice, g=g) + return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) + else: + return None, None, x_mask, None, (None, None, m_p, logs_p, None, None) + + @torch.jit.export + def infer( + self, + phone: torch.Tensor, + phone_lengths: torch.Tensor, + pitch: Optional[torch.Tensor] = None, + nsff0: Optional[torch.Tensor] = None, + sid: torch.Tensor = None, + rate: Optional[torch.Tensor] = None, + ): + g = self.emb_g(sid).unsqueeze(-1) + m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) + z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask + if rate is not None: + assert isinstance(rate, torch.Tensor) + head = int(z_p.shape[2] * (1.0 - rate.item())) + z_p = z_p[:, :, head:] + x_mask = x_mask[:, :, head:] + if self.use_f0: + nsff0 = nsff0[:, head:] + if self.use_f0: + z = self.flow(z_p, x_mask, g=g, reverse=True) + o = self.dec(z * x_mask, nsff0, g=g) + else: + z = self.flow(z_p, x_mask, g=g, reverse=True) + o = self.dec(z * x_mask, g=g) + return o, x_mask, (z, z_p, m_p, logs_p) + + \ No newline at end of file diff --git a/mvsepless/vbach_lib/fairseq.py b/mvsepless/vbach_lib/fairseq.py new file mode 100644 index 0000000000000000000000000000000000000000..ada43e3318cb7c183bf4e9fb187cde42e22e5ab6 --- /dev/null +++ b/mvsepless/vbach_lib/fairseq.py @@ -0,0 +1,2030 @@ +import contextlib +import math +import re +import sys +import types +import uuid + +import numpy as np +import torch +import torch.nn.functional as F +from omegaconf import DictConfig, open_dict +from torch import nn + + +class Dictionary: + def __init__(self, *args, **kwargs): + pass + + +fairseq = types.ModuleType("fairseq") +fairseq_data = types.ModuleType("fairseq.data") +fairseq_data_dictionary = types.ModuleType("fairseq.data.dictionary") +fairseq_data_dictionary.Dictionary = Dictionary +fairseq.data = fairseq_data +fairseq_data.dictionary = fairseq_data_dictionary +sys.modules["fairseq"] = fairseq +sys.modules["fairseq.data"] = fairseq_data +sys.modules["fairseq.data.dictionary"] = fairseq_data_dictionary + + +def load_model(filename): + state = torch.load(filename, map_location="cpu", weights_only=False) + + model = HubertModel(HubertConfig(**state["cfg"]["model"]), num_classes=int(state["model"]["label_embs_concat"].shape[0])) + model.load_state_dict(state["model"], strict=False) + + return model + + +def softmax(x, dim, onnx_trace=False): + return F.softmax(x.float(), dim=dim) if onnx_trace else F.softmax(x, dim=dim, dtype=torch.float32) + + +def log_softmax(x, dim, onnx_trace=False): + return F.log_softmax(x.float(), dim=dim) if onnx_trace else F.log_softmax(x, dim=dim, dtype=torch.float32) + + +def eval_str_dict(x, type=dict): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + return x + + +def with_incremental_state(cls): + cls.__bases__ = (FairseqIncrementalState,) + tuple(b for b in cls.__bases__ if b != FairseqIncrementalState) + return cls + + +def quant_noise(module, p, block_size): + if p <= 0: + return module + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + is_conv = module.weight.ndim == 4 + if not is_conv: + assert module.weight.size(1) % block_size == 0 + elif module.kernel_size == (1, 1): + assert module.in_channels % block_size == 0 + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0 + + def _forward_pre_hook(mod, input): + if mod.training: + if not is_conv: + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + mask = torch.zeros(in_features // block_size * out_features, device=weight.device) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + else: + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + if mod.kernel_size == (1, 1): + mask = torch.zeros(int(in_channels // block_size * out_channels), device=weight.device) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device) + mask.bernoulli_(p) + mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + + mask = mask.to(torch.bool) + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +class FairseqDropout(nn.Module): + def __init__(self, p, module_name=None): + super().__init__() + self.p = p + self.module_name = module_name + self.apply_during_inference = False + + def forward(self, x, inplace=False): + return ( + F.dropout(x, p=self.p, training=True, inplace=inplace) if self.p > 0 and (self.training or self.apply_during_inference) else x + ) + + def make_generation_fast_(self, name, retain_dropout=False, retain_dropout_modules=None, **kwargs): + if retain_dropout: + if retain_dropout_modules is None or self.module_name in retain_dropout_modules: + self.apply_during_inference = True + + +class FairseqIncrementalState: + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_incremental_state() + + def init_incremental_state(self): + self._incremental_state_id = str(uuid.uuid4()) + + def _get_full_incremental_state_key(self, key): + return f"{self._incremental_state_id}.{key}" + + def get_incremental_state(self, incremental_state, key): + full_key = self._get_full_incremental_state_key(key) + if incremental_state is None or full_key not in incremental_state: + return None + return incremental_state[full_key] + + def set_incremental_state(self, incremental_state, key, value): + if incremental_state is not None: + incremental_state[self._get_full_incremental_state_key(key)] = value + return incremental_state + + +class FairseqDecoder(nn.Module): + def __init__(self, dictionary): + super().__init__() + self.dictionary = dictionary + self.onnx_trace = False + self.adaptive_softmax = None + + def forward(self, prev_output_tokens, encoder_out=None, **kwargs): + x, extra = self.extract_features(prev_output_tokens, encoder_out=encoder_out, **kwargs) + return self.output_layer(x), extra + + def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs): + pass + + def output_layer(self, features, **kwargs): + pass + + def get_normalized_probs(self, net_output, log_probs, sample=None): + return self.get_normalized_probs_scriptable(net_output, log_probs, sample) + + def get_normalized_probs_scriptable(self, net_output, log_probs, sample=None): + if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: + if sample is not None: + assert "target" in sample + target = sample["target"] + else: + target = None + out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) + return out.exp_() if not log_probs else out + + logits = net_output[0] + return log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) if log_probs else softmax(logits, dim=-1, onnx_trace=self.onnx_trace) + + def max_positions(self): + return 1e6 + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + +@with_incremental_state +class FairseqIncrementalDecoder(FairseqDecoder): + def __init__(self, dictionary): + super().__init__(dictionary) + + def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs): + pass + + def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs): + pass + + def reorder_incremental_state(self, incremental_state, new_order): + pass + + def reorder_incremental_state_scripting(self, incremental_state, new_order): + for module in self.modules(): + if hasattr(module, "reorder_incremental_state"): + result = module.reorder_incremental_state(incremental_state, new_order) + if result is not None: + incremental_state = result + + def set_beam_size(self, beam_size): + if getattr(self, "_beam_size", -1) != beam_size: + seen = set() + + def apply_set_beam_size(module): + if module != self and hasattr(module, "set_beam_size") and module not in seen: + seen.add(module) + module.set_beam_size(beam_size) + + self.apply(apply_set_beam_size) + self._beam_size = beam_size + + +class MultiheadAttention(FairseqIncrementalDecoder): + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + dictionary=None, + q_noise=0.0, + qn_block_size=8, + xformers_att_config=None, + xformers_blocksparse_layout=None, + xformers_blocksparse_blocksize=16, + ): + super().__init__(dictionary) + xformers_att_config = eval_str_dict(xformers_att_config) + self.use_xformers = xformers_att_config is not None + if self.use_xformers: + raise ImportError + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + self.num_heads = num_heads + self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__) + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim + self.scaling = self.head_dim**-0.5 + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + assert not self.self_attention or self.qkv_same_dim + self.k_proj = quant_noise(nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size) + self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size) + self.q_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) + self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) + if add_bias_kv: + self.bias_k, self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim)), nn.Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + self.add_zero_attn = add_zero_attn + self.beam_size = 1 + self.reset_parameters() + self.onnx_trace = False + self.skip_embed_dim_check = False + self.init_incremental_state() + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def reset_parameters(self): + if self.qkv_same_dim: + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def _get_reserve_head_index(self, num_heads_to_keep: int): + k_proj_heads_norm, q_proj_heads_norm, v_proj_heads_norm = [], [], [] + for i in range(self.num_heads): + start_idx = i * self.head_dim + end_idx = (i + 1) * self.head_dim + k_proj_heads_norm.append( + torch.sum(torch.abs(self.k_proj.weight[start_idx:end_idx])).tolist() + + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist(), + ) + q_proj_heads_norm.append( + torch.sum(torch.abs(self.q_proj.weight[start_idx:end_idx])).tolist() + + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist(), + ) + v_proj_heads_norm.append( + torch.sum(torch.abs(self.v_proj.weight[start_idx:end_idx])).tolist() + + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist(), + ) + + heads_norm = [] + for i in range(self.num_heads): + heads_norm.append(k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i]) + + sorted_head_index = sorted(range(self.num_heads), key=lambda k: heads_norm[k], reverse=True) + reserve_head_index = [] + for i in range(num_heads_to_keep): + reserve_head_index.append((sorted_head_index[i] * self.head_dim, (sorted_head_index[i] + 1) * self.head_dim)) + return reserve_head_index + + def _adaptive_prune_heads(self, reserve_head_index): + new_q_weight, new_q_bias, new_k_weight, new_k_bias, new_v_weight, new_v_bias, new_out_proj_weight = [], [], [], [], [], [], [] + for ele in reserve_head_index: + start_idx, end_idx = ele + new_q_weight.append(self.q_proj.weight[start_idx:end_idx]) + new_q_bias.append(self.q_proj.bias[start_idx:end_idx]) + new_k_weight.append(self.k_proj.weight[start_idx:end_idx]) + new_k_bias.append(self.k_proj.bias[start_idx:end_idx]) + new_v_weight.append(self.v_proj.weight[start_idx:end_idx]) + new_v_bias.append(self.v_proj.bias[start_idx:end_idx]) + new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx]) + new_q_weight = torch.cat(new_q_weight).detach() + new_k_weight = torch.cat(new_k_weight).detach() + new_v_weight = torch.cat(new_v_weight).detach() + new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach() + new_q_weight.requires_grad = True + new_k_weight.requires_grad = True + new_v_weight.requires_grad = True + new_out_proj_weight.requires_grad = True + new_q_bias = torch.cat(new_q_bias).detach() + new_q_bias.requires_grad = True + new_k_bias = torch.cat(new_k_bias).detach() + new_k_bias.requires_grad = True + new_v_bias = torch.cat(new_v_bias).detach() + new_v_bias.requires_grad = True + self.q_proj.weight = nn.Parameter(new_q_weight) + self.q_proj.bias = nn.Parameter(new_q_bias) + self.k_proj.weight = nn.Parameter(new_k_weight) + self.k_proj.bias = nn.Parameter(new_k_bias) + self.v_proj.weight = nn.Parameter(new_v_weight) + self.v_proj.bias = nn.Parameter(new_v_bias) + self.out_proj.weight = nn.Parameter(new_out_proj_weight) + self.num_heads = len(reserve_head_index) + self.embed_dim = self.head_dim * self.num_heads + self.q_proj.out_features = self.embed_dim + self.k_proj.out_features = self.embed_dim + self.v_proj.out_features = self.embed_dim + + def _set_skip_embed_dim_check(self): + self.skip_embed_dim_check = True + + def _pad_masks(self, key_padding_mask, attn_mask): + if attn_mask is not None: + shape = attn_mask.size()[:-1] + torch.Size([1]) + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1) + + if key_padding_mask is not None: + shape = key_padding_mask.size()[:-1] + torch.Size([1]) + key_padding_mask = torch.cat([key_padding_mask, key_padding_mask.new_zeros(shape)], dim=-1) + + return key_padding_mask, attn_mask + + def _add_bias(self, k, v, key_padding_mask, attn_mask, bsz): + assert self.bias_k is not None or self.bias_v is not None + key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask) + return torch.cat([k, self.bias_k.repeat(1, bsz, 1)]), torch.cat([v, self.bias_v.repeat(1, bsz, 1)]), key_padding_mask, attn_mask + + def _append_zero_attn(self, k, v, key_padding_mask, attn_mask): + zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:] + key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask) + return ( + torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2), + torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2), + key_padding_mask, + attn_mask, + ) + + def forward( + self, + query, + key, + value, + key_padding_mask=None, + incremental_state=None, + need_weights=True, + static_kv=False, + attn_mask=None, + before_softmax=False, + need_head_weights=False, + ): + if need_head_weights: + need_weights = True + is_tpu = query.device.type == "xla" + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + if not self.skip_embed_dim_check: + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert value is not None + assert src_len, key_bsz == value.shape[:2] + + if ( + not self.onnx_trace + and not is_tpu + and incremental_state is None + and not static_kv + and not torch.jit.is_scripting() + and not self.skip_embed_dim_check + ): + assert key is not None and value is not None + return F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training or self.dropout_module.apply_during_inference, + key_padding_mask.bool() if key_padding_mask is not None else None, + need_weights, + attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + if self.beam_size > 1 and bsz == key.size(1): + key = key.view(key.size(0), -1, self.beam_size, key.size(2))[:, :, 0, :] + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.view(-1, self.beam_size, key_padding_mask.size(1))[:, 0, :] + k = self.k_proj(key) + v = self.v_proj(key) + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + q *= self.scaling + if self.bias_k is not None: + assert self.bias_v is not None + k, v, attn_mask, key_padding_mask = self._add_bias(k, v, attn_mask, key_padding_mask, bsz) + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + kv_bsz = bsz + if k is not None: + kv_bsz = k.size(1) + k = k.contiguous().view(-1, kv_bsz * self.num_heads, self.head_dim).transpose(0, 1) + + if v is not None: + v = v.contiguous().view(-1, kv_bsz * self.num_heads, self.head_dim).transpose(0, 1) + if saved_state is not None: + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + + kv_bsz = _prev_key.size(0) + prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim) + + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None or kv_bsz == _prev_value.size(0) + prev_value = _prev_value.view(kv_bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + + prev_key_padding_mask = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=kv_bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(kv_bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + + assert k is not None + assert k.size(1) == src_len + + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == kv_bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k, v, key_padding_mask, attn_mask = self._append_zero_attn(k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + + if self.encoder_decoder_attention and bsz != kv_bsz: + attn_weights = torch.einsum( + "bxhtd,bhsd->bxhts", + q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]), + k.view((kv_bsz, self.num_heads) + k.size()[1:]), + ) + attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:]) + else: + attn_weights = torch.bmm(q, k.transpose(1, 2)) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + if self.onnx_trace: + attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) + attn_weights += attn_mask + + if key_padding_mask is not None: + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = ( + attn_weights.view(kv_bsz, -1, self.num_heads, tgt_len, src_len).masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(torch.bool), + float("-inf"), + ) + if not is_tpu + else attn_weights.transpose(0, 2).masked_fill(key_padding_mask, float("-inf")).transpose(0, 2) + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v + attn_weights_float = softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + assert v is not None + attn = None + + if self.encoder_decoder_attention and bsz != kv_bsz: + attn = torch.einsum( + "bxhts,bhsd->bxhtd", + attn_probs.view((kv_bsz, -1, self.num_heads) + attn_probs.size()[1:]), + v.view((kv_bsz, self.num_heads) + v.size()[1:]), + ) + attn = attn.reshape((-1,) + attn.size()[-2:]) + else: + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + + attn = ( + attn.contiguous().view(tgt_len, bsz, self.embed_dim) + if self.onnx_trace and attn.size(1) == 1 + else attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) + ) + attn = self.out_proj(attn) + attn_weights = None + + if need_weights: + attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) + if not need_head_weights: + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights + + @staticmethod + def _append_prev_key_padding_mask(key_padding_mask, prev_key_padding_mask, batch_size, src_len, static_kv): + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1) + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros((batch_size, src_len - prev_key_padding_mask.size(1)), device=prev_key_padding_mask.device) + new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros((batch_size, src_len - key_padding_mask.size(1)), device=key_padding_mask.device) + new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + @torch.jit.export + def reorder_incremental_state(self, incremental_state, new_order): + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + for k in input_buffer.keys(): + input_buffer_k = input_buffer[k] + if input_buffer_k is not None: + if self.encoder_decoder_attention: + if input_buffer_k.size(0) * self.beam_size == new_order.size(0): + return incremental_state + if self.beam_size > 1: + input_buffer[k] = input_buffer_k.index_select(0, new_order.reshape(-1, self.beam_size)[:, 0] // self.beam_size) + else: + input_buffer[k] = input_buffer_k.index_select(0, new_order) + else: + input_buffer[k] = input_buffer_k.index_select(0, new_order) + incremental_state = self._set_input_buffer(incremental_state, input_buffer) + return incremental_state + + def set_beam_size(self, beam_size): + self.beam_size = beam_size + + def _get_input_buffer(self, incremental_state): + result = self.get_incremental_state(incremental_state, "attn_state") + return result if result is not None else {} + + def _set_input_buffer(self, incremental_state, buffer): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + items_to_add, keys_to_remove = {}, [] + for k in state_dict.keys(): + if k.endswith(prefix + "in_proj_weight"): + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] + items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] + items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] + keys_to_remove.append(k) + k_bias = prefix + "in_proj_bias" + if k_bias in state_dict.keys(): + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] + items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim] + items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] + keys_to_remove.append(prefix + "in_proj_bias") + + for k in keys_to_remove: + del state_dict[k] + + for key, value in items_to_add.items(): + state_dict[key] = value + + +def init_bert_params(module): + def normal_(data): + data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + + +def make_conv_pos(e, k, g): + pos_conv = nn.Conv1d(e, e, kernel_size=k, padding=k // 2, groups=g) + dropout = 0 + nn.init.normal_(pos_conv.weight, mean=0, std=math.sqrt((4 * (1.0 - dropout)) / (k * e))) + nn.init.constant_(pos_conv.bias, 0) + return nn.Sequential(nn.utils.parametrizations.weight_norm(pos_conv, name="weight", dim=2), SamePad(k), nn.GELU()) + + +def is_xla_tensor(tensor): + return torch.is_tensor(tensor) and tensor.device.type == "xla" + + +def index_put(tensor, indices, value): + if is_xla_tensor(tensor): + for _ in range(indices.dim(), tensor.dim()): + indices = indices.unsqueeze(-1) + + if indices.size(-1) < tensor.size(-1): + indices = indices.expand_as(tensor) + tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) + else: + tensor[indices] = value + + return tensor + + +def pad_to_multiple(x, multiple, dim=-1, value=0): + if x is None: + return None, 0 + tsz = x.size(dim) + m = tsz / multiple + remainder = math.ceil(m) * multiple - tsz + if m.is_integer(): + return x, 0 + return F.pad(x, (*((0,) * (-1 - dim) * 2), 0, remainder), value=value), remainder + + +def compute_mask_indices( + shape, + padding_mask, + mask_prob, + mask_length, + mask_type="static", + mask_other=0.0, + min_masks=0, + no_overlap=False, + min_space=0, + require_same_masks=True, + mask_dropout=0.0, + add_masks=False, + seed=None, + epoch=None, + indices=None, + idc_select_ver=1, + num_mask_ver=2, +): + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + if num_mask_ver == 1: + all_num_mask = max(min_masks, int(mask_prob * all_sz / float(mask_length) + np.random.rand())) + mask_idcs = [] + + for i in range(bsz): + seed_i = ( + int(hash((seed, epoch, indices[i].item())) % 1e6) if seed is not None and epoch is not None and indices is not None else None + ) + rng = np.random.default_rng(seed_i) + + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + assert sz >= 0, sz + else: + sz = all_sz + + if num_mask_ver == 1: + num_mask = ( + max(min_masks, int(mask_prob * sz / float(mask_length) + np.random.rand())) if padding_mask is not None else all_num_mask + ) + elif num_mask_ver == 2: + num_mask = max(min_masks, int(mask_prob * sz / float(mask_length) + rng.random())) + else: + raise ValueError + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = [max(1, int(round(x))) for x in rng.normal(mask_length, mask_other, size=num_mask)] + elif mask_type == "poisson": + lengths = [int(round(x)) for x in rng.poisson(mask_length, size=num_mask)] + else: + raise Exception + + if sum(lengths) == 0: + if mask_type == "static": + raise ValueError + lengths = [min(mask_length, sz - 1)] + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = rng.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter((e - s if e - s >= length + min_space else 0 for s, e in parts), np.int32) + l_sum = np.sum(lens) + if l_sum == 0: + break + s, e = parts.pop(rng.choice(len(parts), p=lens / np.sum(lens))) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + if idc_select_ver == 1: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + mask_idc = rng.choice(sz - min_len, num_mask, replace=False) + elif idc_select_ver == 2: + mask_idc = rng.choice(sz, num_mask, replace=False) + else: + raise ValueError + + mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]) + + mask_idc = np.unique(mask_idc[mask_idc < sz]) + if len(mask_idc) >= sz: + raise ValueError + mask_idcs.append(mask_idc) + + target_len = None + if require_same_masks: + target_len = max([len(m) for m in mask_idcs]) if add_masks else min([len(m) for m in mask_idcs]) + + for i, mask_idc in enumerate(mask_idcs): + if target_len is not None and len(mask_idc) > target_len: + mask_idc = rng.choice(mask_idc, target_len, replace=False) + mask[i, mask_idc] = True + + if target_len is not None and len(mask_idc) < target_len: + to_mask = rng.choice(np.flatnonzero(~mask[i]), target_len - len(mask_idc), replace=False) + mask[i, to_mask] = True + + if mask_dropout > 0: + masked = np.flatnonzero(mask[i]) + mask[i, rng.choice(masked, np.rint(len(masked) * mask_dropout).astype(int), replace=False)] = False + + return mask + + +def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True): + return nn.LayerNorm(normalized_shape, eps, elementwise_affine) + + +def prune_state_dict(state_dict, model_cfg): + arch = None + if model_cfg is not None: + arch = model_cfg._name if isinstance(model_cfg, DictConfig) else getattr(model_cfg, "arch", None) + if not model_cfg or arch is None or arch == "ptt_transformer": + return state_dict + encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None) + decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None) + if not encoder_layers_to_keep and not decoder_layers_to_keep: + return state_dict + + def create_pruning_pass(layers_to_keep, layer_name): + keep_layers = sorted(int(layer_string) for layer_string in layers_to_keep.split(",")) + mapping_dict = {} + for i in range(len(keep_layers)): + mapping_dict[str(keep_layers[i])] = str(i) + + return {"substitution_regex": re.compile(rf"^{layer_name}.*\.layers\.(\d+)"), "mapping_dict": mapping_dict} + + pruning_passes, new_state_dict = [], {} + if encoder_layers_to_keep: + pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder")) + if decoder_layers_to_keep: + pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder")) + + for layer_name in state_dict.keys(): + match = re.search(r"\.layers\.(\d+)\.", layer_name) + if not match: + new_state_dict[layer_name] = state_dict[layer_name] + continue + + original_layer_number = match.group(1) + for pruning_pass in pruning_passes: + if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass["substitution_regex"].search(layer_name): + substitution_match = pruning_pass["substitution_regex"].search(layer_name) + new_state_dict[ + ( + layer_name[: substitution_match.start(1)] + + pruning_pass["mapping_dict"][original_layer_number] + + layer_name[substitution_match.end(1) :] + ) + ] = state_dict[layer_name] + + with open_dict(model_cfg) if isinstance(model_cfg, DictConfig) else contextlib.ExitStack(): + if hasattr(model_cfg, "encoder_layers_to_keep"): + model_cfg.encoder_layers_to_keep = None + if hasattr(model_cfg, "decoder_layers_to_keep"): + model_cfg.decoder_layers_to_keep = None + + return new_state_dict + + +def relu_squared(x): + return F.relu(x).pow(2) + + +def get_activation_fn(activation): + def gelu(x): + return nn.functional.gelu(x.float()).type_as(x) + + def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + + if activation == "relu": + return F.relu + if activation == "relu_squared": + return relu_squared + if activation == "gelu": + return gelu + if activation == "gelu_fast" or activation == "gelu_accurate": + return gelu_accurate + if activation == "tanh": + return torch.tanh + if activation == "linear": + return lambda x: x + if activation == "swish": + return nn.SiLU + raise RuntimeError + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class TransformerSentenceEncoderLayer(nn.Module): + def __init__( + self, + embedding_dim=768, + ffn_embedding_dim=3072, + num_attention_heads=8, + dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.1, + activation_fn="relu", + layer_norm_first=False, + ): + super().__init__() + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention(self.embedding_dim, num_attention_heads, dropout=attention_dropout, self_attention=True) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + self.layer_norm_first = layer_norm_first + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None): + residual = x + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + attn_mask=self_attn_mask, + need_weights=False, + ) + x = residual + self.dropout1(x) + residual = x + x = self.fc2(self.dropout2(self.activation_fn(self.fc1(self.final_layer_norm(x))))) + layer_result = x + x = residual + self.dropout3(x) + else: + x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, need_weights=False) + x = self.self_attn_layer_norm(residual + self.dropout1(x)) + residual = x + x = self.fc2(self.dropout2(self.activation_fn(self.fc1(x)))) + layer_result = x + x = self.final_layer_norm(residual + self.dropout3(x)) + + return x, (attn, layer_result) + + +class AdapterFast(nn.Module): + def __init__(self, adapter_num, input_dim, hidden_dim, act_fn): + super().__init__() + self.adapter_num = adapter_num + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim)) + self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim)) + self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim)) + self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim)) + self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim)) + self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim)) + self.act_fn = nn.Identity() + if act_fn == "relu": + self.act_fn = nn.ReLU() + elif act_fn == "gelu": + self.act_fn = nn.GELU() + elif act_fn == "selu": + self.act_fn = nn.SELU() + else: + raise ValueError + self.input_dim = input_dim + self.reset_parameters() + + def reset_parameters(self): + for ii in range(self.adapter_num): + nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5)) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii]) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.b_a[ii], -bound, bound) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii]) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.b_b[ii], -bound, bound) + + nn.init.ones_(self.ln_W) + nn.init.zeros_(self.ln_b) + + def forward(self, x, adapter_id): + ii = adapter_id + return F.linear( + self.act_fn(F.linear(F.layer_norm(x, (self.input_dim,), self.ln_W[ii], self.ln_b[ii]), self.W_a[ii], self.b_a[ii])), + self.W_b[ii], + self.b_b[ii], + ) + + def extra_repr(self): + return f"adapter={self.adapter_num}, input_dim={self.input_dim}, hidden_dim={self.hidden_dim}" + + +class FeedForwardModule(nn.Module): + def __init__(self, input_feat, hidden_units, dropout1, dropout2, activation_fn="swish", bias=True): + super(FeedForwardModule, self).__init__() + self.layer_norm = LayerNorm(input_feat) + self.w_1 = nn.Linear(input_feat, hidden_units, bias=bias) + self.w_2 = nn.Linear(hidden_units, input_feat, bias=bias) + self.dropout1 = nn.Dropout(dropout1) + self.dropout2 = nn.Dropout(dropout2) + self.activation = get_activation_fn(activation_fn)(hidden_units) + + def forward(self, x): + return self.dropout2(self.w_2(self.dropout1(self.activation(self.w_1(self.layer_norm(x)))))) + + +class ConvolutionModule(nn.Module): + def __init__(self, embed_dim, channels, depthwise_kernel_size, dropout, activation_fn="swish", bias=False, export=False): + super(ConvolutionModule, self).__init__() + assert (depthwise_kernel_size - 1) % 2 == 0 + self.layer_norm = LayerNorm(embed_dim, export=export) + self.pointwise_conv1 = nn.Conv1d(embed_dim, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias) + self.glu = nn.GLU(dim=1) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + depthwise_kernel_size, + stride=1, + padding=(depthwise_kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + self.batch_norm = nn.BatchNorm1d(channels) + self.activation = get_activation_fn(activation_fn)(channels) + self.pointwise_conv2 = nn.Conv1d(channels, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.dropout( + self.pointwise_conv2( + self.activation(self.batch_norm(self.depthwise_conv(self.glu(self.pointwise_conv1(self.layer_norm(x).transpose(1, 2)))))), + ), + ).transpose(1, 2) + + +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=x1.ndim - 1) + + +def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): + cos, sin = (cos[offset : q.shape[0] + offset, ...], sin[offset : q.shape[0] + offset, ...]) + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + + +class RotaryPositionalEmbedding(nn.Module): + def __init__(self, dim, base=10000, precision=torch.half): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self.seq_len_cached = 0 + self.cos_cached = torch.empty(self.seq_len_cached, 1, 1, dim) + self.sin_cached = torch.empty(self.seq_len_cached, 1, 1, dim) + self.precision = precision + + def forward(self, x, seq_len=0): + if seq_len > self.seq_len_cached: + self.seq_len_cached = seq_len + freqs = torch.einsum("i,j->ij", torch.arange(seq_len, device=x.device).type_as(self.inv_freq), self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.cos_cached = emb.cos().view(emb.size(0), 1, 1, emb.size(1)) + self.sin_cached = emb.sin().view(emb.size(0), 1, 1, emb.size(1)) + return self.cos_cached, self.sin_cached + + +class ESPNETMultiHeadedAttention(nn.Module): + def __init__(self, n_feat, n_head, dropout): + super(ESPNETMultiHeadedAttention, self).__init__() + assert n_feat % n_head == 0 + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = nn.Dropout(p=dropout) + + def forward_qkv(self, query, key, value, **kwargs): + n_batch = query.size(0) + return ( + self.linear_q(query).view(n_batch, -1, self.h, self.d_k).transpose(1, 2), + self.linear_k(key).view(n_batch, -1, self.h, self.d_k).transpose(1, 2), + self.linear_v(value).view(n_batch, -1, self.h, self.d_k).transpose(1, 2), + ) + + def forward_attention(self, value, scores, mask): + n_batch = value.size(0) + if mask is not None: + scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2).to(bool), float("-inf")) + self.attn = torch.softmax(scores, dim=-1) + else: + self.attn = torch.softmax(scores, dim=-1) + + return self.linear_out( + torch.matmul(self.dropout(self.attn), value).transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k), + ) + + def forward(self, query, key, value, key_padding_mask=None, **kwargs): + q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1)) + return self.forward_attention(v, torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None + + +class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention): + def __init__(self, n_feat, n_head, dropout, zero_triu=False): + super().__init__(n_feat, n_head, dropout) + self.zero_triu = zero_triu + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + self.pos_bias_u = nn.Parameter(torch.zeros(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.zeros(self.h, self.d_k)) + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x): + x = ( + torch.cat([torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype), x], dim=-1) + .view(*x.size()[:2], x.size(3) + 1, x.size(2))[:, :, 1:] + .view_as(x)[:, :, :, : x.size(-1) // 2 + 1] + ) + if self.zero_triu: + x = x * torch.tril(torch.ones((x.size(2), x.size(3)), device=x.device), x.size(3) - x.size(2))[None, None, :, :] + return x + + def forward(self, query, key, value, pos_emb, key_padding_mask=None, **kwargs): + pos_emb = pos_emb.transpose(0, 1) + q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1)) + q = q.transpose(1, 2) + + return ( + self.forward_attention( + v, + ( + torch.matmul((q + self.pos_bias_u).transpose(1, 2), k.transpose(-2, -1)) + + self.rel_shift( + torch.matmul( + (q + self.pos_bias_v).transpose(1, 2), + self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.h, self.d_k).transpose(1, 2).transpose(-2, -1), + ), + ) + ) + / math.sqrt(self.d_k), + key_padding_mask, + ).transpose(0, 1), + None, + ) + + +class RotaryPositionMultiHeadedAttention(ESPNETMultiHeadedAttention): + def __init__(self, n_feat, n_head, dropout, precision, rotary_emd_base=10000): + super().__init__(n_feat, n_head, dropout) + precision = torch.float + self.rotary_ndims = self.d_k + if precision == "fp16": + precision = torch.half + self.rotary_emb = RotaryPositionalEmbedding(self.rotary_ndims, base=rotary_emd_base, precision=precision) + + def forward(self, query, key, value, key_padding_mask=None, **kwargs): + T, B, C = value.size() + query = query.view(T, B, self.h, self.d_k) + key = key.view(T, B, self.h, self.d_k) + value = value.view(T, B, self.h, self.d_k) + cos, sin = self.rotary_emb(value, seq_len=T) + query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=0) + query = query.view(T, B, self.h * self.d_k) + key = key.view(T, B, self.h * self.d_k) + value = value.view(T, B, self.h * self.d_k) + q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1)) + return self.forward_attention(v, torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None + + +class ConformerEncoderLayer(nn.Module): + def __init__( + self, + embed_dim, + ffn_embed_dim, + attention_heads, + dropout, + use_fp16, + depthwise_conv_kernel_size=31, + activation_fn="swish", + attn_type=None, + pos_enc_type="abs", + ): + self.pos_enc_type = pos_enc_type + super(ConformerEncoderLayer, self).__init__() + self.ffn1 = FeedForwardModule(embed_dim, ffn_embed_dim, dropout, dropout) + self.self_attn_layer_norm = LayerNorm(embed_dim, export=False) + self.self_attn_dropout = nn.Dropout(dropout) + if attn_type == "espnet": + if self.pos_enc_type == "rel_pos": + self.self_attn = RelPositionMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout) + elif self.pos_enc_type == "rope": + self.self_attn = RotaryPositionMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout, precision=use_fp16) + elif self.pos_enc_type == "abs": + self.self_attn = ESPNETMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout) + else: + raise Exception + else: + self.self_attn = MultiheadAttention(embed_dim, attention_heads, dropout=dropout) + self.conv_module = ConvolutionModule( + embed_dim=embed_dim, + channels=embed_dim, + depthwise_kernel_size=depthwise_conv_kernel_size, + dropout=dropout, + activation_fn=activation_fn, + ) + self.ffn2 = FeedForwardModule(embed_dim, ffn_embed_dim, dropout, dropout, activation_fn=activation_fn) + self.final_layer_norm = LayerNorm(embed_dim, export=False) + + def forward(self, x, encoder_padding_mask, position_emb=None): + residual = x + x = self.ffn1(x) * 0.5 + residual + residual = x + x = self.self_attn_layer_norm(x) + if self.pos_enc_type == "rel_pos": + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=encoder_padding_mask, + pos_emb=position_emb, + need_weights=False, + ) + else: + x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=False) + x = self.self_attn_dropout(x) + x = x + residual + residual = x + x = residual + self.conv_module(x.transpose(0, 1)).transpose(0, 1) + residual = x + x = self.ffn2(x) + layer_result = x + x = self.final_layer_norm(x * 0.5 + residual) + return x, (attn, layer_result) + + +class ConformerWav2Vec2EncoderLayer(ConformerEncoderLayer): + def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None, position_emb=None): + return super().forward(x, self_attn_padding_mask, position_emb) + + +class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer): + def __init__( + self, + embedding_dim=768, + ffn_embedding_dim=3072, + num_attention_heads=8, + dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.1, + activation_fn="relu", + layer_norm_first=False, + adapter_num=201, + adapter_dim=64, + adapter_act_fn="relu", + ): + super().__init__( + embedding_dim=embedding_dim, + ffn_embedding_dim=ffn_embedding_dim, + num_attention_heads=num_attention_heads, + dropout=dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + layer_norm_first=layer_norm_first, + ) + self.adapter_num = adapter_num + self.adapter_dim = adapter_dim + self.adapter_layer = AdapterFast(adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn) + + def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None, corpus_key=None): + x, (attn, layer_result) = super().forward( + x=x, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + att_args=att_args, + ) + assert corpus_key is not None + assert len(set(corpus_key)) == 1 + return x + self.adapter_layer(x, corpus_key[0]), (attn, layer_result) + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None, tranpose_dim=-2): + super().__init__() + self.deconstruct_idx = deconstruct_idx + self.tranpose_dim = tranpose_dim + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(self.tranpose_dim, -1) + + +class TransformerEncoder(nn.Module): + def build_encoder_layer(self, args, **kwargs): + if args.layer_type == "transformer": + layer = TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + ) + elif args.layer_type == "conformer": + layer = ConformerWav2Vec2EncoderLayer( + embed_dim=self.embedding_dim, + ffn_embed_dim=args.encoder_ffn_embed_dim, + attention_heads=args.encoder_attention_heads, + dropout=args.dropout, + depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, + activation_fn="swish", + attn_type=args.attn_type, + use_fp16=args.fp16, + pos_enc_type="abs", + ) + elif args.layer_type == "trf_adp": + use_adp = False + if args.adp_trf_idx == "all" or kwargs.get("layer_idx") in list(range(*[int(g) for g in args.adp_trf_idx.split(":")])): + use_adp = True + + layer = ( + TransformerSentenceEncoderWithAdapterLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + adapter_num=args.adp_num, + adapter_dim=args.adp_dim, + adapter_act_fn=args.adp_act_fn, + ) + if use_adp + else TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + ) + ) + + return layer + + def __init__(self, args): + super().__init__() + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + self.required_seq_len_multiple = args.required_seq_len_multiple + pos_conv_depth = getattr(args, "pos_conv_depth", 1) + if pos_conv_depth > 1: + num_layers = args.pos_conv_depth + k = max(3, args.conv_pos // num_layers) + + def make_conv_block(e, k, g, l): + return nn.Sequential( + *[ + nn.Sequential( + nn.Conv1d(e, e, kernel_size=k, padding=k // 2, groups=g), + SamePad(k), + TransposeLast(), + LayerNorm(e, elementwise_affine=False), + TransposeLast(), + nn.GELU(), + ) + for _ in range(l) + ], + ) + + self.pos_conv = make_conv_block(self.embedding_dim, k, args.conv_pos_groups, num_layers) + else: + self.pos_conv = make_conv_pos(self.embedding_dim, args.conv_pos, args.conv_pos_groups) + + self.layers = nn.ModuleList([self.build_encoder_layer(args, layer_idx=ii) for ii in range(args.encoder_layers)]) + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None, layer=None, corpus_key=None): + x, layer_results = self.extract_features(x, padding_mask, layer, corpus_key=corpus_key) + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + return x, layer_results + + def extract_features(self, x, padding_mask=None, tgt_layer=None, min_layer=0, corpus_key=None): + if padding_mask is not None: + x = index_put(x, padding_mask, 0) + x = x + self.pos_conv(x.transpose(1, 2)).transpose(1, 2) + if not self.layer_norm_first: + x = self.layer_norm(x) + x, pad_length = pad_to_multiple(x, self.required_seq_len_multiple, dim=-2, value=0) + if pad_length > 0 and padding_mask is None: + padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool) + padding_mask[:, -pad_length:] = True + else: + padding_mask, _ = pad_to_multiple(padding_mask, self.required_seq_len_multiple, dim=-1, value=True) + x = F.dropout(x, p=self.dropout, training=self.training).transpose(0, 1) + layer_results = [] + r = None + + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() if self.layerdrop > 0 else 1 + if not self.training or (dropout_probability > self.layerdrop): + layer_check = layer + if (corpus_key is None) or (not isinstance(layer_check, (TransformerSentenceEncoderWithAdapterLayer))): + x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask, need_weights=False) + else: + x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, corpus_key=corpus_key) + if i >= min_layer: + layer_results.append((x, z, lr)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + x = x.transpose(0, 1) + + if pad_length > 0: + x = x[:, :-pad_length] + + def undo_pad(a, b, c): + return (a[:-pad_length], b[:-pad_length] if b is not None else b, c[:-pad_length]) + + layer_results = [undo_pad(*u) for u in layer_results] + + return x, layer_results + + def max_positions(self): + return self.args.max_positions + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class ConvFeatureExtractionModel(nn.Module): + def __init__(self, conv_layers, dropout=0.0, mode="default", conv_bias=False): + super().__init__() + assert mode in {"default", "layer_norm"} + + def block(n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert (is_layer_norm and is_group_norm) == False + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential(TransposeLast(), Fp32LayerNorm(dim, elementwise_affine=True), TransposeLast()), + nn.GELU(), + ) + if is_group_norm: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), Fp32GroupNorm(dim, dim, affine=True), nn.GELU()) + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ), + ) + in_d = dim + + def forward(self, x): + x = x.unsqueeze(1) + for conv in self.conv_layers: + x = conv(x) + + return x + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class BaseFairseqModel(nn.Module): + def __init__(self): + super().__init__() + self._is_generation_fast = False + + def get_targets(self, sample, net_output): + return sample["target"] + + def extract_features(self, *args, **kwargs): + return self(*args, **kwargs) + + def load_state_dict(self, state_dict, strict=True, model_cfg=None, args=None): + self.upgrade_state_dict(state_dict) + new_state_dict = prune_state_dict(state_dict, model_cfg) + return super().load_state_dict(new_state_dict, strict) + + def upgrade_state_dict(self, state_dict): + self.upgrade_state_dict_named(state_dict, "") + + def upgrade_state_dict_named(self, state_dict, name): + assert state_dict is not None + + def do_upgrade(m, prefix): + if len(prefix) > 0: + prefix += "." + for n, c in m.named_children(): + name = prefix + n + if hasattr(c, "upgrade_state_dict_named"): + c.upgrade_state_dict_named(state_dict, name) + elif hasattr(c, "upgrade_state_dict"): + c.upgrade_state_dict(state_dict) + do_upgrade(c, name) + + do_upgrade(self, name) + + def make_generation_fast_(self, **kwargs): + if self._is_generation_fast: + return + self._is_generation_fast = True + + def apply_remove_weight_norm(module): + try: + nn.utils.remove_weight_norm(module) + except (AttributeError, ValueError): + return + + self.apply(apply_remove_weight_norm) + + def apply_make_generation_fast_(module, prefix): + if len(prefix) > 0: + prefix += "." + + base_func = BaseFairseqModel.make_generation_fast_ + for n, m in module.named_modules(): + if m != self and hasattr(m, "make_generation_fast_") and m.make_generation_fast_.__func__ is not base_func: + m.make_generation_fast_(name=prefix + n, **kwargs) + + apply_make_generation_fast_(self, "") + self.eval() + + +class HubertConfig: + def __init__( + self, + _name=None, + label_rate=50, + encoder_layers_1=3, + logit_temp_ctr=0.1, + num_negatives=100, + cross_sample_negatives=0, + ctr_layers=[-6], + crop_seq_to_multiple=1, + extractor_mode="default", + encoder_layers=12, + encoder_embed_dim=768, + encoder_ffn_embed_dim=3072, + encoder_attention_heads=12, + activation_fn="gelu", + layer_type="transformer", + dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.0, + encoder_layerdrop=0.0, + dropout_input=0.0, + dropout_features=0.0, + final_dim=0, + untie_final_proj=False, + layer_norm_first=False, + conv_feature_layers="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + conv_bias=False, + logit_temp=0.1, + target_glu=False, + feature_grad_mult=1.0, + mask_length=10, + mask_prob=0.65, + mask_selection="static", + mask_other=0.0, + no_mask_overlap=False, + mask_min_space=1, + mask_channel_length=10, + mask_channel_prob=0.0, + mask_channel_selection="static", + mask_channel_other=0.0, + no_mask_channel_overlap=False, + mask_channel_min_space=1, + conv_pos=128, + conv_pos_groups=16, + conv_pos_batch_norm=False, + latent_temp=(2, 0.5, 0.999995), + skip_masked=False, + skip_nomask=False, + checkpoint_activations=False, + required_seq_len_multiple=2, + depthwise_conv_kernel_size=31, + attn_type="", + pos_enc_type="abs", + fp16=False, + ): + self._name = _name + self.label_rate = label_rate + self.encoder_layers_1 = encoder_layers_1 + self.logit_temp_ctr = logit_temp_ctr + self.num_negatives = num_negatives + self.cross_sample_negatives = cross_sample_negatives + self.ctr_layers = ctr_layers + self.crop_seq_to_multiple = crop_seq_to_multiple + self.extractor_mode = extractor_mode + self.encoder_layers = encoder_layers + self.encoder_embed_dim = encoder_embed_dim + self.encoder_ffn_embed_dim = encoder_ffn_embed_dim + self.encoder_attention_heads = encoder_attention_heads + self.activation_fn = activation_fn + self.layer_type = layer_type + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.encoder_layerdrop = encoder_layerdrop + self.dropout_input = dropout_input + self.dropout_features = dropout_features + self.final_dim = final_dim + self.untie_final_proj = untie_final_proj + self.layer_norm_first = layer_norm_first + self.conv_feature_layers = conv_feature_layers + self.conv_bias = conv_bias + self.logit_temp = logit_temp + self.target_glu = target_glu + self.feature_grad_mult = feature_grad_mult + self.mask_length = mask_length + self.mask_prob = mask_prob + self.mask_selection = mask_selection + self.mask_other = mask_other + self.no_mask_overlap = no_mask_overlap + self.mask_min_space = mask_min_space + self.mask_channel_length = mask_channel_length + self.mask_channel_prob = mask_channel_prob + self.mask_channel_selection = mask_channel_selection + self.mask_channel_other = mask_channel_other + self.no_mask_channel_overlap = no_mask_channel_overlap + self.mask_channel_min_space = mask_channel_min_space + self.conv_pos = conv_pos + self.conv_pos_groups = conv_pos_groups + self.conv_pos_batch_norm = conv_pos_batch_norm + self.latent_temp = latent_temp + self.skip_masked = skip_masked + self.skip_nomask = skip_nomask + self.checkpoint_activations = checkpoint_activations + self.required_seq_len_multiple = required_seq_len_multiple + self.depthwise_conv_kernel_size = depthwise_conv_kernel_size + self.attn_type = attn_type + self.pos_enc_type = pos_enc_type + self.fp16 = fp16 + + +class HubertModel(BaseFairseqModel): + def __init__(self, cfg, num_classes): + super().__init__() + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers]) + self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / 16000 + self.post_extract_proj = nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + self.feature_grad_mult = cfg.feature_grad_mult + self.logit_temp = cfg.logit_temp + self.skip_masked = cfg.skip_masked + self.skip_nomask = cfg.skip_nomask + final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim + self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_()) + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + self.target_glu = None + if cfg.target_glu: + self.target_glu = nn.Sequential(nn.Linear(final_dim, final_dim * 2), nn.GLU()) + self.untie_final_proj = cfg.untie_final_proj + self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) + self.num_classes = [num_classes] + self.label_embs_concat = nn.Parameter(torch.FloatTensor(sum(self.num_classes), final_dim)) + nn.init.uniform_(self.label_embs_concat) + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + def apply_mask(self, x, padding_mask, target_list): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = torch.from_numpy( + compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ), + ).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + x[ + ( + torch.from_numpy( + compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ), + ) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + ] = 0 + return x, mask_indices + + def compute_nce(self, x, pos, negs): + neg_is_pos = (pos == negs).all(-1) + logits = torch.cosine_similarity(x.float(), torch.cat([pos.unsqueeze(0), negs], dim=0).float(), dim=-1).type_as(x) + logits /= self.logit_temp + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + return logits.transpose(0, 1) + + def forward_features(self, source): + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + return features + + def forward_targets(self, features, target_list): + feat_tsz = features.size(2) + targ_tsz = min([t.size(1) for t in target_list]) + if self.feat2tar_ratio * feat_tsz > targ_tsz: + feat_tsz = int(targ_tsz / self.feat2tar_ratio) + features = features[..., :feat_tsz] + + return features, [t[:, (torch.arange(feat_tsz).float() * self.feat2tar_ratio).long()] for t in target_list] + + def forward_padding_mask(self, features, padding_mask): + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + return padding_mask.view(padding_mask.size(0), features.size(1), -1).all(-1) + + def forward(self, source, target_list=None, padding_mask=None, mask=True, features_only=False, output_layer=None): + features = self.forward_features(source) + if target_list is not None: + features, target_list = self.forward_targets(features, target_list) + features_pen = features.float().pow(2).mean() + features = self.layer_norm(features.transpose(1, 2)) + unmasked_features = features.clone() + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + if mask: + x, mask_indices = self.apply_mask(features, padding_mask, target_list) + else: + x, mask_indices = features, None + x, _ = self.encoder(x, padding_mask=padding_mask, layer=None if output_layer is None else output_layer - 1) + if features_only: + return {"x": x, "padding_mask": padding_mask, "features": features} + + def compute_pred(proj_x, target, label_embs): + y = torch.index_select(label_embs, 0, target.long()) + negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1) + if self.target_glu: + y = self.target_glu(y) + negs = self.target_glu(negs) + + return self.compute_nce(proj_x, y, negs) + + label_embs_list = self.label_embs_concat.split(self.num_classes, 0) + if not self.skip_masked: + masked_indices = torch.logical_and(~padding_mask, mask_indices) + proj_x_m = self.final_proj(x[masked_indices]) + logit_m_list = [ + compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) + for i, (proj_x_m, t) in enumerate( + zip( + proj_x_m.chunk(len(target_list), dim=-1) if self.untie_final_proj else [proj_x_m for _ in range(len(target_list))], + target_list, + strict=False, + ), + ) + ] + else: + logit_m_list = [None for _ in target_list] + + if not self.skip_nomask: + nomask_indices = torch.logical_and(~padding_mask, ~mask_indices) + proj_x_u = self.final_proj(x[nomask_indices]) + logit_u_list = [ + compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i]) + for i, (proj_x_u, t) in enumerate( + zip( + proj_x_u.chunk(len(target_list), dim=-1) if self.untie_final_proj else [proj_x_u for _ in range(len(target_list))], + target_list, + strict=False, + ), + ) + ] + else: + logit_u_list = [None for _ in target_list] + + return {"logit_m_list": logit_m_list, "logit_u_list": logit_u_list, "padding_mask": padding_mask, "features_pen": features_pen} + + def extract_features(self, source, padding_mask=None, mask=False, ret_conv=False, output_layer=None): + res = self.forward(source, padding_mask=padding_mask, mask=mask, features_only=True, output_layer=output_layer) + return res["features"] if ret_conv else res["x"], res["padding_mask"] + + def get_logits(self, net_output, is_masked=True): + return [x.float() for x in (net_output["logit_m_list"] if is_masked else net_output["logit_u_list"]) if x is not None] + + def get_targets(self, net_output, is_masked=True): + return [x.new_zeros(x.size(0), dtype=torch.long) for x in self.get_logits(net_output, is_masked)] + + def get_extra_losses(self, net_output): + extra_losses, names = [], [] + if "features_pen" in net_output: + extra_losses.append(net_output["features_pen"]) + names.append("features_pen") + + return extra_losses, names + + def remove_pretraining_modules(self): + self.target_glu = None + self.final_proj = None + + +# Добавьте в конец fairseq.py + +def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): + """Упрощенная версия загрузки checkpoint для CPU""" + state = torch.load(path, map_location=torch.device("cpu"), weights_only=False) + return state + +def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1, state=None): + """ + Упрощенная версия загрузки модели, совместимая с infer.py + """ + if isinstance(filenames, str): + filenames = [filenames] + + ensemble = [] + for filename in filenames: + # Используем существующую функцию load_model + model = load_model(filename) + ensemble.append(model) + + # Возвращаем в формате, ожидаемом infer.py + return ensemble, None, None + +# Создаем алиас для обратной совместимости +load_model_ensemble = load_model_ensemble_and_task \ No newline at end of file diff --git a/mvsepless/vbach_lib/predictors/FCPE.py b/mvsepless/vbach_lib/predictors/FCPE.py new file mode 100644 index 0000000000000000000000000000000000000000..6ae0d956476f3c1d9d7967b05bca84a280d6bfe2 --- /dev/null +++ b/mvsepless/vbach_lib/predictors/FCPE.py @@ -0,0 +1,892 @@ + +from typing import Union + +import torch.nn.functional as F +import numpy as np +import torch +import torch.nn as nn +from torch.nn.utils.parametrizations import weight_norm +from torchaudio.transforms import Resample +import os +import librosa +import soundfile as sf +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +import math +from functools import partial + +from einops import rearrange, repeat +from local_attention import LocalAttention + +os.environ["LRU_CACHE_CAPACITY"] = "3" + + +def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): + try: + data, sample_rate = sf.read(full_path, always_2d=True) + except Exception as error: + print(f"An error occurred loading {full_path}: {error}") + if return_empty_on_exception: + return [], sample_rate or target_sr or 48000 + else: + raise + + data = data[:, 0] if len(data.shape) > 1 else data + assert len(data) > 2 + + max_mag = ( + -np.iinfo(data.dtype).min + if np.issubdtype(data.dtype, np.integer) + else max(np.amax(data), -np.amin(data)) + ) + max_mag = ( + (2**31) + 1 if max_mag > (2**15) else ((2**15) + 1 if max_mag > 1.01 else 1.0) + ) + data = torch.FloatTensor(data.astype(np.float32)) / max_mag + + if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception: + return [], sample_rate or target_sr or 48000 + if target_sr is not None and sample_rate != target_sr: + data = torch.from_numpy( + librosa.core.resample(data.numpy(), orig_sr=sample_rate, target_sr=target_sr) + ) + sample_rate = target_sr + + return data, sample_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +class STFT: + def __init__( + self, + sr=22050, + n_mels=80, + n_fft=1024, + win_size=1024, + hop_length=256, + fmin=20, + fmax=11025, + clip_val=1e-5, + ): + self.target_sr = sr + self.n_mels = n_mels + self.n_fft = n_fft + self.win_size = win_size + self.hop_length = hop_length + self.fmin = fmin + self.fmax = fmax + self.clip_val = clip_val + self.mel_basis = {} + self.hann_window = {} + + def get_mel(self, y, keyshift=0, speed=1, center=False, train=False): + sample_rate = self.target_sr + n_mels = self.n_mels + n_fft = self.n_fft + win_size = self.win_size + hop_length = self.hop_length + fmin = self.fmin + fmax = self.fmax + clip_val = self.clip_val + + factor = 2 ** (keyshift / 12) + n_fft_new = int(np.round(n_fft * factor)) + win_size_new = int(np.round(win_size * factor)) + hop_length_new = int(np.round(hop_length * speed)) + + mel_basis = self.mel_basis if not train else {} + hann_window = self.hann_window if not train else {} + + mel_basis_key = str(fmax) + "_" + str(y.device) + if mel_basis_key not in mel_basis: + mel = librosa_mel_fn( + sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax + ) + mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device) + + keyshift_key = str(keyshift) + "_" + str(y.device) + if keyshift_key not in hann_window: + hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device) + + pad_left = (win_size_new - hop_length_new) // 2 + pad_right = max( + (win_size_new - hop_length_new + 1) // 2, + win_size_new - y.size(-1) - pad_left, + ) + mode = "reflect" if pad_right < y.size(-1) else "constant" + y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode=mode) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft_new, + hop_length=hop_length_new, + win_length=win_size_new, + window=hann_window[keyshift_key], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9)) + + if keyshift != 0: + size = n_fft // 2 + 1 + resize = spec.size(1) + spec = ( + F.pad(spec, (0, 0, 0, size - resize)) + if resize < size + else spec[:, :size, :] + ) + spec = spec * win_size / win_size_new + spec = torch.matmul(mel_basis[mel_basis_key], spec) + spec = dynamic_range_compression_torch(spec, clip_val=clip_val) + return spec + + def __call__(self, audiopath): + audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr) + spect = self.get_mel(audio.unsqueeze(0)).squeeze(0) + return spect + + +stft = STFT() + + +def softmax_kernel( + data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device=None +): + b, h, *_ = data.shape + + data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.0 + + ratio = projection_matrix.shape[0] ** -0.5 + projection = repeat(projection_matrix, "j d -> b h j d", b=b, h=h) + projection = projection.type_as(data) + data_dash = torch.einsum("...id,...jd->...ij", (data_normalizer * data), projection) + + diag_data = data**2 + diag_data = torch.sum(diag_data, dim=-1) + diag_data = (diag_data / 2.0) * (data_normalizer**2) + diag_data = diag_data.unsqueeze(dim=-1) + + if is_query: + data_dash = ratio * ( + torch.exp( + data_dash - diag_data - torch.max(data_dash, dim=-1, keepdim=True).values + ) + + eps + ) + else: + data_dash = ratio * (torch.exp(data_dash - diag_data + eps)) + + return data_dash.type_as(data) + + +def orthogonal_matrix_chunk(cols, qr_uniform_q=False, device=None): + unstructured_block = torch.randn((cols, cols), device=device) + q, r = torch.linalg.qr(unstructured_block.cpu(), mode="reduced") + q, r = map(lambda t: t.to(device), (q, r)) + + if qr_uniform_q: + d = torch.diag(r, 0) + q *= d.sign() + return q.t() + + +def exists(val): + return val is not None + + +def empty(tensor): + return tensor.numel() == 0 + + +def default(val, d): + return val if exists(val) else d + + +def cast_tuple(val): + return (val,) if not isinstance(val, tuple) else val + + +class PCmer(nn.Module): + def __init__( + self, + num_layers, + num_heads, + dim_model, + dim_keys, + dim_values, + residual_dropout, + attention_dropout, + ): + super().__init__() + self.num_layers = num_layers + self.num_heads = num_heads + self.dim_model = dim_model + self.dim_values = dim_values + self.dim_keys = dim_keys + self.residual_dropout = residual_dropout + self.attention_dropout = attention_dropout + + self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(num_layers)]) + + def forward(self, phone, mask=None): + for layer in self._layers: + phone = layer(phone, mask) + return phone + + +class _EncoderLayer(nn.Module): + def __init__(self, parent: PCmer): + super().__init__() + self.conformer = ConformerConvModule(parent.dim_model) + self.norm = nn.LayerNorm(parent.dim_model) + self.dropout = nn.Dropout(parent.residual_dropout) + self.attn = SelfAttention( + dim=parent.dim_model, heads=parent.num_heads, causal=False + ) + + def forward(self, phone, mask=None): + phone = phone + (self.attn(self.norm(phone), mask=mask)) + phone = phone + (self.conformer(phone)) + return phone + + +def calc_same_padding(kernel_size): + pad = kernel_size // 2 + return (pad, pad - (kernel_size + 1) % 2) + + +class Swish(nn.Module): + def forward(self, x): + return x * x.sigmoid() + + +class Transpose(nn.Module): + def __init__(self, dims): + super().__init__() + assert len(dims) == 2, "dims must be a tuple of two dimensions" + self.dims = dims + + def forward(self, x): + return x.transpose(*self.dims) + + +class GLU(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + out, gate = x.chunk(2, dim=self.dim) + return out * gate.sigmoid() + + +class DepthWiseConv1d(nn.Module): + def __init__(self, chan_in, chan_out, kernel_size, padding): + super().__init__() + self.padding = padding + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in) + + def forward(self, x): + x = F.pad(x, self.padding) + return self.conv(x) + + +class ConformerConvModule(nn.Module): + def __init__( + self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0 + ): + super().__init__() + + inner_dim = dim * expansion_factor + padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) + + self.net = nn.Sequential( + nn.LayerNorm(dim), + Transpose((1, 2)), + nn.Conv1d(dim, inner_dim * 2, 1), + GLU(dim=1), + DepthWiseConv1d( + inner_dim, inner_dim, kernel_size=kernel_size, padding=padding + ), + Swish(), + nn.Conv1d(inner_dim, dim, 1), + Transpose((1, 2)), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +def linear_attention(q, k, v): + if v is None: + out = torch.einsum("...ed,...nd->...ne", k, q) + return out + else: + k_cumsum = k.sum(dim=-2) + D_inv = 1.0 / (torch.einsum("...nd,...d->...n", q, k_cumsum.type_as(q)) + 1e-8) + context = torch.einsum("...nd,...ne->...de", k, v) + out = torch.einsum("...de,...nd,...n->...ne", context, q, D_inv) + return out + + +def gaussian_orthogonal_random_matrix( + nb_rows, nb_columns, scaling=0, qr_uniform_q=False, device=None +): + nb_full_blocks = int(nb_rows / nb_columns) + block_list = [] + + for _ in range(nb_full_blocks): + q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device) + block_list.append(q) + + remaining_rows = nb_rows - nb_full_blocks * nb_columns + if remaining_rows > 0: + q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device) + block_list.append(q[:remaining_rows]) + + final_matrix = torch.cat(block_list) + + if scaling == 0: + multiplier = torch.randn((nb_rows, nb_columns), device=device).norm(dim=1) + elif scaling == 1: + multiplier = math.sqrt((float(nb_columns))) * torch.ones( + (nb_rows,), device=device + ) + else: + raise ValueError(f"Invalid scaling {scaling}") + + return torch.diag(multiplier) @ final_matrix + + +class FastAttention(nn.Module): + def __init__( + self, + dim_heads, + nb_features=None, + ortho_scaling=0, + causal=False, + generalized_attention=False, + kernel_fn=nn.ReLU(), + qr_uniform_q=False, + no_projection=False, + ): + super().__init__() + nb_features = default(nb_features, int(dim_heads * math.log(dim_heads))) + + self.dim_heads = dim_heads + self.nb_features = nb_features + self.ortho_scaling = ortho_scaling + + self.create_projection = partial( + gaussian_orthogonal_random_matrix, + nb_rows=self.nb_features, + nb_columns=dim_heads, + scaling=ortho_scaling, + qr_uniform_q=qr_uniform_q, + ) + projection_matrix = self.create_projection() + self.register_buffer("projection_matrix", projection_matrix) + + self.generalized_attention = generalized_attention + self.kernel_fn = kernel_fn + self.no_projection = no_projection + self.causal = causal + + @torch.no_grad() + def redraw_projection_matrix(self): + projections = self.create_projection() + self.projection_matrix.copy_(projections) + del projections + + def forward(self, q, k, v): + device = q.device + + if self.no_projection: + q = q.softmax(dim=-1) + k = torch.exp(k) if self.causal else k.softmax(dim=-2) + else: + create_kernel = partial( + softmax_kernel, projection_matrix=self.projection_matrix, device=device + ) + q = create_kernel(q, is_query=True) + k = create_kernel(k, is_query=False) + + attn_fn = linear_attention if not self.causal else self.causal_linear_fn + + if v is None: + out = attn_fn(q, k, None) + return out + else: + out = attn_fn(q, k, v) + return out + + +class SelfAttention(nn.Module): + def __init__( + self, + dim, + causal=False, + heads=8, + dim_head=64, + local_heads=0, + local_window_size=256, + nb_features=None, + feature_redraw_interval=1000, + generalized_attention=False, + kernel_fn=nn.ReLU(), + qr_uniform_q=False, + dropout=0.0, + no_projection=False, + ): + super().__init__() + assert dim % heads == 0, "dimension must be divisible by number of heads" + dim_head = default(dim_head, dim // heads) + inner_dim = dim_head * heads + self.fast_attention = FastAttention( + dim_head, + nb_features, + causal=causal, + generalized_attention=generalized_attention, + kernel_fn=kernel_fn, + qr_uniform_q=qr_uniform_q, + no_projection=no_projection, + ) + + self.heads = heads + self.global_heads = heads - local_heads + self.local_attn = ( + LocalAttention( + window_size=local_window_size, + causal=causal, + autopad=True, + dropout=dropout, + look_forward=int(not causal), + rel_pos_emb_config=(dim_head, local_heads), + ) + if local_heads > 0 + else None + ) + + self.to_q = nn.Linear(dim, inner_dim) + self.to_k = nn.Linear(dim, inner_dim) + self.to_v = nn.Linear(dim, inner_dim) + self.to_out = nn.Linear(inner_dim, dim) + self.dropout = nn.Dropout(dropout) + + @torch.no_grad() + def redraw_projection_matrix(self): + self.fast_attention.redraw_projection_matrix() + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + name=None, + inference=False, + **kwargs, + ): + _, _, _, h, gh = *x.shape, self.heads, self.global_heads + + cross_attend = exists(context) + context = default(context, x) + context_mask = default(context_mask, mask) if not cross_attend else context_mask + q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v)) + + attn_outs = [] + if not empty(q): + if exists(context_mask): + global_mask = context_mask[:, None, :, None] + v.masked_fill_(~global_mask, 0.0) + if cross_attend: + pass + else: + out = self.fast_attention(q, k, v) + attn_outs.append(out) + + if not empty(lq): + assert ( + not cross_attend + ), "local attention is not compatible with cross attention" + out = self.local_attn(lq, lk, lv, input_mask=mask) + attn_outs.append(out) + + out = torch.cat(attn_outs, dim=1) + out = rearrange(out, "b h n d -> b n (h d)") + out = self.to_out(out) + return self.dropout(out) + + +def l2_regularization(model, l2_alpha): + l2_loss = [] + for module in model.modules(): + if type(module) is nn.Conv2d: + l2_loss.append((module.weight**2).sum() / 2.0) + return l2_alpha * sum(l2_loss) + + +class FCPE(nn.Module): + def __init__( + self, + input_channel=128, + out_dims=360, + n_layers=12, + n_chans=512, + use_siren=False, + use_full=False, + loss_mse_scale=10, + loss_l2_regularization=False, + loss_l2_regularization_scale=1, + loss_grad1_mse=False, + loss_grad1_mse_scale=1, + f0_max=1975.5, + f0_min=32.70, + confidence=False, + threshold=0.05, + use_input_conv=True, + ): + super().__init__() + if use_siren is True: + raise ValueError("Siren is not supported yet.") + if use_full is True: + raise ValueError("Full model is not supported yet.") + + self.loss_mse_scale = loss_mse_scale if (loss_mse_scale is not None) else 10 + self.loss_l2_regularization = ( + loss_l2_regularization if (loss_l2_regularization is not None) else False + ) + self.loss_l2_regularization_scale = ( + loss_l2_regularization_scale + if (loss_l2_regularization_scale is not None) + else 1 + ) + self.loss_grad1_mse = loss_grad1_mse if (loss_grad1_mse is not None) else False + self.loss_grad1_mse_scale = ( + loss_grad1_mse_scale if (loss_grad1_mse_scale is not None) else 1 + ) + self.f0_max = f0_max if (f0_max is not None) else 1975.5 + self.f0_min = f0_min if (f0_min is not None) else 32.70 + self.confidence = confidence if (confidence is not None) else False + self.threshold = threshold if (threshold is not None) else 0.05 + self.use_input_conv = use_input_conv if (use_input_conv is not None) else True + + self.cent_table_b = torch.Tensor( + np.linspace( + self.f0_to_cent(torch.Tensor([f0_min]))[0], + self.f0_to_cent(torch.Tensor([f0_max]))[0], + out_dims, + ) + ) + self.register_buffer("cent_table", self.cent_table_b) + + _leaky = nn.LeakyReLU() + self.stack = nn.Sequential( + nn.Conv1d(input_channel, n_chans, 3, 1, 1), + nn.GroupNorm(4, n_chans), + _leaky, + nn.Conv1d(n_chans, n_chans, 3, 1, 1), + ) + + self.decoder = PCmer( + num_layers=n_layers, + num_heads=8, + dim_model=n_chans, + dim_keys=n_chans, + dim_values=n_chans, + residual_dropout=0.1, + attention_dropout=0.1, + ) + self.norm = nn.LayerNorm(n_chans) + + self.n_out = out_dims + self.dense_out = weight_norm(nn.Linear(n_chans, self.n_out)) + + def forward( + self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder="local_argmax" + ): + if cdecoder == "argmax": + self.cdecoder = self.cents_decoder + elif cdecoder == "local_argmax": + self.cdecoder = self.cents_local_decoder + + x = ( + self.stack(mel.transpose(1, 2)).transpose(1, 2) + if self.use_input_conv + else mel + ) + x = self.decoder(x) + x = self.norm(x) + x = self.dense_out(x) + x = torch.sigmoid(x) + + if not infer: + gt_cent_f0 = self.f0_to_cent(gt_f0) + gt_cent_f0 = self.gaussian_blurred_cent(gt_cent_f0) + loss_all = self.loss_mse_scale * F.binary_cross_entropy(x, gt_cent_f0) + if self.loss_l2_regularization: + loss_all = loss_all + l2_regularization( + model=self, l2_alpha=self.loss_l2_regularization_scale + ) + x = loss_all + if infer: + x = self.cdecoder(x) + x = self.cent_to_f0(x) + x = (1 + x / 700).log() if not return_hz_f0 else x + + return x + + def cents_decoder(self, y, mask=True): + B, N, _ = y.size() + ci = self.cent_table[None, None, :].expand(B, N, -1) + rtn = torch.sum(ci * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True) + if mask: + confident = torch.max(y, dim=-1, keepdim=True)[0] + confident_mask = torch.ones_like(confident) + confident_mask[confident <= self.threshold] = float("-INF") + rtn = rtn * confident_mask + return (rtn, confident) if self.confidence else rtn + + def cents_local_decoder(self, y, mask=True): + B, N, _ = y.size() + ci = self.cent_table[None, None, :].expand(B, N, -1) + confident, max_index = torch.max(y, dim=-1, keepdim=True) + local_argmax_index = torch.arange(0, 9).to(max_index.device) + (max_index - 4) + local_argmax_index = torch.clamp(local_argmax_index, 0, self.n_out - 1) + ci_l = torch.gather(ci, -1, local_argmax_index) + y_l = torch.gather(y, -1, local_argmax_index) + rtn = torch.sum(ci_l * y_l, dim=-1, keepdim=True) / torch.sum( + y_l, dim=-1, keepdim=True + ) + if mask: + confident_mask = torch.ones_like(confident) + confident_mask[confident <= self.threshold] = float("-INF") + rtn = rtn * confident_mask + return (rtn, confident) if self.confidence else rtn + + def cent_to_f0(self, cent): + return 10.0 * 2 ** (cent / 1200.0) + + def f0_to_cent(self, f0): + return 1200.0 * torch.log2(f0 / 10.0) + + def gaussian_blurred_cent(self, cents): + mask = (cents > 0.1) & (cents < (1200.0 * np.log2(self.f0_max / 10.0))) + B, N, _ = cents.size() + ci = self.cent_table[None, None, :].expand(B, N, -1) + return torch.exp(-torch.square(ci - cents) / 1250) * mask.float() + + +class FCPEInfer: + def __init__(self, model_path, device=None, dtype=torch.float32): + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = device + ckpt = torch.load(model_path, map_location=torch.device(self.device)) + self.args = DotDict(ckpt["config"]) + self.dtype = dtype + model = FCPE( + input_channel=self.args.model.input_channel, + out_dims=self.args.model.out_dims, + n_layers=self.args.model.n_layers, + n_chans=self.args.model.n_chans, + use_siren=self.args.model.use_siren, + use_full=self.args.model.use_full, + loss_mse_scale=self.args.loss.loss_mse_scale, + loss_l2_regularization=self.args.loss.loss_l2_regularization, + loss_l2_regularization_scale=self.args.loss.loss_l2_regularization_scale, + loss_grad1_mse=self.args.loss.loss_grad1_mse, + loss_grad1_mse_scale=self.args.loss.loss_grad1_mse_scale, + f0_max=self.args.model.f0_max, + f0_min=self.args.model.f0_min, + confidence=self.args.model.confidence, + ) + model.to(self.device).to(self.dtype) + model.load_state_dict(ckpt["model"]) + model.eval() + self.model = model + self.wav2mel = Wav2Mel(self.args, dtype=self.dtype, device=self.device) + + @torch.no_grad() + def __call__(self, audio, sr, threshold=0.05): + self.model.threshold = threshold + audio = audio[None, :] + mel = self.wav2mel(audio=audio, sample_rate=sr).to(self.dtype) + f0 = self.model(mel=mel, infer=True, return_hz_f0=True) + return f0 + + +class Wav2Mel: + def __init__(self, args, device=None, dtype=torch.float32): + self.sample_rate = args.mel.sampling_rate + self.hop_size = args.mel.hop_size + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = device + self.dtype = dtype + self.stft = STFT( + args.mel.sampling_rate, + args.mel.num_mels, + args.mel.n_fft, + args.mel.win_size, + args.mel.hop_size, + args.mel.fmin, + args.mel.fmax, + ) + self.resample_kernel = {} + + def extract_nvstft(self, audio, keyshift=0, train=False): + mel = self.stft.get_mel(audio, keyshift=keyshift, train=train).transpose(1, 2) + return mel + + def extract_mel(self, audio, sample_rate, keyshift=0, train=False): + audio = audio.to(self.dtype).to(self.device) + if sample_rate == self.sample_rate: + audio_res = audio + else: + key_str = str(sample_rate) + if key_str not in self.resample_kernel: + self.resample_kernel[key_str] = Resample( + sample_rate, self.sample_rate, lowpass_filter_width=128 + ) + self.resample_kernel[key_str] = ( + self.resample_kernel[key_str].to(self.dtype).to(self.device) + ) + audio_res = self.resample_kernel[key_str](audio) + + mel = self.extract_nvstft(audio_res, keyshift=keyshift, train=train) + n_frames = int(audio.shape[1] // self.hop_size) + 1 + mel = torch.cat((mel, mel[:, -1:, :]), 1) if n_frames > int(mel.shape[1]) else mel + mel = mel[:, :n_frames, :] if n_frames < int(mel.shape[1]) else mel + return mel + + def __call__(self, audio, sample_rate, keyshift=0, train=False): + return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train) + + +class DotDict(dict): + def __getattr__(*args): + val = dict.get(*args) + return DotDict(val) if type(val) is dict else val + + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +class F0Predictor(object): + def compute_f0(self, wav, p_len): + pass + + def compute_f0_uv(self, wav, p_len): + pass + + +class FCPEF0Predictor(F0Predictor): + def __init__( + self, + model_path, + hop_length=512, + f0_min=50, + f0_max=1100, + dtype=torch.float32, + device=None, + sample_rate=44100, + threshold=0.05, + ): + self.fcpe = FCPEInfer(model_path, device=device, dtype=dtype) + self.hop_length = hop_length + self.f0_min = f0_min + self.f0_max = f0_max + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self.threshold = threshold + self.sample_rate = sample_rate + self.dtype = dtype + self.name = "fcpe" + + def repeat_expand( + self, + content: Union[torch.Tensor, np.ndarray], + target_len: int, + mode: str = "nearest", + ): + ndim = content.ndim + content = ( + content[None, None] if ndim == 1 else content[None] if ndim == 2 else content + ) + assert content.ndim == 3 + is_np = isinstance(content, np.ndarray) + content = torch.from_numpy(content) if is_np else content + results = torch.nn.functional.interpolate(content, size=target_len, mode=mode) + results = results.numpy() if is_np else results + return results[0, 0] if ndim == 1 else results[0] if ndim == 2 else results + + def post_process(self, x, sample_rate, f0, pad_to): + f0 = ( + torch.from_numpy(f0).float().to(x.device) + if isinstance(f0, np.ndarray) + else f0 + ) + f0 = self.repeat_expand(f0, pad_to) if pad_to is not None else f0 + + vuv_vector = torch.zeros_like(f0) + vuv_vector[f0 > 0.0] = 1.0 + vuv_vector[f0 <= 0.0] = 0.0 + + nzindex = torch.nonzero(f0).squeeze() + f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy() + time_org = self.hop_length / sample_rate * nzindex.cpu().numpy() + time_frame = np.arange(pad_to) * self.hop_length / sample_rate + + vuv_vector = F.interpolate(vuv_vector[None, None, :], size=pad_to)[0][0] + + if f0.shape[0] <= 0: + return np.zeros(pad_to), vuv_vector.cpu().numpy() + if f0.shape[0] == 1: + return np.ones(pad_to) * f0[0], vuv_vector.cpu().numpy() + + f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) + return f0, vuv_vector.cpu().numpy() + + def compute_f0(self, wav, p_len=None): + x = torch.FloatTensor(wav).to(self.dtype).to(self.device) + p_len = x.shape[0] // self.hop_length if p_len is None else p_len + f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold)[0, :, 0] + if torch.all(f0 == 0): + return f0.cpu().numpy() if p_len is None else np.zeros(p_len), ( + f0.cpu().numpy() if p_len is None else np.zeros(p_len) + ) + return self.post_process(x, self.sample_rate, f0, p_len)[0] + + def compute_f0_uv(self, wav, p_len=None): + x = torch.FloatTensor(wav).to(self.dtype).to(self.device) + p_len = x.shape[0] // self.hop_length if p_len is None else p_len + f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold)[0, :, 0] + if torch.all(f0 == 0): + return f0.cpu().numpy() if p_len is None else np.zeros(p_len), ( + f0.cpu().numpy() if p_len is None else np.zeros(p_len) + ) + return self.post_process(x, self.sample_rate, f0, p_len) + diff --git a/mvsepless/vbach_lib/predictors/RMVPE.py b/mvsepless/vbach_lib/predictors/RMVPE.py new file mode 100644 index 0000000000000000000000000000000000000000..88716914ad4754dda0a99bab6ac66bb0cb49ff92 --- /dev/null +++ b/mvsepless/vbach_lib/predictors/RMVPE.py @@ -0,0 +1,518 @@ + +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from librosa.filters import mel +from scipy.signal import get_window +from librosa.util import pad_center, tiny, normalize + + +def window_sumsquare( + window, + n_frames, + hop_length=200, + win_length=800, + n_fft=800, + dtype=np.float32, + norm=None, +): + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + win_sq = get_window(window, win_length, fftbins=True) + win_sq = normalize(win_sq, norm=norm) ** 2 + win_sq = pad_center(win_sq, n_fft) + + for i in range(n_frames): + sample = i * hop_length + x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] + return x + + +class STFT(nn.Module): + def __init__( + self, filter_length=1024, hop_length=512, win_length=None, window="hann" + ): + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length if win_length else filter_length + self.window = window + self.pad_amount = int(self.filter_length / 2) + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack( + [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] + ) + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :] + ) + + assert filter_length >= self.win_length + fft_window = get_window(window, self.win_length, fftbins=True) + fft_window = pad_center(fft_window, size=filter_length) + fft_window = torch.from_numpy(fft_window).float() + + forward_basis *= fft_window + inverse_basis *= fft_window + + self.register_buffer("forward_basis", forward_basis.float()) + self.register_buffer("inverse_basis", inverse_basis.float()) + + def transform(self, input_data): + num_batches = input_data.shape[0] + num_samples = input_data.shape[-1] + + input_data = input_data.view(num_batches, 1, num_samples) + input_data = F.pad( + input_data.unsqueeze(1), + (self.pad_amount, self.pad_amount, 0, 0, 0, 0), + mode="reflect", + ).squeeze(1) + forward_transform = F.conv1d( + input_data, self.forward_basis, stride=self.hop_length, padding=0 + ) + + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + return torch.sqrt(real_part**2 + imag_part**2) + + def inverse(self, magnitude, phase): + recombine_magnitude_phase = torch.cat( + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 + ) + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + self.inverse_basis, + stride=self.hop_length, + padding=0, + ) + + if self.window is not None: + window_sum = window_sumsquare( + self.window, + magnitude.size(-1), + hop_length=self.hop_length, + win_length=self.win_length, + n_fft=self.filter_length, + dtype=np.float32, + ) + approx_nonzero_indices = torch.from_numpy( + np.where(window_sum > tiny(window_sum))[0] + ) + window_sum = torch.from_numpy(window_sum).to(inverse_transform.device) + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ + approx_nonzero_indices + ] + inverse_transform *= float(self.filter_length) / self.hop_length + + inverse_transform = inverse_transform[..., self.pad_amount :] + inverse_transform = inverse_transform[..., : self.num_samples] + return inverse_transform.squeeze(1) + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + return self.inverse(self.magnitude, self.phase) + + +class BiGRU(nn.Module): + def __init__(self, input_features, hidden_features, num_layers): + super(BiGRU, self).__init__() + self.gru = nn.GRU( + input_features, + hidden_features, + num_layers=num_layers, + batch_first=True, + bidirectional=True, + ) + + def forward(self, x): + return self.gru(x)[0] + + +class ConvBlockRes(nn.Module): + def __init__(self, in_channels, out_channels, momentum=0.01): + super(ConvBlockRes, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + ) + self.shortcut = ( + nn.Conv2d(in_channels, out_channels, (1, 1)) + if in_channels != out_channels + else None + ) + + def forward(self, x): + out = self.conv(x) + if self.shortcut is not None: + x = self.shortcut(x) + return out + x + + +class ResEncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01): + super(ResEncoderBlock, self).__init__() + self.conv = nn.ModuleList( + [ + ConvBlockRes( + in_channels if i == 0 else out_channels, out_channels, momentum + ) + for i in range(n_blocks) + ] + ) + self.pool = ( + nn.AvgPool2d(kernel_size=kernel_size) if kernel_size is not None else None + ) + + def forward(self, x): + for conv in self.conv: + x = conv(x) + pooled = self.pool(x) if self.pool is not None else x + return pooled, x + + +class Encoder(nn.Module): + def __init__( + self, + in_channels, + in_size, + n_encoders, + kernel_size, + n_blocks, + out_channels=16, + momentum=0.01, + ): + super(Encoder, self).__init__() + self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) + self.layers = nn.ModuleList() + self.latent_channels = [] + for _ in range(n_encoders): + self.layers.append( + ResEncoderBlock( + in_channels, out_channels, kernel_size, n_blocks, momentum=momentum + ) + ) + self.latent_channels.append([out_channels, in_size]) + in_channels = out_channels + out_channels *= 2 + in_size //= 2 + self.out_size = in_size + self.out_channel = out_channels + + def forward(self, x): + concat_tensors = [] + x = self.bn(x) + for layer in self.layers: + x, pooled = layer(x) + concat_tensors.append(pooled) + return x, concat_tensors + + +class Intermediate(nn.Module): + def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01): + super(Intermediate, self).__init__() + self.layers = nn.ModuleList( + [ + ResEncoderBlock( + in_channels if i == 0 else out_channels, + out_channels, + None, + n_blocks, + momentum, + ) + for i in range(n_inters) + ] + ) + + def forward(self, x): + for layer in self.layers: + _, x = layer(x) + return x + + +class ResDecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01): + super(ResDecoderBlock, self).__init__() + out_padding = (0, 1) if stride == (1, 2) else (1, 1) + self.conv1 = nn.Sequential( + nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=stride, + padding=(1, 1), + output_padding=out_padding, + bias=False, + ), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + ) + self.conv2 = nn.ModuleList( + [ + ConvBlockRes( + out_channels * 2 if i == 0 else out_channels, out_channels, momentum + ) + for i in range(n_blocks) + ] + ) + + def forward(self, x, concat_tensor): + x = self.conv1(x) + x = torch.cat((x, concat_tensor), dim=1) + for conv in self.conv2: + x = conv(x) + return x + + +class Decoder(nn.Module): + def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01): + super(Decoder, self).__init__() + self.layers = nn.ModuleList() + for _ in range(n_decoders): + out_channels = in_channels // 2 + self.layers.append( + ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum) + ) + in_channels = out_channels + + def forward(self, x, concat_tensors): + for layer, concat_tensor in zip(self.layers, reversed(concat_tensors)): + x = layer(x, concat_tensor) + return x + + +class DeepUnet(nn.Module): + def __init__( + self, + kernel_size, + n_blocks, + en_de_layers=5, + inter_layers=4, + in_channels=1, + en_out_channels=16, + ): + super(DeepUnet, self).__init__() + self.encoder = Encoder( + in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels + ) + self.intermediate = Intermediate( + self.encoder.out_channel // 2, + self.encoder.out_channel, + inter_layers, + n_blocks, + ) + self.decoder = Decoder( + self.encoder.out_channel, en_de_layers, kernel_size, n_blocks + ) + + def forward(self, x): + x, concat_tensors = self.encoder(x) + x = self.intermediate(x) + return self.decoder(x, concat_tensors) + + +class E2E(nn.Module): + def __init__( + self, + n_blocks, + n_gru, + kernel_size, + en_de_layers=5, + inter_layers=4, + in_channels=1, + en_out_channels=16, + ): + super(E2E, self).__init__() + self.unet = DeepUnet( + kernel_size, + n_blocks, + en_de_layers, + inter_layers, + in_channels, + en_out_channels, + ) + self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) + if n_gru: + self.fc = nn.Sequential( + BiGRU(3 * 128, 256, n_gru), + nn.Linear(512, 360), + nn.Dropout(0.25), + nn.Sigmoid(), + ) + else: + self.fc = nn.Sequential( + nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid() + ) + + def forward(self, mel): + mel = mel.transpose(-1, -2).unsqueeze(1) + x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) + return self.fc(x) + + +class MelSpectrogram(nn.Module): + def __init__( + self, + is_half, + n_mel_channels, + sample_rate, + win_length, + hop_length, + n_fft=None, + mel_fmin=0, + mel_fmax=None, + clamp=1e-5, + ): + super(MelSpectrogram, self).__init__() + n_fft = win_length if n_fft is None else n_fft + self.hann_window = {} + mel_basis = mel( + sr=sample_rate, + n_fft=n_fft, + n_mels=n_mel_channels, + fmin=mel_fmin, + fmax=mel_fmax, + htk=True, + ) + self.register_buffer("mel_basis", torch.from_numpy(mel_basis).float()) + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.sample_rate = sample_rate + self.n_mel_channels = n_mel_channels + self.clamp = clamp + self.is_half = is_half + + def forward(self, audio, keyshift=0, speed=1, center=True): + factor = 2 ** (keyshift / 12) + n_fft_new = int(np.round(self.n_fft * factor)) + win_length_new = int(np.round(self.win_length * factor)) + hop_length_new = int(np.round(self.hop_length * speed)) + keyshift_key = f"{keyshift}_{audio.device}" + if keyshift_key not in self.hann_window: + self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to( + audio.device + ) + if not hasattr(self, "stft"): + self.stft = STFT( + filter_length=n_fft_new, + hop_length=hop_length_new, + win_length=win_length_new, + window="hann", + ).to(audio.device) + magnitude = self.stft.transform(audio) + if keyshift != 0: + size = self.n_fft // 2 + 1 + resize = magnitude.size(1) + if resize < size: + magnitude = F.pad(magnitude, (0, 0, 0, size - resize)) + magnitude = magnitude[:, :size, :] * self.win_length / win_length_new + mel_output = torch.matmul(self.mel_basis, magnitude) + if self.is_half: + mel_output = mel_output.half() + return torch.log(torch.clamp(mel_output, min=self.clamp)) + + +class RMVPE0Predictor: + def __init__(self, model_path, is_half, device=None): + self.resample_kernel = {} + self.is_half = is_half + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = device + self.mel_extractor = MelSpectrogram( + is_half, 128, 16000, 1024, 160, None, 30, 8000 + ).to(device) + model = E2E(4, 1, (2, 2)) + ckpt = torch.load(model_path, map_location="cpu", weights_only=True) + model.load_state_dict(ckpt) + model.eval() + if is_half: + model = model.half() + self.model = model.to(device) + self.cents_mapping = np.pad(20 * np.arange(360) + 1997.3794084376191, (4, 4)) + + def mel2hidden(self, mel): + with torch.no_grad(): + n_frames = mel.shape[-1] + mel = mel.float() + padding = min(32 * ((n_frames - 1) // 32 + 1) - n_frames, n_frames) + mel = F.pad(mel, (0, padding), mode="reflect") + if self.is_half: + mel = mel.half() + hidden = self.model(mel) + return hidden[:, :n_frames] + + def decode(self, hidden, thred=0.03): + cents_pred = self.to_local_average_cents(hidden, thred=thred) + f0 = 10 * (2 ** (cents_pred / 1200)) + f0[f0 == 10] = 0 + return f0 + + def infer_from_audio(self, audio, thred=0.03): + audio = torch.from_numpy(audio).float().to(self.device).unsqueeze(0) + mel = self.mel_extractor(audio, center=True) + hidden = self.mel2hidden(mel) + hidden = hidden.squeeze(0).cpu().numpy() + if self.is_half: + hidden = hidden.astype("float32") + return self.decode(hidden, thred=thred) + + def infer_from_audio_with_pitch(self, audio, thred=0.03, f0_min=50, f0_max=1100): + audio = torch.from_numpy(audio).float().to(self.device).unsqueeze(0) + mel = self.mel_extractor(audio, center=True) + hidden = self.mel2hidden(mel) + hidden = hidden.squeeze(0).cpu().numpy() + if self.is_half: + hidden = hidden.astype("float32") + f0 = self.decode(hidden, thred=thred) + f0[(f0 < f0_min) | (f0 > f0_max)] = 0 + return f0 + + def to_local_average_cents(self, salience, thred=0.05): + center = np.argmax(salience, axis=1) + salience = np.pad(salience, ((0, 0), (4, 4))) + center += 4 + todo_salience = [] + todo_cents_mapping = [] + starts = center - 4 + ends = center + 5 + for idx in range(salience.shape[0]): + todo_salience.append(salience[:, starts[idx] : ends[idx]][idx]) + todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]]) + todo_salience = np.array(todo_salience) + todo_cents_mapping = np.array(todo_cents_mapping) + product_sum = np.sum(todo_salience * todo_cents_mapping, 1) + weight_sum = np.sum(todo_salience, 1) + divided = product_sum / weight_sum + maxx = np.max(salience, axis=1) + divided[maxx <= thred] = 0 + return divided +