| import sys
|
| import os
|
| import yaml
|
| import time
|
| import traceback
|
| import gc
|
| gc.enable()
|
| import librosa
|
| from pathlib import Path
|
| BASE_DIR = Path(__file__).resolve().parent
|
| sys.path.append(str(BASE_DIR))
|
|
|
| from extra_utils import hf_spaces_gpu, dw_file, extra_clear_torch_cache, nuclear_clear_model, emergency_ram_clear
|
| import torch
|
| import rich
|
| nn = torch.nn
|
| import json
|
| from tqdm import tqdm
|
| import numpy as np
|
| from typing import Literal, Optional, List, Tuple, Any, Dict
|
| from ml_collections import ConfigDict
|
| from omegaconf import OmegaConf
|
| import gradio as gr
|
| from audio import read, write, output_formats, subtractor, check, easy_resampler, ensemble_types, ensemble, multiread, get_audio_files_from_list, stereo_to_mono
|
| from args_parser import parse_separator_args, tobool
|
| from namer import Namer
|
| from i18n import _i18n
|
| import contextlib
|
|
|
| class PathNotExist(Exception): pass
|
| class PathsNotExist(Exception): pass
|
| class PathNotSpecified(Exception): pass
|
| class PathsNotSpecified(Exception): pass
|
| class FileIsNotAudio(Exception): pass
|
| class FilesIsNotAudio(Exception): pass
|
| class MixNotFound(Exception): pass
|
| class MixIsEmpty(Exception): pass
|
| class UnknownModelType(Exception): pass
|
| class DemixError(Exception): pass
|
| class ConfigNotLoaded(Exception): pass
|
| class ModelNotLoaded(Exception): pass
|
| class ModelStateDictError(Exception): pass
|
|
|
| HAS_OLD_AMP = False
|
| if hasattr(torch, "cuda"):
|
| if hasattr(torch.cuda, "amp"):
|
| if hasattr(torch.cuda.amp, "autocast"):
|
| HAS_OLD_AMP = True
|
| HAS_NEW_AMP = False
|
| if hasattr(torch, "amp"):
|
| if hasattr(torch.amp, "autocast"):
|
| HAS_NEW_AMP = True
|
|
|
| def get_autocast_context(device_type="cuda", enabled=True):
|
| if HAS_NEW_AMP:
|
| return torch.amp.autocast(device_type=device_type, enabled=enabled)
|
| elif HAS_OLD_AMP:
|
| return torch.cuda.amp.autocast(enabled=enabled)
|
| else:
|
|
|
| return contextlib.nullcontext()
|
|
|
| def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor:
|
| """
|
| Создать массив окна для плавного склеивания
|
|
|
| Args:
|
| window_size: Размер окна
|
| fade_size: Размер зоны затухания
|
|
|
| Returns:
|
| Массив окна
|
| """
|
| fadein = torch.linspace(0, 1, fade_size)
|
| fadeout = torch.linspace(1, 0, fade_size)
|
|
|
| window = torch.ones(window_size)
|
| window[-fade_size:] = fadeout
|
| window[:fade_size] = fadein
|
| return window
|
|
|
| base_params = {
|
| "sec": {
|
| "type": "float",
|
| "component": "number",
|
| "minimum": 1,
|
| "maximum": 30,
|
| "step": 0.1,
|
| "default": 7,
|
| "info": "separation_segment_size_info"
|
| },
|
| "size": {
|
| "type": "int",
|
| "component": "slider",
|
| "minimum": 128,
|
| "maximum": 1024,
|
| "step": 128,
|
| "default": 256,
|
| "info": "separation_segment_size_info"
|
| },
|
| "wsize": {
|
| "type": "int",
|
| "component": "slider",
|
| "minimum": 320,
|
| "maximum": 1024,
|
| "step": 64,
|
| "default": 512,
|
| "info": "separation_window_size_info"
|
| },
|
| "hop": {
|
| "type": "int",
|
| "component": "slider",
|
| "minimum": 512,
|
| "maximum": 2048,
|
| "step": 512,
|
| "default": 1024,
|
| "info": "separation_hop_info"
|
| },
|
| "overlap": {
|
| "type": "int",
|
| "component": "slider",
|
| "minimum": 1,
|
| "maximum": 16,
|
| "step": 1,
|
| "default": 2,
|
| "info": "separation_overlap_info"
|
| },
|
| "batch": {
|
| "type": "int",
|
| "component": "slider",
|
| "minimum": 1,
|
| "maximum": 16,
|
| "step": 1,
|
| "default": 1,
|
| "info": "separation_batch_size_info"
|
| },
|
| "threshold" : {
|
| "type": "float",
|
| "component": "slider",
|
| "minimum": 0.1,
|
| "maximum": 0.3,
|
| "step": 0.1,
|
| "default": 0.2,
|
| },
|
| "aggression": {
|
| "type": "int",
|
| "component": "slider",
|
| "minimum": 0,
|
| "maximum": 100,
|
| "step": 1,
|
| "default": 5,
|
| "info": "separation_aggresion_info"
|
| },
|
| "enable": {
|
| "type": "bool",
|
| "component": "checkbox",
|
| "default": True
|
| },
|
| "disable": {
|
| "type": "bool",
|
| "component": "checkbox",
|
| "default": False
|
| },
|
| }
|
|
|
| add_params = {
|
| "mdxc": {
|
| "mdxc_segment_size": base_params["size"],
|
| "mdxc_batch_size": base_params["batch"],
|
| "mdxc_overlap": base_params["overlap"],
|
| "mdxc_denoise": base_params["disable"],
|
| "mdxc_override_segment": base_params["disable"]
|
| },
|
| "demucs": {
|
| "demucs_segment": base_params["sec"],
|
| "demucs_batch_size": base_params["batch"],
|
| "demucs_overlap": base_params["overlap"],
|
| "demucs_denoise": base_params["disable"],
|
| "demucs_override_segment": base_params["disable"]
|
| },
|
| "mdx": {
|
| "mdx_hop_length": base_params["hop"],
|
| "mdx_segment_size": base_params["size"],
|
| "mdx_batch_size": base_params["batch"],
|
| "mdx_overlap": base_params["overlap"],
|
| "mdx_denoise": base_params["disable"],
|
| "mdx_override_segment": base_params["disable"]
|
| },
|
| "vr": {
|
| "vr_window_size": base_params["wsize"],
|
| "vr_batch_size": base_params["batch"],
|
| "vr_aggression": base_params["aggression"],
|
| "vr_post_process": base_params["disable"],
|
| "vr_post_process_threshold": base_params["threshold"],
|
| "vr_high_end_process": {**base_params["disable"], "info": "separation_hi-end_process_info"}
|
| },
|
| "mvox": {
|
| "mvox_segment": base_params["sec"],
|
| "mvox_overlap": base_params["overlap"],
|
| "mvox_override_segment": base_params["disable"]
|
| }
|
| }
|
| add_params_list = []
|
| add_params_group = []
|
| add_params_args = {}
|
| for t_tab, t_components in add_params.items():
|
| add_params_group.append(t_tab)
|
| for t_component, t_settings in t_components.items():
|
| add_params_list.append(t_component)
|
| add_params_args[t_component] = {"default": t_settings["default"], "type": t_settings["type"]}
|
|
|
| def get_add_params(args):
|
| """Безопасно получает add_params из args"""
|
| if hasattr(args, 'add_params') and args.add_params is not None:
|
| return vars(args.add_params)
|
| return {}
|
|
|
| class MSSI:
|
| def __init__(self,
|
| output_dir=".",
|
| output_format=output_formats[0],
|
| use_spec_invert=False,
|
| device="cuda" if torch.cuda.is_available() else "cpu",
|
| ):
|
| self.output_dir = Path(output_dir)
|
| self.output_dir.mkdir(parents=True, exist_ok=True)
|
| self.model_types = (
|
| "mel_band_roformer",
|
| "bs_roformer",
|
| "mdx23c",
|
| "scnet",
|
| "scnet_masked",
|
| "scnet_tran",
|
| "htdemucs",
|
| "bandit",
|
| "bandit_v2",
|
| "mdxnet",
|
| "vr",
|
| "medley_vox"
|
| )
|
| self.custom_model_types = self.model_types[:9]
|
|
|
| self.output_format = output_format
|
| self.device = torch.device(device)
|
| self.use_spec_invert = use_spec_invert
|
| self.model = None
|
| self.model_module = None
|
| self.state_dict = {}
|
| self.model_loaded = False
|
| self.model_type = None
|
| self.ckpt_path = None
|
| self.conf_path = None
|
| self.config = None
|
| self.target_instrument = None
|
| self.instruments = []
|
| self.input_mix = None
|
| self.input_file_name = None
|
| self.sample_rate = None
|
| self.selected_instruments = []
|
| self.output_files_list = []
|
| self.add_params = {}
|
| self.output_arrays: dict[str, np.ndarray] = {}
|
|
|
| def settings(self,
|
| output_dir=".",
|
| output_format=output_formats[0],
|
| use_spec_invert=False,
|
| ):
|
| self.output_dir = Path(output_dir)
|
| self.output_dir.mkdir(parents=True, exist_ok=True)
|
| self.output_format = output_format
|
| self.use_spec_invert = use_spec_invert
|
|
|
| def set_add_params(self, **kwargs):
|
| self.add_params = kwargs
|
|
|
| def load_config(self, model_type: str, conf: str | Path):
|
| if not conf:
|
| raise PathNotSpecified(_i18n("path_not_specified"))
|
| self.conf_path = Path(conf)
|
| if not self.conf_path.exists():
|
| self.conf_path = None
|
| raise PathNotExist(_i18n("path_not_exist"))
|
| if model_type not in self.model_types:
|
| raise UnknownModelType(_i18n("unknown_model_type", model_type=model_type))
|
| self.model_type = model_type
|
| try:
|
| if self.model_type == "htdemucs":
|
| self.config = OmegaConf.load(self.conf_path)
|
| self.sample_rate = self.config.training.samplerate
|
| else:
|
| with self.conf_path.open("r", encoding="utf-8") as f:
|
| self.config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
|
| self.sample_rate = self.config.audio.sample_rate
|
| self.target_instrument = self.config.training.target_instrument
|
| self.instruments = self.config.training.instruments
|
| print(_i18n("config_loaded")+": "+self.conf_path.name)
|
| except FileNotFoundError:
|
| self.config = None
|
| self.conf_path = None
|
| self.model_type = None
|
| self.target_instrument = None
|
| self.instruments = []
|
| self.sample_rate = None
|
| raise FileNotFoundError(_i18n("config_not_found", path=conf)) from e
|
| except Exception as e:
|
| self.config = None
|
| self.conf_path = None
|
| self.model_type = None
|
| self.target_instrument = None
|
| self.instruments = []
|
| self.sample_rate = None
|
| raise ValueError(_i18n("config_load_error", error=str(e))) from e
|
|
|
| def prefer_target_instrument(self):
|
| if self.target_instrument:
|
| return [self.target_instrument]
|
| else:
|
| return self.instruments
|
|
|
| def print_instruments(self):
|
| print(_i18n("stems")+": "+",".join(self.instruments))
|
| print(_i18n("target_instrument")+": "+(self.target_instrument if self.target_instrument else _i18n("no")))
|
|
|
| def load_model_instance(self):
|
| if self.config is None or self.model_type is None:
|
| raise ConfigNotLoaded(_i18n("config_is_not_loaded"))
|
| if self.model_type == "mdx23c":
|
| from models import mdx23c_tfc_tdf_v3 as module
|
| self.model_module = module.TFC_TDF_net
|
| self.model = self.model_module(self.config)
|
| del module
|
|
|
| elif self.model_type == "mdxnet":
|
| from models import mdx_net as module
|
| self.model_module = module.MDXNet
|
| self.model = self.model_module(**dict(self.config.model))
|
| del module
|
|
|
| elif self.model_type == "vr":
|
| from models import vr_arch as module
|
| self.model_module = module.get_model
|
| self.model = self.model_module(self.config)
|
| del module
|
|
|
| elif self.model_type == "htdemucs":
|
| models_path = BASE_DIR / 'models'
|
| sys.path.append(str(models_path))
|
| from demucs import get_model as module
|
| self.model_module = module
|
| self.model = self.model_module(self.config)
|
| del module
|
|
|
| elif self.model_type == "mel_band_roformer":
|
| if hasattr(self.config, "windowed"):
|
| from models.windowed_roformer import model as module
|
| self.model_module = module.MelBandRoformerWSA
|
| self.model = self.model_module(**dict(self.config.model))
|
| del module
|
|
|
| elif hasattr(self.config, "conformer"):
|
| from models import bs_roformer as module
|
| self.model_module = module.MelBandConformer
|
| self.model = self.model_module(**dict(self.config.model))
|
| del module
|
|
|
| else:
|
| from models import bs_roformer as module
|
| self.model_module = module.MelBandRoformer
|
| self.model = self.model_module(**dict(self.config.model))
|
| del module
|
|
|
| elif self.model_type == "bs_roformer":
|
| if hasattr(self.config, "sw"):
|
| from models import bs_roformer as module
|
| self.model_module = module.BSRoformer_SW
|
| self.model = self.model_module(**dict(self.config.model))
|
| del module
|
|
|
| elif hasattr(self.config, "fno"):
|
| from models import bs_roformer as module
|
| self.model_module = module.BSRoformer_FNO
|
| self.model = self.model_module(**dict(self.config.model))
|
| del module
|
|
|
| elif hasattr(self.config, "hyperace"):
|
| from models import bs_roformer as module
|
| self.model_module = module.BSRoformerHyperACE
|
| self.model = self.model_module(**dict(self.config.model))
|
| del module
|
|
|
| elif hasattr(self.config, "hyperace2"):
|
| from models import bs_roformer as module
|
| self.model_module = module.BSRoformerHyperACE_2
|
| self.model = self.model_module(**dict(self.config.model))
|
| del module
|
|
|
| elif hasattr(self.config, "conformer"):
|
| from models import bs_roformer as module
|
| self.model_module = module.BSConformer
|
| self.model = self.model_module(**dict(self.config.model))
|
| del module
|
|
|
| elif hasattr(self.config, "conditional"):
|
| from models import bs_roformer as module
|
| self.model_module = module.BSRoformer_Conditional
|
| self.model = self.model_module(**dict(self.config.model))
|
| del module
|
|
|
| elif hasattr(self.config, "unwa_inst_large_2"):
|
| from models import bs_roformer as module
|
| self.model_module = module.BSRoformer_2
|
| self.model = self.model_module(**dict(self.config.model))
|
| del module
|
|
|
| elif hasattr(self.config, "siamese"):
|
| from models import bs_roformer as module
|
| self.model_module = module.BSSiameseRoformer
|
| self.model = self.model_module(**dict(self.config.model))
|
| del module
|
|
|
| else:
|
| from models import bs_roformer as module
|
| self.model_module = module.BSRoformer
|
| self.model = self.model_module(**dict(self.config.model))
|
| del module
|
|
|
| elif self.model_type == "bandit":
|
| from models.bandit.core import model as module
|
| self.model_module = module.MultiMaskMultiSourceBandSplitRNNSimple
|
| self.model = self.model_module(**self.config.model)
|
| del module
|
|
|
| elif self.model_type == "bandit_v2":
|
| from models.bandit_v2 import bandit as module
|
| self.model_module = module.Bandit
|
| self.model = self.model_module(**self.config.kwargs)
|
| del module
|
|
|
| elif self.model_type == "scnet_unofficial":
|
| from models import scnet_unofficial as module
|
| self.model_module = module.SCNet
|
| self.model = self.model_module(**self.config.model)
|
| del module
|
|
|
| elif self.model_type == "scnet":
|
| from models import scnet as module
|
| self.model_module = module.SCNet
|
| self.model = self.model_module(**self.config.model)
|
| del module
|
|
|
| elif self.model_type == 'scnet_masked':
|
| from models.scnet import scnet_masked as module
|
| self.model_module = module.SCNet
|
| self.model = self.model_module(**self.config.model)
|
| del module
|
|
|
| elif self.model_type == 'scnet_tran':
|
| from models.scnet import scnet_tran as module
|
| self.model_module = module.SCNet_Tran
|
| self.model = self.model_module(**self.config.model)
|
| del module
|
|
|
| elif self.model_type == 'medley_vox':
|
| from models import medley_vox as module
|
| self.model_module = module.load_model_with_args
|
| self.model = self.model_module(self.config.model)
|
| del module
|
|
|
| else:
|
| raise UnknownModelType(_i18n("unknown_model_type", model_type=self.model_type))
|
|
|
| def clear_model(self):
|
| """Стандартная очистка (сохранена для совместимости)"""
|
|
|
| if self.model is not None:
|
| if hasattr(self.model, "cpu"):
|
| self.model = self.model.cpu()
|
| del self.model
|
|
|
| self.model = None
|
| del self.model_module
|
| self.model_module = None
|
| self.model_loaded = False
|
|
|
| self.state_dict.clear() if hasattr(self.state_dict, 'clear') else None
|
| self.state_dict = {}
|
|
|
| self.config = None
|
| self.target_instrument = None
|
| self.instruments = []
|
| self.output_arrays.clear()
|
|
|
| self.ckpt_path = None
|
| self.conf_path = None
|
| self.model_type = None
|
|
|
| gc.collect()
|
| gc.collect()
|
|
|
| extra_clear_torch_cache()
|
| nuclear_clear_model()
|
| emergency_ram_clear()
|
| self.clear_gpu_cache()
|
|
|
| def clear_gpu_cache(self):
|
| gc.collect()
|
| torch.clear_autocast_cache()
|
| if self.device.type == "mps":
|
| torch.mps.empty_cache()
|
| if self.device.type == "cuda":
|
| torch.cuda.synchronize()
|
| torch.cuda.ipc_collect()
|
| torch.cuda.empty_cache()
|
|
|
| def load_checkpoint(self, ckpt: str | Path):
|
| if not ckpt:
|
| raise PathNotSpecified(_i18n("path_not_specified"))
|
| self.ckpt_path = Path(ckpt)
|
| if not self.ckpt_path.exists():
|
| self.ckpt_path = None
|
| raise PathNotExist(_i18n("path_not_exist"))
|
| if not self.model:
|
| self.ckpt_path = None
|
| raise ModelNotLoaded(_i18n("model_not_loaded"))
|
|
|
| if self.model_type == "mdxnet":
|
| try:
|
| self.model.init_onnx_session(self.ckpt_path, self.device, 0)
|
| self.model_loaded = True
|
| print(_i18n("checkpoint_loaded") + ": " + self.ckpt_path.name)
|
| except Exception as e:
|
| self.model_loaded = False
|
| self.ckpt_path = None
|
| self.clear_model()
|
| print(_i18n("load_checkpoint_error", error=e))
|
| return
|
| else:
|
| try:
|
| try:
|
| self.state_dict = torch.load(
|
| self.ckpt_path, map_location=self.device, weights_only=True
|
| )
|
| except torch.serialization.pickle.UnpicklingError:
|
| self.state_dict = torch.load(
|
| self.ckpt_path, map_location=self.device, weights_only=False
|
| )
|
|
|
| except Exception as e:
|
| self.model_loaded = False
|
| self.ckpt_path = None
|
| self.clear_model()
|
| print(_i18n("load_checkpoint_error", error=e))
|
| return
|
|
|
| if "state" in self.state_dict:
|
| self.state_dict = self.state_dict["state"]
|
| if "state_dict" in self.state_dict:
|
| self.state_dict = self.state_dict["state_dict"]
|
| if "model_state_dict" in self.state_dict:
|
| self.state_dict = self.state_dict["model_state_dict"]
|
|
|
| if self.model_type == "medley_vox":
|
| has_ema_keys = any(k.startswith("ema_model") for k in self.state_dict.keys())
|
|
|
| if has_ema_keys:
|
| self.state_dict = {k: v for k, v in self.state_dict.items() if k.startswith("ema_model")}
|
|
|
| new_state_dict = {}
|
| for k, v in self.state_dict.items():
|
| if k.startswith("ema_model.module."): new_key = k.replace("ema_model.module.", "")
|
| elif k.startswith("ema_model."): new_key = k.replace("ema_model.", "")
|
| elif k.startswith("online_model.module."): new_key = k.replace("online_model.module.", "")
|
| elif k.startswith("online_model."): new_key = k.replace("online_model.", "")
|
| elif k.startswith("module."): new_key = k.replace("module.", "")
|
| else: new_key = k
|
|
|
| if new_key not in ["initted", "step"]:
|
| new_state_dict[new_key] = v
|
|
|
| self.state_dict = new_state_dict
|
| del new_state_dict
|
|
|
| try:
|
| self.model.load_state_dict(self.state_dict)
|
| self.state_dict = {}
|
| self.model_loaded = True
|
| self.model.to(self.device)
|
| self.model.eval()
|
| print(_i18n("checkpoint_loaded") + ": " + self.ckpt_path.name)
|
|
|
| except RuntimeError as e:
|
| try:
|
| self.model.load_state_dict(self.state_dict, strict=False)
|
| self.state_dict = {}
|
| self.model_loaded = True
|
| self.model.to(self.device)
|
| self.model.eval()
|
| print(_i18n("load_state_dict_error", error=e))
|
| print(_i18n("checkpoint_loaded") + ": " + self.ckpt_path.name)
|
| except RuntimeError as e_2:
|
| self.state_dict = {}
|
| self.model_loaded = False
|
| self.ckpt_path = None
|
| self.clear_model()
|
| print(_i18n("load_state_dict_error", error=e_2))
|
| return
|
|
|
| def load_mix(self, path: str):
|
| self.input_file_name = None
|
| self.input_mix = None
|
| if self.config is None:
|
| raise ConfigNotLoaded(_i18n("config_is_not_loaded"))
|
| mono_bool = False
|
| if hasattr(self.config, "model"):
|
| if hasattr(self.config.model, "stereo"):
|
| mono_bool = False if self.config.model.stereo else True
|
| if not path:
|
| raise PathNotSpecified(_i18n("path_not_specified"))
|
| input_file = Path(path)
|
| if not input_file.exists():
|
| raise PathNotExist(_i18n("path_not_exist"))
|
| if check(path):
|
| self.input_file_name = input_file.stem
|
| self.input_mix, _ = read(path=input_file, sr=self.sample_rate, mono=mono_bool)
|
| self.input_mix = self.input_mix.copy()
|
| print(_i18n("loaded_mix")+": "+input_file.name)
|
| print(_i18n("array_shape")+": "+str(self.input_mix.shape))
|
| else:
|
| raise FileIsNotAudio(_i18n("file_is_not_audio", path=path))
|
|
|
| def load_array(self, array: np.ndarray, orig_sr: int):
|
| self.input_file_name = "temp_array"
|
| if self.config is None:
|
| raise ConfigNotLoaded(_i18n("config_is_not_loaded"))
|
| mono_bool = False
|
| if hasattr(self.config, "model"):
|
| if hasattr(self.config.model, "stereo"):
|
| mono_bool = not self.config.model.stereo
|
| self.input_mix = easy_resampler(array.copy(), orig_sr, self.sample_rate) if orig_sr != self.sample_rate else array.copy()
|
| if mono_bool:
|
| self.input_mix = stereo_to_mono(self.input_mix)
|
|
|
| print(_i18n("loaded_mix")+": "+_i18n("from_array"))
|
| print(_i18n("array_shape")+": "+str(self.input_mix.shape))
|
|
|
| def demix(self, add_text: str = ""):
|
| if self.input_mix is None:
|
| raise MixNotFound(_i18n("mix_not_found"))
|
| if self.input_mix.size == 0:
|
| raise MixIsEmpty(_i18n("mix_is_empty"))
|
| if not self.model_loaded:
|
| raise ModelNotLoaded(_i18n("model_not_loaded"))
|
| if self.model_type == "mdxnet":
|
| mix_tensor = torch.tensor(self.input_mix, dtype=torch.float32).to(self.device)
|
| batch_size = 1
|
| dim_t = 256
|
| hop_length: int = self.add_params.get("mdx_hop_length", 1024)
|
| batch_size: int = self.add_params.get("mdx_batch_size", 1)
|
| num_overlap: int = self.add_params.get("mdx_overlap", 2)
|
| denoise: bool = self.add_params.get("mdx_denoise", False)
|
| if self.add_params.get("mdx_override_segment", False):
|
| segment_size: int = self.add_params.get("mdx_segment_size", dim_t)
|
| else:
|
| segment_size: int = dim_t
|
| segment_size = round(segment_size / 128) * 128
|
| stem_name = self.target_instrument
|
| chunk_size = hop_length * (segment_size - 1)
|
| fade_size = chunk_size // 10
|
| step = chunk_size // num_overlap
|
| border = chunk_size - step
|
| self.model.post_init(segment_size, self.device)
|
| length_init = mix_tensor.shape[-1]
|
|
|
| if length_init > 2 * border and border > 0:
|
| wave = nn.functional.pad(mix_tensor, (border, border), mode="reflect")
|
|
|
| window = _getWindowingArray(chunk_size, fade_size).to(self.device)
|
|
|
| with torch.no_grad():
|
| result = torch.zeros_like(wave, device=self.device)
|
| counter = torch.zeros_like(wave, device=self.device)
|
|
|
| i = 0
|
| batch_data = []
|
| batch_locations = []
|
| denoise_str = " "+_i18n("denoise") if denoise else ""
|
| with tqdm(total=wave.shape[1], desc=_i18n("processing") + denoise_str + str(add_text), unit=_i18n("samples")) as progress_bar:
|
| while i < wave.shape[1]:
|
| part = wave[:, i : i + chunk_size]
|
| chunk_len = part.shape[-1]
|
|
|
| if chunk_len < chunk_size:
|
| pad_mode = "reflect" if chunk_len > chunk_size // 2 else "constant"
|
| part = nn.functional.pad(
|
| part, (0, chunk_size - chunk_len), mode=pad_mode, value=0
|
| )
|
|
|
| batch_data.append(part)
|
| batch_locations.append((i, chunk_len))
|
| i += step
|
|
|
| if len(batch_data) >= batch_size or i >= wave.shape[1]:
|
| arr = torch.stack(batch_data, dim=0)
|
|
|
| for j, (start, seg_len) in enumerate(batch_locations):
|
| if denoise:
|
| processed_spec1 = self.model.forward(self.model.stft(arr[j : j + 1], chunk_size, hop_length, segment_size))
|
| processed_spec2 = self.model.forward(self.model.stft(-(arr[j : j + 1]), chunk_size, hop_length, segment_size))
|
| processed_wav = (self.model.istft(processed_spec1, chunk_size, hop_length, segment_size) + -self.model.istft(processed_spec2, chunk_size, hop_length, segment_size)) * 0.5
|
| else:
|
| processed_spec = self.model.forward(self.model.stft(arr[j : j + 1], chunk_size, hop_length, segment_size))
|
| processed_wav = self.model.istft(processed_spec, chunk_size, hop_length, segment_size)
|
|
|
| window_segment = window[..., :seg_len]
|
| result[:, start : start + seg_len] += (
|
| processed_wav[0, :, :seg_len] * window_segment
|
| )
|
| counter[:, start : start + seg_len] += window_segment
|
|
|
| batch_data.clear()
|
| batch_locations.clear()
|
|
|
| progress_bar.update(step)
|
|
|
| estimated_sources = result / counter
|
|
|
| if length_init > 2 * border and border > 0:
|
| estimated_sources = estimated_sources[..., border:-border]
|
| result_separation = estimated_sources.detach().cpu().numpy()
|
| result_separation = np.nan_to_num(
|
| result_separation, nan=0.0, posinf=0.0, neginf=0.0
|
| )
|
| self.output_arrays = {stem_name: result_separation}
|
| del mix_tensor, window, result, counter, batch_data, batch_locations
|
| del estimated_sources, result_separation
|
| if denoise:
|
| del processed_spec1, processed_spec2, processed_wav
|
| else:
|
| del processed_spec, processed_wav
|
| elif self.model_type == "vr":
|
| from models.vr_arch import spec_utils, NON_ACCOM_STEMS
|
| aggression: int = self.add_params.get("vr_aggression", 5)
|
| enable_post_process: bool = self.add_params.get("vr_post_process", False)
|
| high_end_process: bool = self.add_params.get("vr_high_end_process", False)
|
| post_process_threshold: float = self.add_params.get("vr_post_process_threshold", 0.2)
|
| batch_size: int = self.add_params.get("vr_batch_size", 1)
|
| window_size: int = self.add_params.get("vr_window_size", 512)
|
| sr = self.sample_rate
|
| model_sample_rate = self.model.model_params.param["sr"]
|
| primary_stem, secondary_stem = self.instruments[0], self.instruments[1]
|
| aggr = float(int(aggression) / 100)
|
| aggressiveness = {
|
| "value": aggr,
|
| "split_bin": self.model.model_params.param["band"][1]["crop_stop"],
|
| "aggr_correction": self.model.model_params.param.get("aggr_correction"),
|
| }
|
|
|
| input_high_end_h = None
|
| input_high_end = None
|
| X_wave, X_spec_s = {}, {}
|
|
|
| bands_n = len(self.model.model_params.param["band"])
|
|
|
| for d in tqdm(range(bands_n, 0, -1), desc=_i18n("processing") + str(add_text), unit=_i18n("bands")):
|
| bp = self.model.model_params.param["band"][d]
|
|
|
| wav_resolution = bp["res_type"]
|
|
|
| if self.device.type == "mps":
|
| wav_resolution = "polyphase"
|
|
|
| if d == bands_n:
|
| X_wave[d], _ = librosa.resample(
|
| y=self.input_mix,
|
| orig_sr=self.sample_rate,
|
| target_sr=bp["sr"],
|
| res_type=wav_resolution,
|
| )
|
| X_spec_s[d] = spec_utils.wave_to_spectrogram(
|
| X_wave[d],
|
| bp["hl"],
|
| bp["n_fft"],
|
| self.model.model_params,
|
| band=d,
|
| is_v51_model=self.config.model.is_vr5,
|
| )
|
|
|
| if X_wave[d].ndim == 1:
|
| X_wave[d] = np.asarray([X_wave[d], X_wave[d]])
|
| else:
|
| X_wave[d] = librosa.resample(
|
| X_wave[d + 1],
|
| orig_sr=self.model.model_params.param["band"][d + 1]["sr"],
|
| target_sr=bp["sr"],
|
| res_type=wav_resolution,
|
| )
|
| X_spec_s[d] = spec_utils.wave_to_spectrogram(
|
| X_wave[d],
|
| bp["hl"],
|
| bp["n_fft"],
|
| self.model.model_params,
|
| band=d,
|
| is_v51_model=self.config.model.is_vr5,
|
| )
|
|
|
| if d == bands_n and high_end_process:
|
| input_high_end_h = (bp["n_fft"] // 2 - bp["crop_stop"]) + (
|
| self.model.model_params.param["pre_filter_stop"]
|
| - self.model.model_params.param["pre_filter_start"]
|
| )
|
| input_high_end = X_spec_s[d][
|
| :, bp["n_fft"] // 2 - input_high_end_h : bp["n_fft"] // 2, :
|
| ]
|
|
|
| X_spec = spec_utils.combine_spectrograms(
|
| X_spec_s, self.model.model_params, is_v51_model=self.config.model.is_vr5
|
| )
|
| del X_wave, X_spec_s
|
|
|
| def spec_to_wav(spec, high_end_process, input_high_end, input_high_end_h):
|
| if (
|
| high_end_process
|
| and isinstance(input_high_end, np.ndarray)
|
| and input_high_end_h is not None
|
| ):
|
| input_high_end_ = spec_utils.mirroring(
|
| "mirroring", spec, input_high_end, self.model.model_params
|
| )
|
| wav = spec_utils.cmb_spectrogram_to_wave(
|
| spec,
|
| self.model.model_params,
|
| input_high_end_h,
|
| input_high_end_,
|
| is_v51_model=self.config.model.is_vr5,
|
| )
|
| else:
|
| wav = spec_utils.cmb_spectrogram_to_wave(
|
| spec, self.model.model_params, is_v51_model=self.config.model.is_vr5
|
| )
|
|
|
| return wav
|
|
|
| def _execute(X_mag_pad: np.ndarray, roi_size: int) -> np.ndarray:
|
| X_dataset = []
|
| patches = (X_mag_pad.shape[2] - 2 * self.model.offset) // roi_size
|
| for i in tqdm(range(patches), desc=_i18n("processing") + str(add_text), unit=_i18n("patches")):
|
| start = i * roi_size
|
| X_mag_window = X_mag_pad[:, :, start : start + window_size]
|
| X_dataset.append(X_mag_window)
|
|
|
|
|
| X_dataset = np.asarray(X_dataset)
|
| self.model.eval()
|
| with torch.no_grad():
|
| mask = []
|
|
|
| for i in tqdm(range(0, patches, batch_size), desc=_i18n("processing") + str(add_text), unit=_i18n("chunks")):
|
| X_batch = X_dataset[i : i + batch_size]
|
| X_batch = torch.from_numpy(X_batch).to(self.device)
|
| pred = self.model.predict_mask(X_batch)
|
| if not pred.size()[3] > 0:
|
| raise ValueError(
|
| _i18n("window_size_error")
|
| )
|
| pred = pred.detach().cpu().numpy()
|
| pred = np.concatenate(pred, axis=2)
|
| mask.append(pred)
|
| if len(mask) == 0:
|
| raise ValueError(
|
| _i18n("window_size_error")
|
| )
|
|
|
| mask = np.concatenate(mask, axis=2)
|
| return mask
|
|
|
| def postprocess(
|
| mask: np.ndarray,
|
| X_mag: np.ndarray,
|
| X_phase: np.ndarray
|
| ) -> Tuple[np.ndarray, np.ndarray]:
|
| is_non_accom_stem = False
|
| for stem in NON_ACCOM_STEMS:
|
| if stem == primary_stem.lower():
|
| is_non_accom_stem = True
|
|
|
| mask = spec_utils.adjust_aggr(mask, is_non_accom_stem, aggressiveness)
|
|
|
| if enable_post_process:
|
| mask = spec_utils.merge_artifacts(
|
| mask, thres=post_process_threshold
|
| )
|
|
|
| y_spec = mask * X_mag * np.exp(1.0j * X_phase)
|
| v_spec = (1 - mask) * X_mag * np.exp(1.0j * X_phase)
|
|
|
| return y_spec, v_spec
|
|
|
| X_mag, X_phase = spec_utils.preprocess(X_spec)
|
| n_frame = X_mag.shape[2]
|
| pad_l, pad_r, roi_size = spec_utils.make_padding(
|
| n_frame, window_size, self.model.offset
|
| )
|
| X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")
|
| X_mag_pad /= X_mag_pad.max()
|
| mask = _execute(X_mag_pad, roi_size)
|
|
|
| mask = mask[:, :, :n_frame]
|
|
|
| y_spec, v_spec = postprocess(mask, X_mag, X_phase)
|
|
|
| y_spec = np.nan_to_num(y_spec, nan=0.0, posinf=0.0, neginf=0.0)
|
| v_spec = np.nan_to_num(v_spec, nan=0.0, posinf=0.0, neginf=0.0)
|
| primary_stem_array = spec_to_wav(y_spec, high_end_process, input_high_end, input_high_end_h)
|
| primary_stem_array = librosa.resample(
|
| primary_stem_array,
|
| orig_sr=model_sample_rate,
|
| target_sr=sr,
|
| ).T
|
| secondary_stem_array = spec_to_wav(v_spec, high_end_process, input_high_end, input_high_end_h)
|
| secondary_stem_array = librosa.resample(
|
| secondary_stem_array,
|
| orig_sr=model_sample_rate,
|
| target_sr=sr,
|
| ).T
|
| self.output_arrays = {
|
| primary_stem: primary_stem_array,
|
| secondary_stem: secondary_stem_array,
|
| }
|
| del X_spec, X_mag, X_phase, X_mag_pad, mask
|
| del y_spec, v_spec, primary_stem_array, secondary_stem_array
|
| elif self.model_type == "htdemucs":
|
| mix = torch.tensor(self.input_mix, dtype=torch.float32)
|
| segment_sec: int = self.add_params.get("demucs_segment", 10)
|
| denoise: bool = self.add_params.get("demucs_denoise", False)
|
| num_overlap = self.add_params.get("demucs_overlap", 2)
|
| batch_size: int = self.add_params.get("demucs_batch_size", 1)
|
| if self.add_params.get("demucs_override_segment", False):
|
| chunk_size = self.config.training.samplerate * segment_sec
|
| else:
|
| chunk_size = getattr(self.config.training, "segment", 10) * self.config.training.samplerate
|
| num_instruments = len(self.instruments)
|
| step = chunk_size // num_overlap
|
| fade_size = chunk_size // 10
|
| windowing_array = _getWindowingArray(chunk_size, fade_size)
|
| use_amp = getattr(self.config.training, "use_amp", True)
|
|
|
| with torch.inference_mode():
|
| req_shape = (num_instruments,) + mix.shape
|
| result = torch.zeros(req_shape, dtype=torch.float32)
|
| counter = torch.zeros(req_shape, dtype=torch.float32)
|
|
|
| i = 0
|
| batch_data = []
|
| batch_locations = []
|
| denoise_str = " "+_i18n("denoise") if denoise else ""
|
| with tqdm(total=mix.shape[1], desc=_i18n("processing") + denoise_str + str(add_text), unit=_i18n("samples")) as progress_bar:
|
| while i < mix.shape[1]:
|
| part = mix[:, i : i + chunk_size].to(self.device)
|
| chunk_len = part.shape[-1]
|
| pad_mode = "reflect" if chunk_len > chunk_size // 2 else "constant"
|
| part = nn.functional.pad(
|
| part, (0, chunk_size - chunk_len), mode=pad_mode, value=0
|
| )
|
|
|
| batch_data.append(part)
|
| batch_locations.append((i, chunk_len))
|
| i += step
|
|
|
| if len(batch_data) >= batch_size or i >= mix.shape[1]:
|
| arr = torch.stack(batch_data, dim=0)
|
| if denoise:
|
| x1 = self.model(arr)
|
| x2 = self.model(-arr)
|
| x = (x1 + -x2) * 0.5
|
| else:
|
| x = self.model(arr)
|
| window = windowing_array.clone()
|
| if i - step == 0:
|
| window[:fade_size] = 1
|
| elif i >= mix.shape[1]:
|
| window[-fade_size:] = 1
|
|
|
| for j, (start, seg_len) in enumerate(batch_locations):
|
| result[..., start : start + seg_len] += (
|
| x[j, ..., :seg_len].cpu() * window[..., :seg_len]
|
| )
|
| counter[..., start : start + seg_len] += window[..., :seg_len]
|
|
|
| batch_data.clear()
|
| batch_locations.clear()
|
| progress_bar.update(step)
|
| estimated_sources = result / counter
|
| estimated_sources = estimated_sources.detach().cpu().numpy()
|
| np.nan_to_num(estimated_sources, copy=False, nan=0.0)
|
|
|
| if num_instruments <= 1:
|
| self.output_arrays = estimated_sources
|
| else:
|
| instruments = self.instruments
|
| self.output_arrays = {k: v for k, v in zip(instruments, estimated_sources)}
|
| del mix, result, counter, batch_data, batch_locations
|
| if denoise:
|
| del estimated_sources, x, x1, x2
|
| else:
|
| del estimated_sources, x
|
| elif self.model_type == "medley_vox":
|
| import pyloudnorm as pyln
|
| from models.medley_vox.loudness_utils import loudnorm, db2linear
|
| if self.add_params.get("mvox_override_segment", False):
|
| segment_sec: int = self.add_params.get("mvox_segment", self.config.model.seq_dur)
|
| else:
|
| segment_sec = self.config.model.seq_dur
|
| overlap: int = self.add_params.get("mvox_overlap", 2)
|
| stems: List[str] = self.instruments
|
|
|
| if self.input_mix.ndim == 1:
|
| self.input_mix = np.expand_dims(self.input_mix, axis=0)
|
| num_channels = 1
|
| elif self.input_mix.ndim == 2:
|
| if self.input_mix.shape[0] <= self.input_mix.shape[1]:
|
| num_channels = self.input_mix.shape[0]
|
| else:
|
| self.input_mix = self.input_mix.T
|
| num_channels = self.input_mix.shape[0]
|
|
|
| samplerate = self.config.model.sample_rate
|
| chunk_size = int(samplerate * segment_sec)
|
| step = chunk_size // overlap
|
| fade_size = chunk_size // 10
|
|
|
| meter = pyln.Meter(samplerate)
|
| try:
|
| if num_channels > 1:
|
| mix_for_loudnorm = self.input_mix.T
|
| else:
|
| mix_for_loudnorm = self.input_mix[0]
|
|
|
| mixture_d, adjusted_gain = loudnorm(mix_for_loudnorm, -24.0, meter)
|
|
|
| if num_channels > 1:
|
| if isinstance(mixture_d, np.ndarray) and mixture_d.ndim == 2:
|
| mixture_d = mixture_d.T
|
| else:
|
| mixture_d = np.tile(mixture_d, (num_channels, 1))
|
| else:
|
| if mixture_d.ndim == 1:
|
| mixture_d = mixture_d.reshape(1, -1)
|
|
|
| except Exception as e:
|
| print(_i18n("loudnorm_error", error=str(e)))
|
| mixture_d = mix.copy()
|
| rms = np.sqrt(np.mean(mix**2))
|
| target_rms = 0.1
|
| if rms > 0:
|
| adjusted_gain = 20 * np.log10(target_rms / rms)
|
| mixture_d = mix * (target_rms / rms)
|
| else:
|
| adjusted_gain = 0
|
|
|
| length_init = mixture_d.shape[1]
|
|
|
| windowing_array = _getWindowingArray(chunk_size, fade_size).to(self.device)
|
|
|
| result_stems = {stem: np.zeros((num_channels, length_init), dtype=np.float32)
|
| for stem in stems}
|
|
|
| mix_tensor = torch.tensor(mixture_d, dtype=torch.float32).to(self.device)
|
|
|
| counters = {stem: torch.zeros((num_channels, length_init), dtype=torch.float32, device=self.device)
|
| for stem in stems}
|
|
|
| i = 0
|
| with tqdm(total=length_init, desc=_i18n("processing") + str(add_text), unit=_i18n("samples")) as progress_bar:
|
| while i < length_init:
|
| end_idx = min(i + chunk_size, length_init)
|
| chunk = mix_tensor[:, i:end_idx]
|
| cur_chunk_len = chunk.shape[1]
|
|
|
| chunk_results = torch.zeros((num_channels, 2, cur_chunk_len), dtype=torch.float32, device=self.device)
|
|
|
| for ch in range(num_channels):
|
| channel_chunk = chunk[ch:ch+1, :]
|
|
|
| if cur_chunk_len < chunk_size:
|
| pad_len = chunk_size - cur_chunk_len
|
| channel_chunk = torch.nn.functional.pad(
|
| channel_chunk, (0, pad_len), mode='constant', value=0
|
| )
|
|
|
| channel_chunk = channel_chunk.unsqueeze(0)
|
|
|
| with torch.no_grad():
|
| out_chunk = self.model.separate(channel_chunk)
|
|
|
| chunk_results[ch, :, :cur_chunk_len] = out_chunk[0, :, :cur_chunk_len].cpu()
|
|
|
| window = windowing_array[:cur_chunk_len].clone()
|
| if i == 0:
|
| window[:fade_size] = 1
|
| if end_idx >= length_init:
|
| window[-fade_size:] = 1
|
|
|
| for stem_idx, stem in enumerate(stems):
|
| result_stems[stem][:, i:end_idx] += chunk_results[:, stem_idx, :].cpu().numpy() * window.cpu().numpy()
|
| counters[stem][:, i:end_idx] += window
|
|
|
| i += step
|
| progress_bar.update(step)
|
|
|
| for stem in stems:
|
| counters_np = counters[stem].detach().cpu().numpy()
|
| mask = counters_np > 0
|
| result_stems[stem][mask] /= counters_np[mask]
|
|
|
| result_stems[stem] = result_stems[stem] * db2linear(-adjusted_gain)
|
|
|
| self.output_arrays = result_stems
|
| del mix_tensor, mixture_d, counters
|
| del result_stems, chunk_results
|
| del meter
|
| else:
|
| mix = torch.tensor(self.input_mix, dtype=torch.float32).to(self.device)
|
| segment: int = self.add_params.get("mdxc_segment_size", 256)
|
| if hasattr(self.config, "model"):
|
| if hasattr(self.config.model, "stft_hop_length"):
|
| hop_length = self.config.model.stft_hop_length
|
| elif hasattr(self.config.model, "hop_size"):
|
| hop_length = self.config.model.hop_size
|
| elif hasattr(self.config.model, "hop_length"):
|
| hop_length = self.config.model.hop_length
|
| if hasattr(self.config, "audio"):
|
| if hasattr(self.config.audio, "hop_length"):
|
| hop_length = self.config.audio.hop_length
|
| if hasattr(self.config, "kwargs"):
|
| if hasattr(self.config.kwargs, "hop_length"):
|
| hop_length = self.config.kwargs.hop_length
|
|
|
| if self.add_params.get("mdxc_override_segment", False):
|
| chunk_size = int(hop_length) * (int(segment) - 1)
|
| else:
|
| chunk_size = self.config.audio.chunk_size
|
| instruments = self.prefer_target_instrument()
|
| num_instruments = len(instruments)
|
| denoise: bool = self.add_params.get("mdxc_denoise", False)
|
| num_overlap: int = self.add_params.get("mdxc_overlap", 2)
|
|
|
| fade_size = chunk_size // 10
|
| step = chunk_size // num_overlap
|
| border = chunk_size - step
|
| length_init = mix.shape[-1]
|
| windowing_array = _getWindowingArray(chunk_size, fade_size)
|
|
|
| if length_init > 2 * border and border > 0:
|
| mix = nn.functional.pad(mix, (border, border), mode="reflect")
|
|
|
| batch_size: int = self.add_params.get("mdxc_batch_size", 1)
|
| use_amp = getattr(self.config.training, "use_amp", True)
|
|
|
| with torch.inference_mode(), get_autocast_context(self.device.type, use_amp):
|
| req_shape = (num_instruments,) + mix.shape
|
| result = torch.zeros(req_shape, dtype=torch.float32)
|
| counter = torch.zeros(req_shape, dtype=torch.float32)
|
|
|
| i = 0
|
| batch_data = []
|
| batch_locations = []
|
| denoise_str = " "+_i18n("denoise") if denoise else ""
|
| with tqdm(total=mix.shape[1], desc=_i18n("processing") + denoise_str + str(add_text), unit=_i18n("samples")) as progress_bar:
|
|
|
| while i < mix.shape[1]:
|
| part = mix[:, i : i + chunk_size].to(self.device)
|
| chunk_len = part.shape[-1]
|
|
|
| pad_mode = "reflect" if chunk_len > chunk_size // 2 else "constant"
|
| part = nn.functional.pad(
|
| part, (0, chunk_size - chunk_len), mode=pad_mode, value=0
|
| )
|
|
|
| batch_data.append(part)
|
| batch_locations.append((i, chunk_len))
|
| i += step
|
|
|
| if len(batch_data) >= batch_size or i >= mix.shape[1]:
|
| arr = torch.stack(batch_data, dim=0)
|
| if denoise:
|
| x1 = self.model(arr)
|
| x2 = self.model(-arr)
|
| x = (x1 + -x2) * 0.5
|
| else:
|
| x = self.model(arr)
|
|
|
| window = windowing_array.clone()
|
| if i - step == 0:
|
| window[:fade_size] = 1
|
| elif i >= mix.shape[1]:
|
| window[-fade_size:] = 1
|
|
|
| for j, (start, seg_len) in enumerate(batch_locations):
|
| result[..., start : start + seg_len] += (
|
| x[j, ..., :seg_len].cpu() * window[..., :seg_len]
|
| )
|
| counter[..., start : start + seg_len] += window[..., :seg_len]
|
|
|
| batch_data.clear()
|
| batch_locations.clear()
|
| progress_bar.update(step)
|
|
|
| estimated_sources = result / counter
|
| estimated_sources = estimated_sources.detach().cpu().numpy()
|
| np.nan_to_num(estimated_sources, copy=False, nan=0.0)
|
|
|
| if length_init > 2 * border and border > 0:
|
| estimated_sources = estimated_sources[..., border:-border]
|
|
|
| self.output_arrays = {k: v for k, v in zip(instruments, estimated_sources)}
|
| del mix, result, counter, batch_data, batch_locations
|
| if denoise:
|
| del estimated_sources, x, x1, x2
|
| else:
|
| del estimated_sources, x
|
| self.add_second_stem()
|
| return
|
|
|
| def add_second_stem(self):
|
| if self.target_instrument:
|
| second_stem = [instrument for instrument in self.instruments if instrument != self.target_instrument][0]
|
| self.output_arrays[second_stem] = subtractor(self.input_mix, self.output_arrays[self.target_instrument], self.sample_rate, self.sample_rate, spectrogram=self.use_spec_invert)[0]
|
| print(_i18n("added_second_stem") + ": " + second_stem)
|
| else:
|
| return
|
|
|
| def delete_unselected_stems(self, selected_stems: list):
|
| if selected_stems:
|
| output_keys = list(self.output_arrays.keys())
|
| deleted_keys = []
|
| for stem in output_keys:
|
| if stem not in selected_stems:
|
| self.output_arrays[stem] = None
|
| del self.output_arrays[stem]
|
| deleted_keys.append(stem)
|
| print(_i18n("deleted_stems") + f": " + ",".join(deleted_keys))
|
| else:
|
| return
|
|
|
| def extract_instrumental(self, extract_instrumental: bool, return_: bool = False):
|
| if extract_instrumental:
|
| if self.output_arrays:
|
| output_keys = [key_ for key_ in self.output_arrays]
|
| self.output_arrays["invert"] = self.input_mix.copy()
|
| for stem in output_keys:
|
| self.output_arrays["invert"] = subtractor(self.output_arrays["invert"], self.output_arrays[stem], self.sample_rate, self.sample_rate, spectrogram=self.use_spec_invert)[0]
|
| if return_:
|
| return self.output_arrays["invert"]
|
|
|
| def write(self, template: str, format_return: str = "name_stems_list"):
|
| results = []
|
| writed_stems = []
|
| model_name = self.ckpt_path.stem
|
| print(_i18n("format_return") + ": " + _i18n(format_return))
|
| for stem, array in tqdm(self.output_arrays.items(), desc=_i18n("writing"), unit=_i18n('files')):
|
| custom_name = Namer.template(
|
| template,
|
| STEM=stem,
|
| MODEL=model_name,
|
| NAME=Namer.short_input_name_template(template, STEM=stem, MODEL=model_name, NAME=self.input_file_name)
|
| )
|
| writed_stems.append([stem, write(Namer.iter(self.output_dir / f"{custom_name}.{self.output_format}"), array, self.sample_rate)])
|
| if writed_stems:
|
| match format_return:
|
| case "name_stems_list":
|
| results = [self.input_file_name, writed_stems]
|
| case "stems_list":
|
| results = writed_stems
|
| case "stems_list_append_self":
|
| self.output_files_list.append(writed_stems)
|
| case "name_stems_list_append_self":
|
| self.output_files_list.append([self.input_file_name, writed_stems])
|
| return results
|
|
|
| def clear_mix(self):
|
| self.input_file_name = None
|
| self.input_mix = None
|
| self.output_arrays.clear()
|
|
|
| def clear_outputs(self):
|
| self.output_files_list.clear()
|
|
|
| def get_outputs(self):
|
| return self.output_files_list
|
|
|
| def _process(self, i: int, total: int, path: str, template: str, selected_stems: list = [], extract_instrumental: bool = True):
|
| template = Namer.sanitize(template)
|
| template = Namer.dedup_template(template, keys=["NAME", "MODEL", "STEM"])
|
| template = Namer.short(template, length=40)
|
| self.clear_mix()
|
| self.load_mix(path)
|
| try:
|
| self.demix(f" | {i}/{total} {_i18n('files')}")
|
| except Exception as e:
|
| self.clear_mix()
|
| raise DemixError(_i18n("demix_error", error=e)) from e
|
| self.delete_unselected_stems(selected_stems)
|
| self.extract_instrumental(extract_instrumental)
|
| self.write(template, "name_stems_list_append_self")
|
| self.clear_mix()
|
|
|
| def _process_array_ensemble(self, i: int, total: int, array: np.ndarray, sr: int, primary_stem: str | None = None, invert: bool = False):
|
| self.clear_mix()
|
| self.load_array(array, sr)
|
| try:
|
| self.demix(f" | {i}/{total} {_i18n('models')} | {self.ckpt_path.stem}")
|
| except Exception as e:
|
| self.clear_mix()
|
| raise DemixError(_i18n("demix_error", error=e)) from e
|
| self.delete_unselected_stems([primary_stem])
|
| if invert:
|
| result = self.extract_instrumental(True, return_=True)
|
| else:
|
| result = self.output_arrays[primary_stem]
|
| return result, self.sample_rate
|
|
|
| def load_model(self, model_type: str, ckpt: str | Path, conf: str | Path):
|
| self.clear_model()
|
| self.load_config(model_type=model_type, conf=conf)
|
| self.load_model_instance()
|
| self.load_checkpoint(ckpt=ckpt)
|
|
|
| def inference(self, input: str | list, /, *inputs, template: str = "NAME_MDOEL_STEM", selected_stems: list = [], extract_instrumental: bool = False):
|
| self.clear_outputs()
|
| all_inputs = []
|
| if isinstance(input, list):
|
| all_inputs.extend(input)
|
| else:
|
| all_inputs.append(input)
|
| if inputs:
|
| all_inputs.extend(inputs)
|
| total = len(all_inputs)
|
| for i, input_file in enumerate(all_inputs, start=1):
|
| try:
|
| self._process(i, total, input_file, template=template, selected_stems=selected_stems, extract_instrumental=extract_instrumental)
|
| except Exception as e:
|
| traceback.print_exc()
|
| return self.get_outputs()
|
|
|
| class ModelManager:
|
| def __init__(self):
|
| self.info = {}
|
| self.info_url = "https://huggingface.co/noblebarkrr/mvsepless_resources/resolve/main/models.json?download=true"
|
| self.info_path = Path(BASE_DIR) / "models.json"
|
| self.load_info()
|
| self.cache_dir = Path(BASE_DIR) / "separation_cache"
|
| self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| def get_all_models(self):
|
| return [mn for mn in self.info]
|
|
|
| def get_stems(self, model_name):
|
| return [stem for stem in self.info.get(model_name, {}).get("stems", [])]
|
|
|
| def get_target_instrument(self, model_name):
|
| return self.info.get(model_name, {}).get("target_instrument", None)
|
|
|
| def get_model_type(self, model_name):
|
| return self.info.get(model_name, {}).get("model_type", "")
|
|
|
| def get_links(self, model_name):
|
| return (self.info.get(model_name, {}).get("checkpoint_url", None),
|
| self.info.get(model_name, {}).get("config_url", None))
|
|
|
| def generate_local_paths(self, model_name):
|
| return (self.cache_dir / f"{model_name}.ckpt",
|
| self.cache_dir / f"{model_name}_config.yaml")
|
|
|
| def check_installed(self, model_name):
|
| return [path.exists() for path in self.generate_local_paths(model_name)]
|
|
|
| def check_installed2(self, model_name):
|
| return all(self.check_installed(model_name))
|
|
|
| def load_info(self):
|
| self.info = json.loads(self.info_path.read_text("utf-8"))
|
|
|
| def show_info(self, limit: int = None, stem: str = None, only_installed: bool = False):
|
| models = []
|
| if stem:
|
| models = [
|
| model for model in self.get_all_models()
|
| if (stem in self.get_stems(model) or
|
| stem.lower() in self.get_stems(model) or
|
| stem.upper() in self.get_stems(model) or
|
| stem.capitalize() in self.get_stems(model) or
|
| stem.title() in self.get_stems(model))
|
| ]
|
| else:
|
| models = self.get_all_models()
|
| if only_installed:
|
| models = [model for model in models if self.check_installed2(model)]
|
| if limit:
|
| models = models[:limit]
|
|
|
| console = rich.console.Console()
|
| table = rich.table.Table(title=_i18n("model_info"), show_lines=True)
|
| table.add_column(_i18n("model_name"), no_wrap=True)
|
| table.add_column(_i18n("output_stems"))
|
| table.add_section()
|
| table.add_row(_i18n("table_model_info_installed_legend"), _i18n("table_model_info_target_instrument_legend"))
|
| table.add_section()
|
| if models:
|
| for model_ in models:
|
| target_instrument = self.get_target_instrument(model_)
|
| stems = self.get_stems(model_)
|
| if target_instrument:
|
| for i, stem in enumerate(stems):
|
| if stem == target_instrument:
|
| stems[i] = f"[green]{stem}[/]"
|
|
|
| stems_str = ", ".join(stems)
|
| table.add_row(f"[green]{model_}[/]" if self.check_installed2(model_) else model_, stems_str)
|
| else:
|
| table.add_row(_i18n("na"), _i18n("na"))
|
|
|
| console.print(table)
|
|
|
| def update_info(self):
|
| dw_file(self.info_url, self.info_path)
|
| print(_i18n("model_info_updated"))
|
|
|
| def download(self, model_name: str):
|
| status = ""
|
| urls = self.get_links(model_name)
|
| local_paths = self.generate_local_paths(model_name)
|
| local_exists = self.check_installed(model_name)
|
| for url, local_path, exists in zip(urls, local_paths, local_exists):
|
| if not exists:
|
| dw_file(url, local_path)
|
| if all(local_exists):
|
| status = _i18n("model_already_downloaded")
|
| else:
|
| status = _i18n("model_downloaded")
|
| print(status)
|
| return status
|
|
|
| class Ensembler:
|
| def __init__(self):
|
| self.arrays = []
|
| self.srs = []
|
|
|
| def add_array(self, y: np.ndarray, sr: int):
|
| self.arrays.append(y)
|
| self.srs.append(sr)
|
|
|
| def get_arrays(self):
|
| return self.arrays
|
|
|
| def get_srs(self):
|
| return self.srs
|
|
|
| def clear(self):
|
| self.arrays.clear()
|
|
|
| class Separator(ModelManager):
|
| def __init__(self):
|
| super().__init__()
|
| self.mssi = MSSI()
|
|
|
| def unload_model(self):
|
| self.mssi.clear_model()
|
|
|
| @hf_spaces_gpu
|
| def separate_base(
|
| self,
|
| input_valid_files: list[str | Path],
|
| model_name: str,
|
| template: str,
|
| checkpoint: str | Path,
|
| config: str | Path,
|
| selected_stems: list,
|
| extract_instrumental: bool
|
| ):
|
| self.mssi.clear_model()
|
| self.mssi.load_model(self.get_model_type(model_name), checkpoint, config)
|
| self.mssi.print_instruments()
|
| results = self.mssi.inference(input_valid_files, template=template, selected_stems=selected_stems, extract_instrumental=extract_instrumental)
|
| self.mssi.clear_model()
|
| return results
|
|
|
| def separate(
|
| self,
|
| input_files: list[str | Path],
|
| output_dir: str | Path = Path("."),
|
| output_format: str = output_formats[0],
|
| template: str = "NAME_(STEM)_MODEL",
|
| model_name: str = "bs_6stem",
|
| extract_instrumental: bool = False,
|
| use_spec_invert: bool = False,
|
| selected_stems: list = [],
|
| add_params: dict = {}
|
| ):
|
| if not output_dir:
|
| output_dir = ""
|
| input_valid_files = get_audio_files_from_list(input_files, only_files=False)
|
| if not input_valid_files:
|
| raise PathsNotSpecified(_i18n("paths_not_specified"))
|
| self.mssi.settings(output_dir=output_dir, output_format=output_format, use_spec_invert=use_spec_invert)
|
| self.mssi.set_add_params(**add_params)
|
| self.download(model_name)
|
| checkpoint, config = self.generate_local_paths(model_name)
|
| results = self.separate_base(input_valid_files, model_name, template, checkpoint, config, selected_stems, extract_instrumental)
|
| return results
|
|
|
| @hf_spaces_gpu
|
| def custom_separate(
|
| self,
|
| input_files: list,
|
| output_dir: str | Path = Path("."),
|
| output_format: str = output_formats[0],
|
| template: str = "NAME_(STEM)_MODEL",
|
| model_type: str = "bs_roformer",
|
| ckpt: str = "model.ckpt",
|
| conf: str = "conf.ckpt",
|
| extract_instrumental: bool = False,
|
| use_spec_invert: bool = False,
|
| selected_stems: list = [],
|
| add_params: dict = {}
|
| ):
|
| if not output_dir:
|
| output_dir = ""
|
| input_valid_files = get_audio_files_from_list(input_files, only_files=False)
|
| if not input_valid_files:
|
| raise PathsNotSpecified(_i18n("paths_not_specified"))
|
| checkpoint, config = Path(ckpt), Path(conf)
|
| model_name = checkpoint.stem
|
| self.mssi.settings(output_dir=output_dir, output_format=output_format, use_spec_invert=use_spec_invert)
|
| self.mssi.set_add_params(**add_params)
|
| self.mssi.clear_model()
|
| self.mssi.load_model(model_type, checkpoint, config)
|
| self.previous_model_name = model_name
|
| self.mssi.print_instruments()
|
| results = self.mssi.inference(input_valid_files, template=template, selected_stems=selected_stems, extract_instrumental=extract_instrumental)
|
| self.mssi.clear_model()
|
| return results
|
|
|
| def print_flow(self, flow):
|
| """Print current ensemble flow in a formatted table (like show_info)"""
|
| if not flow:
|
| return
|
|
|
| console = rich.console.Console()
|
| table = rich.table.Table(title="", show_lines=True)
|
| table.add_column("#", style="cyan", no_wrap=True)
|
| table.add_column(_i18n("model_name"))
|
| table.add_column(_i18n("primary_stem"))
|
| table.add_column(_i18n("invert"))
|
| table.add_column(_i18n("weights"), justify="right")
|
|
|
| for idx, (model_name, primary_stem, invert, weight) in enumerate(flow, start=1):
|
| invert_str = _i18n("yes") if invert else _i18n("no")
|
| table.add_row(
|
| str(idx),
|
| model_name,
|
| primary_stem,
|
| invert_str,
|
| f"{weight:.2f}" if isinstance(weight, (int, float)) else str(weight)
|
| )
|
|
|
| console.print(table)
|
|
|
| @hf_spaces_gpu
|
| def auto_ensemble_base(
|
| self,
|
| model_name: str,
|
| checkpoint: str | Path,
|
| config: str | Path,
|
| i: int,
|
| model_count: int,
|
| mix: np.ndarray,
|
| orig_sr: int,
|
| primary_stem: str,
|
| invert: bool
|
| ):
|
| self.mssi.clear_model()
|
| self.mssi.load_model(self.get_model_type(model_name), checkpoint, config)
|
| self.mssi.print_instruments()
|
| output, model_sr = self.mssi._process_array_ensemble(i, model_count, mix, orig_sr, primary_stem, invert)
|
| self.mssi.clear_model()
|
| return output, model_sr
|
|
|
| def auto_ensemble(
|
| self,
|
| input_file: str | Path,
|
| output_dir: str | Path = Path("."),
|
| flow: list[list[str | bool | int | float]] = [],
|
| template: str = "NAME_TYPE_COUNT",
|
| etype: str = ensemble_types[0],
|
| output_format: str = output_formats[0],
|
| use_spec_invert: bool = False,
|
| save_primary_stems: bool = False
|
| ) -> tuple[str, str, list[str]]:
|
| if not output_dir:
|
| output_dir = ""
|
| if not input_file:
|
| raise PathNotSpecified(_i18n("path_not_specified"))
|
| input_file = Path(input_file)
|
| if not input_file.exists():
|
| raise PathNotExist(_i18n("path_not_exist"))
|
| if not check(input_file):
|
| raise FileIsNotAudio(_i18n("file_is_not_audio", path=input_file))
|
| self.print_flow(flow)
|
| if not flow:
|
| return None
|
| output_dir = Path(output_dir)
|
| output_dir.mkdir(parents=True, exist_ok=True)
|
| model_count = len(flow)
|
| print(_i18n("ensemble_type")+": "+etype)
|
| print(_i18n("ensemble_models_count")+": "+str(model_count))
|
| mix, orig_sr = read(input_file, sr=44100)
|
| template = Namer.sanitize(template)
|
| template = Namer.dedup_template(template, keys=["NAME", "TYPE", "COUNT"])
|
| template = Namer.short(template, length=40)
|
| invert_key = "_invert"
|
| custom_name = Namer.template(
|
| template,
|
| TYPE=etype,
|
| COUNT=model_count,
|
| NAME=Namer.short_input_name_template(template, TYPE=etype, COUNT=model_count, NAME=input_file.stem)
|
| )
|
| auto_ensembler = Ensembler()
|
| weights = []
|
| saved_primary_stems = []
|
| self.mssi.set_add_params(**{"demucs_denoise": True, "mdx_denoise": True})
|
| for i, (model_name, primary_stem, invert, weight) in enumerate(flow, start=1):
|
| try:
|
| self.download(model_name)
|
| checkpoint, config = self.generate_local_paths(model_name)
|
| output, model_sr = self.auto_ensemble_base(model_name, checkpoint, config, i, model_count, mix, orig_sr, primary_stem, invert)
|
| auto_ensembler.add_array(output, model_sr)
|
| weights.append(weight)
|
| if save_primary_stems:
|
| primary_stem_file_name = primary_stem + (invert_key if invert else "")
|
| saved_primary_stems.append(write(Namer.iter(output_dir / model_name / f"{primary_stem_file_name}.flac"), output, model_sr))
|
| except Exception as e:
|
| print(_i18n("error_occured_separation")+": "+str(e))
|
| gr.Warning(message="<b>"+f'{_i18n("error_occured_separation")}'.replace("\n", "<br>")+": "+str(e)+"</b>", title="")
|
| continue
|
| extracted_primary_stems = auto_ensembler.get_arrays()
|
| srs = auto_ensembler.get_srs()
|
| output_array, sr_ = ensemble(extracted_primary_stems, srs, etype, weights)
|
| extracted_primary_stems = None
|
| auto_ensembler.clear()
|
| auto_ensembler, output = None, None
|
| del auto_ensembler, output
|
| inverted_array, i_sr = subtractor(mix, output_array, orig_sr, sr_, spectrogram=use_spec_invert)
|
| return write(Namer.iter(output_dir / f"{custom_name}.{output_format}"), output_array, sr_), write(Namer.iter(output_dir / f"{Namer.short(custom_name+invert_key)}.{output_format}"), inverted_array, i_sr), saved_primary_stems
|
|
|
| def manual_ensemble(
|
| self,
|
| input_files: list[str | Path],
|
| output_dir: str | Path = Path("."),
|
| weights: list[float] | None = None,
|
| template: str = "ensembled_TYPE_COUNT",
|
| etype: str = ensemble_types[0],
|
| output_format: str = output_formats[0],
|
| ) -> str:
|
| if not output_dir:
|
| output_dir = ""
|
| input_valid_files = get_audio_files_from_list(input_files, only_files=True)
|
| if not input_valid_files:
|
| raise PathsNotSpecified(_i18n("paths_not_specified"))
|
| output_dir = Path(output_dir)
|
| output_dir.mkdir(parents=True, exist_ok=True)
|
| arrays, srs = multiread(input_valid_files)
|
| model_count = len(srs)
|
| results, max_sr = ensemble(arrays, srs, etype, weights)
|
| template = Namer.sanitize(template)
|
| template = Namer.dedup_template(template, keys=["TYPE", "COUNT"])
|
| template = Namer.short(template, length=40)
|
| custom_name = Namer.template(
|
| template,
|
| TYPE=etype,
|
| COUNT=model_count
|
| )
|
| return write(Namer.iter(output_dir / f"{custom_name}.{output_format}"), results, max_sr)
|
|
|
| def subtract(self, audio1: str | Path, audio2: str | Path, output_dir: str | Path = Path("."), output_format: str = output_formats[0], use_spec_invert: bool = False, template: str = "invert_TYPE_NAME"):
|
| if not output_dir:
|
| output_dir = ""
|
| output_dir = Path(output_dir)
|
| output_dir.mkdir(parents=True, exist_ok=True)
|
| if not audio1 or not audio2:
|
| raise PathsNotSpecified(_i18n("paths_not_specified"))
|
| audio1, audio2 = Path(audio1), Path(audio2)
|
| if not audio1.exists() or not audio2.exists():
|
| raise PathsNotExist(_i18n("paths_not_exist"))
|
| if not check(audio1) or not check(audio2):
|
| raise FilesIsNotAudio(_i18n("files_is_not_audio"))
|
| template = Namer.sanitize(template)
|
| template = Namer.dedup_template(template, keys=["NAME", "TYPE"])
|
| template = Namer.short(template, length=40)
|
| invert_type_key = ("spectrogram" if use_spec_invert else "waveform")
|
| custom_name = Namer.template(
|
| template,
|
| TYPE=invert_type_key,
|
| NAME=Namer.short_input_name_template(template, TYPE=invert_type_key, NAME=audio1.stem)
|
| )
|
| y1, sr1 = read(audio1)
|
| y2, sr2 = read(audio2)
|
| inverted, min_sr = subtractor(y1, y2, sr1, sr2, spectrogram=use_spec_invert)
|
| return write(Namer.iter(output_dir / f"{custom_name}.{output_format}"), inverted, min_sr)
|
|
|
| if __name__ == "__main__":
|
| separator = Separator()
|
| args = parse_separator_args(add_params_args)
|
| if args.mode == "separate":
|
| separator.separate(
|
| input_files=args.input,
|
| output_dir=args.output_dir,
|
| output_format=args.output_format,
|
| template=args.template,
|
| model_name=args.model_name,
|
| extract_instrumental=args.extract_instrumental,
|
| use_spec_invert=args.use_spec_invert,
|
| selected_stems=args.selected_stems,
|
| add_params=get_add_params(args)
|
| )
|
| elif args.mode == "custom_separate":
|
| separator.custom_separate(
|
| input_files=args.input,
|
| output_dir=args.output_dir,
|
| output_format=args.output_format,
|
| template=args.template,
|
| model_type=args.model_type,
|
| ckpt=args.checkpoint_path,
|
| conf=args.config_path,
|
| extract_instrumental=args.extract_instrumental,
|
| use_spec_invert=args.use_spec_invert,
|
| selected_stems=args.selected_stems,
|
| add_params=get_add_params(args)
|
| )
|
| elif args.mode == "auto_ensemble":
|
| if args.preset:
|
| flow = json.loads(Path(args.preset).read_text("utf-8"))
|
| elif args.flow:
|
| flow = []
|
| for params in args.flow:
|
| list_values_param = params.split(":")
|
| if len(list_values_param) == 4:
|
| flow.append([str(list_values_param[0]), str(list_values_param[1]), tobool(list_values_param[2]), float(list_values_param[3])])
|
| else:
|
| raise ValueError()
|
| separator.auto_ensemble(
|
| input_file=args.input,
|
| output_dir=args.output_dir,
|
| flow=flow,
|
| template=args.template,
|
| etype=args.ensemble_type,
|
| output_format=args.output_format,
|
| use_spec_invert=args.use_spec_invert,
|
| save_primary_stems=args.save_primary_stems
|
| )
|
| elif args.mode == "manual_ensemble":
|
| separator.manual_ensemble(
|
| input_files=args.input,
|
| output_dir=args.output_dir,
|
| weights=args.weights,
|
| template=args.template,
|
| etype=args.ensemble_type,
|
| output_format=args.output_format
|
| )
|
| elif args.mode == "subtract":
|
| separator.subtract(
|
| audio1=args.input_1,
|
| audio2=args.input_2,
|
| output_dir=args.output_dir,
|
| output_format=args.output_format,
|
| use_spec_invert=args.spec_invert,
|
| template=args.template
|
| )
|
| elif args.mode == "info":
|
| if args.update:
|
| separator.update_info()
|
| elif args.download:
|
| separator.download(args.model_name)
|
| elif args.clear_cache:
|
| separator.cache_dir.unlink(missing_ok=True)
|
| else:
|
| separator.show_info(args.limit, args.stem, args.only_installed)
|
|
|
|
|