Spaces:
Running
Running
Update audio_separator/separator/separator.py
Browse files
audio_separator/separator/separator.py
CHANGED
|
@@ -1,27 +1,3 @@
|
|
| 1 |
-
""" This file contains the Separator class, to facilitate the separation of stems from audio. """
|
| 2 |
-
|
| 3 |
-
from importlib import metadata, resources
|
| 4 |
-
import os
|
| 5 |
-
import sys
|
| 6 |
-
import platform
|
| 7 |
-
import subprocess
|
| 8 |
-
import time
|
| 9 |
-
import logging
|
| 10 |
-
import warnings
|
| 11 |
-
import importlib
|
| 12 |
-
import io
|
| 13 |
-
from typing import Optional
|
| 14 |
-
|
| 15 |
-
import hashlib
|
| 16 |
-
import json
|
| 17 |
-
import yaml
|
| 18 |
-
import requests
|
| 19 |
-
import torch
|
| 20 |
-
import torch.amp.autocast_mode as autocast_mode
|
| 21 |
-
import onnxruntime as ort
|
| 22 |
-
from tqdm import tqdm
|
| 23 |
-
|
| 24 |
-
|
| 25 |
import os
|
| 26 |
import logging
|
| 27 |
import requests
|
|
@@ -44,7 +20,7 @@ class Separator:
|
|
| 44 |
"""
|
| 45 |
Optimized Separator class for audio source separation on Hugging Face Zero GPU.
|
| 46 |
Supports MDX, VR, Demucs, and MDXC architectures with ONNX Runtime and PyTorch.
|
| 47 |
-
|
| 48 |
"""
|
| 49 |
def __init__(
|
| 50 |
self,
|
|
@@ -167,52 +143,41 @@ class Separator:
|
|
| 167 |
raise RuntimeError(f"Failed to download {url}: {response.status_code}")
|
| 168 |
|
| 169 |
def list_supported_model_files(self):
|
| 170 |
-
"""Fetch supported model files
|
| 171 |
download_checks_url = "https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json"
|
| 172 |
download_checks_path = os.path.join(self.model_file_dir, "download_checks.json")
|
| 173 |
self.download_file_if_not_exists(download_checks_url, download_checks_path)
|
| 174 |
model_downloads_list = json.load(open(download_checks_path, encoding="utf-8"))
|
| 175 |
|
| 176 |
-
#
|
| 177 |
-
|
| 178 |
-
"
|
| 179 |
-
"
|
| 180 |
-
|
| 181 |
-
"instrumental": {"SDR": 15.2149, "SIR": 25.6075, "SAR": 17.1363, "ISR": 17.7893}
|
| 182 |
-
},
|
| 183 |
"stems": ["vocals", "instrumental"],
|
| 184 |
-
"target_stem": "vocals"
|
| 185 |
-
|
| 186 |
-
"htdemucs_ft.yaml": {
|
| 187 |
-
"median_scores": {
|
| 188 |
-
"vocals": {"SDR": 11.2685, "SIR": 21.257, "SAR": 11.0359, "ISR": 19.3753},
|
| 189 |
-
"drums": {"SDR": 13.235, "SIR": 23.3053, "SAR": 13.0313, "ISR": 17.2889},
|
| 190 |
-
"bass": {"SDR": 9.72743, "SIR": 19.5435, "SAR": 9.20801, "ISR": 13.5037}
|
| 191 |
-
},
|
| 192 |
-
"stems": ["vocals", "drums", "bass"],
|
| 193 |
-
"target_stem": "vocals"
|
| 194 |
},
|
| 195 |
-
"
|
| 196 |
-
"
|
| 197 |
-
|
| 198 |
-
"instrumental": {"SDR": 16.3035, "SIR": 26.6161, "SAR": 18.5167, "ISR": 18.3939}
|
| 199 |
-
},
|
| 200 |
"stems": ["vocals", "instrumental"],
|
| 201 |
-
"target_stem": "vocals"
|
| 202 |
-
|
|
|
|
|
|
|
| 203 |
}
|
| 204 |
|
| 205 |
public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
|
| 206 |
audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
|
| 207 |
|
| 208 |
-
# Simplified model list for MDX, VR, Demucs, MDXC
|
| 209 |
model_files_grouped_by_type = {
|
| 210 |
"MDX": {
|
| 211 |
"MDX-Net Model: UVR-MDX-NET-Inst_full_292": {
|
| 212 |
"filename": "UVR-MDX-NET-Inst_full_292.onnx",
|
| 213 |
-
"scores":
|
| 214 |
-
"stems":
|
| 215 |
-
"target_stem":
|
| 216 |
"download_files": ["UVR-MDX-NET-Inst_full_292.onnx"]
|
| 217 |
}
|
| 218 |
},
|
|
@@ -228,27 +193,16 @@ class Separator:
|
|
| 228 |
"Demucs": {
|
| 229 |
"Demucs v4: htdemucs_ft": {
|
| 230 |
"filename": "htdemucs_ft.yaml",
|
| 231 |
-
"scores":
|
| 232 |
-
"stems":
|
| 233 |
-
"target_stem":
|
| 234 |
"download_files": [
|
| 235 |
f"{public_model_repo_url_prefix}/htdemucs_ft.yaml",
|
| 236 |
"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th"
|
| 237 |
]
|
| 238 |
}
|
| 239 |
},
|
| 240 |
-
"MDXC":
|
| 241 |
-
"MDX23C Model: MDX23C-InstVoc HQ": {
|
| 242 |
-
"filename": "MDX23C-8KFFT-InstVoc_HQ.ckpt",
|
| 243 |
-
"scores": model_scores.get("MDX23C-8KFFT-InstVoc_HQ.ckpt", {}).get("median_scores", {}),
|
| 244 |
-
"stems": model_scores.get("MDX23C-8KFFT-InstVoc_HQ.ckpt", {}).get("stems", []),
|
| 245 |
-
"target_stem": model_scores.get("MDX23C-8KFFT-InstVoc_HQ.ckpt", {}).get("target_stem"),
|
| 246 |
-
"download_files": [
|
| 247 |
-
"MDX23C-8KFFT-InstVoc_HQ.ckpt",
|
| 248 |
-
f"{audio_separator_models_repo_url_prefix}/model_2_stem_full_band_8k.yaml"
|
| 249 |
-
]
|
| 250 |
-
}
|
| 251 |
-
}
|
| 252 |
}
|
| 253 |
return model_files_grouped_by_type
|
| 254 |
|
|
@@ -289,6 +243,7 @@ class Separator:
|
|
| 289 |
if file_to_download.endswith(".yaml"):
|
| 290 |
yaml_config_filename = file_to_download
|
| 291 |
return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
|
|
|
|
| 292 |
raise ValueError(f"Model file {model_filename} not found")
|
| 293 |
|
| 294 |
def load_model_data_from_yaml(self, yaml_config_filename):
|
|
@@ -321,6 +276,7 @@ class Separator:
|
|
| 321 |
if model_hash in model_data:
|
| 322 |
self.logger.debug(f"Model data loaded for hash {model_hash}")
|
| 323 |
return model_data[model_hash]
|
|
|
|
| 324 |
raise ValueError(f"No model data for hash {model_hash}")
|
| 325 |
|
| 326 |
def load_model(self, model_filename="UVR-MDX-NET-Inst_full_292.onnx"):
|
|
@@ -328,9 +284,14 @@ class Separator:
|
|
| 328 |
self.logger.info(f"Loading model {model_filename}")
|
| 329 |
start_time = time.perf_counter()
|
| 330 |
|
| 331 |
-
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
|
|
|
| 334 |
model_data = self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename else self.load_model_data_using_hash(model_path)
|
| 335 |
|
| 336 |
common_params = {
|
|
@@ -461,9 +422,13 @@ class Separator:
|
|
| 461 |
def download_model_and_data(self, model_filename):
|
| 462 |
"""Download model files without loading into memory."""
|
| 463 |
self.logger.info(f"Downloading model {model_filename}")
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
|
| 468 |
def get_simplified_model_list(self, filter_sort_by: Optional[str] = None):
|
| 469 |
"""Return a simplified list of models."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import logging
|
| 3 |
import requests
|
|
|
|
| 20 |
"""
|
| 21 |
Optimized Separator class for audio source separation on Hugging Face Zero GPU.
|
| 22 |
Supports MDX, VR, Demucs, and MDXC architectures with ONNX Runtime and PyTorch.
|
| 23 |
+
Handles MelBand Roformer models and ensures robust model downloading.
|
| 24 |
"""
|
| 25 |
def __init__(
|
| 26 |
self,
|
|
|
|
| 143 |
raise RuntimeError(f"Failed to download {url}: {response.status_code}")
|
| 144 |
|
| 145 |
def list_supported_model_files(self):
|
| 146 |
+
"""Fetch supported model files, including MelBand Roformer models."""
|
| 147 |
download_checks_url = "https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json"
|
| 148 |
download_checks_path = os.path.join(self.model_file_dir, "download_checks.json")
|
| 149 |
self.download_file_if_not_exists(download_checks_url, download_checks_path)
|
| 150 |
model_downloads_list = json.load(open(download_checks_path, encoding="utf-8"))
|
| 151 |
|
| 152 |
+
# Custom model list from the Gradio log
|
| 153 |
+
roformer_models = {
|
| 154 |
+
"MelBand Roformer Kim | Inst V1 (E) Plus by Unwa": {
|
| 155 |
+
"filename": "melband_roformer_inst_v1e_plus.ckpt",
|
| 156 |
+
"scores": {"vocals": {"SDR": 11.95}, "instrumental": {"SDR": 16.30}},
|
|
|
|
|
|
|
| 157 |
"stems": ["vocals", "instrumental"],
|
| 158 |
+
"target_stem": "vocals",
|
| 159 |
+
"download_files": ["melband_roformer_inst_v1e_plus.ckpt", "model_2_stem_full_band_8k.yaml"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
},
|
| 161 |
+
"MelBand Roformer Kim | Inst V1 Plus by Unwa": {
|
| 162 |
+
"filename": "melband_roformer_inst_v1_plus.ckpt",
|
| 163 |
+
"scores": {"vocals": {"SDR": 11.80}, "instrumental": {"SDR": 16.20}},
|
|
|
|
|
|
|
| 164 |
"stems": ["vocals", "instrumental"],
|
| 165 |
+
"target_stem": "vocals",
|
| 166 |
+
"download_files": ["melband_roformer_inst_v1_plus.ckpt", "model_2_stem_full_band_8k.yaml"]
|
| 167 |
+
},
|
| 168 |
+
# Add other models from the log as needed
|
| 169 |
}
|
| 170 |
|
| 171 |
public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
|
| 172 |
audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
|
| 173 |
|
|
|
|
| 174 |
model_files_grouped_by_type = {
|
| 175 |
"MDX": {
|
| 176 |
"MDX-Net Model: UVR-MDX-NET-Inst_full_292": {
|
| 177 |
"filename": "UVR-MDX-NET-Inst_full_292.onnx",
|
| 178 |
+
"scores": {"vocals": {"SDR": 10.6497}, "instrumental": {"SDR": 15.2149}},
|
| 179 |
+
"stems": ["vocals", "instrumental"],
|
| 180 |
+
"target_stem": "vocals",
|
| 181 |
"download_files": ["UVR-MDX-NET-Inst_full_292.onnx"]
|
| 182 |
}
|
| 183 |
},
|
|
|
|
| 193 |
"Demucs": {
|
| 194 |
"Demucs v4: htdemucs_ft": {
|
| 195 |
"filename": "htdemucs_ft.yaml",
|
| 196 |
+
"scores": {"vocals": {"SDR": 11.2685}, "drums": {"SDR": 13.235}, "bass": {"SDR": 9.72743}},
|
| 197 |
+
"stems": ["vocals", "drums", "bass"],
|
| 198 |
+
"target_stem": "vocals",
|
| 199 |
"download_files": [
|
| 200 |
f"{public_model_repo_url_prefix}/htdemucs_ft.yaml",
|
| 201 |
"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th"
|
| 202 |
]
|
| 203 |
}
|
| 204 |
},
|
| 205 |
+
"MDXC": roformer_models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
}
|
| 207 |
return model_files_grouped_by_type
|
| 208 |
|
|
|
|
| 243 |
if file_to_download.endswith(".yaml"):
|
| 244 |
yaml_config_filename = file_to_download
|
| 245 |
return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
|
| 246 |
+
self.logger.error(f"Model {model_filename} not found in supported models")
|
| 247 |
raise ValueError(f"Model file {model_filename} not found")
|
| 248 |
|
| 249 |
def load_model_data_from_yaml(self, yaml_config_filename):
|
|
|
|
| 276 |
if model_hash in model_data:
|
| 277 |
self.logger.debug(f"Model data loaded for hash {model_hash}")
|
| 278 |
return model_data[model_hash]
|
| 279 |
+
self.logger.error(f"No model data for hash {model_hash}")
|
| 280 |
raise ValueError(f"No model data for hash {model_hash}")
|
| 281 |
|
| 282 |
def load_model(self, model_filename="UVR-MDX-NET-Inst_full_292.onnx"):
|
|
|
|
| 284 |
self.logger.info(f"Loading model {model_filename}")
|
| 285 |
start_time = time.perf_counter()
|
| 286 |
|
| 287 |
+
try:
|
| 288 |
+
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
| 289 |
+
except ValueError as e:
|
| 290 |
+
self.logger.error(f"Failed to load model: {e}")
|
| 291 |
+
self.logger.info("Falling back to default model: UVR-MDX-NET-Inst_full_292.onnx")
|
| 292 |
+
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self\nSystem: .download_model_files("UVR-MDX-NET-Inst_full_292.onnx")
|
| 293 |
|
| 294 |
+
model_name = model_filename.split(".")[0]
|
| 295 |
model_data = self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename else self.load_model_data_using_hash(model_path)
|
| 296 |
|
| 297 |
common_params = {
|
|
|
|
| 422 |
def download_model_and_data(self, model_filename):
|
| 423 |
"""Download model files without loading into memory."""
|
| 424 |
self.logger.info(f"Downloading model {model_filename}")
|
| 425 |
+
try:
|
| 426 |
+
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
| 427 |
+
model_data = self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename else self.load_model_data_using_hash(model_path)
|
| 428 |
+
self.logger.info(f"Model downloaded: {model_friendly_name}, type: {model_type}, path: {model_path}, data items: {len(model_data)}")
|
| 429 |
+
except ValueError as e:
|
| 430 |
+
self.logger.error(f"Failed to download model: {e}")
|
| 431 |
+
raise
|
| 432 |
|
| 433 |
def get_simplified_model_list(self, filter_sort_by: Optional[str] = None):
|
| 434 |
"""Return a simplified list of models."""
|