Spaces:
Running
Running
Update gui.py
Browse files
gui.py
CHANGED
|
@@ -11,6 +11,28 @@ import librosa
|
|
| 11 |
import soundfile as sf
|
| 12 |
from ensemble import ensemble_files
|
| 13 |
import shutil
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Device and autocast setup
|
| 16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -132,8 +154,7 @@ CSS = """
|
|
| 132 |
overflow: hidden;
|
| 133 |
}
|
| 134 |
body {
|
| 135 |
-
background:
|
| 136 |
-
background-size: cover;
|
| 137 |
margin: 0;
|
| 138 |
padding: 0;
|
| 139 |
font-family: 'Roboto', sans-serif;
|
|
@@ -335,7 +356,6 @@ def download_audio(url, out_dir="ytdl"):
|
|
| 335 |
if not url:
|
| 336 |
raise ValueError("No URL provided.")
|
| 337 |
|
| 338 |
-
# Clear ytdl directory
|
| 339 |
if os.path.exists(out_dir):
|
| 340 |
shutil.rmtree(out_dir)
|
| 341 |
os.makedirs(out_dir, exist_ok=True)
|
|
@@ -358,10 +378,8 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
|
|
| 358 |
if not audio:
|
| 359 |
raise ValueError("No audio file provided.")
|
| 360 |
|
| 361 |
-
# Convert override_seg_size to boolean
|
| 362 |
override_seg_size = override_seg_size == "True"
|
| 363 |
|
| 364 |
-
# Clear output directory
|
| 365 |
if os.path.exists(output_dir):
|
| 366 |
shutil.rmtree(output_dir)
|
| 367 |
os.makedirs(output_dir, exist_ok=True)
|
|
@@ -392,7 +410,6 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
|
|
| 392 |
separation = separator.separate(audio)
|
| 393 |
stems = [os.path.join(output_dir, file_name) for file_name in separation]
|
| 394 |
|
| 395 |
-
# Filter excluded stems
|
| 396 |
if exclude_stems.strip():
|
| 397 |
excluded = [s.strip().lower() for s in exclude_stems.split(',')]
|
| 398 |
filtered_stems = [stem for stem in stems if not any(ex in os.path.basename(stem).lower() for ex in excluded)]
|
|
@@ -402,15 +419,13 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
|
|
| 402 |
logger.error(f"Separation failed: {e}")
|
| 403 |
raise RuntimeError(f"Separation failed: {e}")
|
| 404 |
|
| 405 |
-
def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_tta, model_dir, output_dir, norm_thresh, amp_thresh, batch_size, ensemble_method, exclude_stems="",
|
| 406 |
"""Perform ensemble processing on audio using multiple Roformer models."""
|
| 407 |
if not audio or not model_keys:
|
| 408 |
raise ValueError("Audio or models missing.")
|
| 409 |
|
| 410 |
-
# Convert use_tta to boolean
|
| 411 |
use_tta = use_tta == "True"
|
| 412 |
|
| 413 |
-
# Clear output directory
|
| 414 |
if os.path.exists(output_dir):
|
| 415 |
shutil.rmtree(output_dir)
|
| 416 |
os.makedirs(output_dir, exist_ok=True)
|
|
@@ -421,7 +436,6 @@ def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_
|
|
| 421 |
all_stems = []
|
| 422 |
total_models = len(model_keys)
|
| 423 |
|
| 424 |
-
# Separate audio with each model
|
| 425 |
for i, model_key in enumerate(model_keys):
|
| 426 |
for category, models in ROFORMER_MODELS.items():
|
| 427 |
if model_key in models:
|
|
@@ -446,7 +460,6 @@ def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_
|
|
| 446 |
separation = separator.separate(audio)
|
| 447 |
stems = [os.path.join(output_dir, file_name) for file_name in separation]
|
| 448 |
|
| 449 |
-
# Filter excluded stems
|
| 450 |
if exclude_stems.strip():
|
| 451 |
excluded = [s.strip().lower() for s in exclude_stems.split(',')]
|
| 452 |
filtered_stems = [stem for stem in stems if not any(ex in os.path.basename(stem).lower() for ex in excluded)]
|
|
@@ -457,11 +470,10 @@ def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_
|
|
| 457 |
if not all_stems:
|
| 458 |
raise ValueError("No valid stems for ensemble after exclusion.")
|
| 459 |
|
| 460 |
-
|
| 461 |
-
if
|
| 462 |
weights = [1.0] * len(all_stems)
|
| 463 |
|
| 464 |
-
# Perform ensemble
|
| 465 |
output_file = os.path.join(output_dir, f"{base_name}_ensemble_{ensemble_method}.{out_format}")
|
| 466 |
ensemble_args = [
|
| 467 |
"--files", *all_stems,
|
|
@@ -477,11 +489,11 @@ def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_
|
|
| 477 |
|
| 478 |
def update_roformer_models(category):
|
| 479 |
"""Update Roformer model dropdown based on selected category."""
|
| 480 |
-
return gr.update(choices=list(ROFORMER_MODELS.get(category, {}).keys()))
|
| 481 |
|
| 482 |
def update_ensemble_models(category):
|
| 483 |
"""Update ensemble model dropdown based on selected category."""
|
| 484 |
-
return gr.update(choices=list(ROFORMER_MODELS.get(category, {}).keys()))
|
| 485 |
|
| 486 |
# Interface creation
|
| 487 |
def create_interface():
|
|
@@ -497,8 +509,8 @@ def create_interface():
|
|
| 497 |
model_file_dir = gr.Textbox(value="/tmp/audio-separator-models/", label="π Model Cache", placeholder="Path to model directory", interactive=True)
|
| 498 |
output_dir = gr.Textbox(value="output", label="π€ Output Directory", placeholder="Where to save results", interactive=True)
|
| 499 |
output_format = gr.Dropdown(value="wav", choices=OUTPUT_FORMATS, label="πΆ Output Format", interactive=True)
|
| 500 |
-
norm_threshold = gr.Slider(0.1, 1, value=0.9, step=0.1, label="π Normalization Threshold", interactive=True)
|
| 501 |
-
amp_threshold = gr.Slider(0.1, 1, value=0.3, step=0.1, label="π Amplification Threshold", interactive=True)
|
| 502 |
batch_size = gr.Slider(1, 16, value=4, step=1, label="β‘ Batch Size", interactive=True)
|
| 503 |
|
| 504 |
# Roformer Tab
|
|
@@ -563,13 +575,7 @@ def create_interface():
|
|
| 563 |
ensemble_category.change(update_ensemble_models, inputs=[ensemble_category], outputs=[ensemble_models])
|
| 564 |
download_ensemble.click(fn=download_audio, inputs=[url_ensemble], outputs=[ensemble_audio])
|
| 565 |
ensemble_button.click(
|
| 566 |
-
fn=
|
| 567 |
-
norm_thresh, amp_thresh, batch_size, method, exclude_stems, weights_str:
|
| 568 |
-
auto_ensemble_process(
|
| 569 |
-
audio, models, seg_size, overlap, out_format, use_tta, model_dir, output_dir,
|
| 570 |
-
norm_thresh, amp_thresh, batch_size, method, exclude_stems,
|
| 571 |
-
[float(w.strip()) for w in weights_str.split(',')] if weights_str.strip() else None
|
| 572 |
-
),
|
| 573 |
inputs=[
|
| 574 |
ensemble_audio, ensemble_models, ensemble_seg_size, ensemble_overlap,
|
| 575 |
output_format, ensemble_use_tta, model_file_dir, output_dir,
|
|
@@ -588,8 +594,8 @@ if __name__ == "__main__":
|
|
| 588 |
|
| 589 |
app = create_interface()
|
| 590 |
try:
|
| 591 |
-
#
|
| 592 |
-
app.launch(server_name="0.0.0.0", server_port=args.port, share=
|
| 593 |
except Exception as e:
|
| 594 |
logger.error(f"Failed to launch app: {e}")
|
| 595 |
raise
|
|
|
|
| 11 |
import soundfile as sf
|
| 12 |
from ensemble import ensemble_files
|
| 13 |
import shutil
|
| 14 |
+
import gradio_client.utils as client_utils
|
| 15 |
+
|
| 16 |
+
# Patch gradio_client.utils.get_type to handle boolean schemas
|
| 17 |
+
def patched_get_type(schema):
|
| 18 |
+
if isinstance(schema, bool):
|
| 19 |
+
return "boolean"
|
| 20 |
+
if "const" in schema:
|
| 21 |
+
return repr(schema["const"])
|
| 22 |
+
if "enum" in schema:
|
| 23 |
+
return f"Union[{', '.join(repr(e) for e in schema['enum'])}]"
|
| 24 |
+
if "type" not in schema:
|
| 25 |
+
return "Any"
|
| 26 |
+
type_ = schema["type"]
|
| 27 |
+
if isinstance(type_, list):
|
| 28 |
+
return f"Union[{', '.join(t for t in type_ if t != 'null')}]"
|
| 29 |
+
if type_ == "array":
|
| 30 |
+
return f"List[{client_utils._json_schema_to_python_type(schema.get('items', {}), schema.get('$defs', {}))}]"
|
| 31 |
+
if type_ == "object":
|
| 32 |
+
return "Dict[str, Any]"
|
| 33 |
+
return type_
|
| 34 |
+
|
| 35 |
+
client_utils.get_type = patched_get_type
|
| 36 |
|
| 37 |
# Device and autocast setup
|
| 38 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 154 |
overflow: hidden;
|
| 155 |
}
|
| 156 |
body {
|
| 157 |
+
background: none;
|
|
|
|
| 158 |
margin: 0;
|
| 159 |
padding: 0;
|
| 160 |
font-family: 'Roboto', sans-serif;
|
|
|
|
| 356 |
if not url:
|
| 357 |
raise ValueError("No URL provided.")
|
| 358 |
|
|
|
|
| 359 |
if os.path.exists(out_dir):
|
| 360 |
shutil.rmtree(out_dir)
|
| 361 |
os.makedirs(out_dir, exist_ok=True)
|
|
|
|
| 378 |
if not audio:
|
| 379 |
raise ValueError("No audio file provided.")
|
| 380 |
|
|
|
|
| 381 |
override_seg_size = override_seg_size == "True"
|
| 382 |
|
|
|
|
| 383 |
if os.path.exists(output_dir):
|
| 384 |
shutil.rmtree(output_dir)
|
| 385 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
| 410 |
separation = separator.separate(audio)
|
| 411 |
stems = [os.path.join(output_dir, file_name) for file_name in separation]
|
| 412 |
|
|
|
|
| 413 |
if exclude_stems.strip():
|
| 414 |
excluded = [s.strip().lower() for s in exclude_stems.split(',')]
|
| 415 |
filtered_stems = [stem for stem in stems if not any(ex in os.path.basename(stem).lower() for ex in excluded)]
|
|
|
|
| 419 |
logger.error(f"Separation failed: {e}")
|
| 420 |
raise RuntimeError(f"Separation failed: {e}")
|
| 421 |
|
| 422 |
+
def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_tta, model_dir, output_dir, norm_thresh, amp_thresh, batch_size, ensemble_method, exclude_stems="", weights_str="", progress=gr.Progress()):
|
| 423 |
"""Perform ensemble processing on audio using multiple Roformer models."""
|
| 424 |
if not audio or not model_keys:
|
| 425 |
raise ValueError("Audio or models missing.")
|
| 426 |
|
|
|
|
| 427 |
use_tta = use_tta == "True"
|
| 428 |
|
|
|
|
| 429 |
if os.path.exists(output_dir):
|
| 430 |
shutil.rmtree(output_dir)
|
| 431 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
| 436 |
all_stems = []
|
| 437 |
total_models = len(model_keys)
|
| 438 |
|
|
|
|
| 439 |
for i, model_key in enumerate(model_keys):
|
| 440 |
for category, models in ROFORMER_MODELS.items():
|
| 441 |
if model_key in models:
|
|
|
|
| 460 |
separation = separator.separate(audio)
|
| 461 |
stems = [os.path.join(output_dir, file_name) for file_name in separation]
|
| 462 |
|
|
|
|
| 463 |
if exclude_stems.strip():
|
| 464 |
excluded = [s.strip().lower() for s in exclude_stems.split(',')]
|
| 465 |
filtered_stems = [stem for stem in stems if not any(ex in os.path.basename(stem).lower() for ex in excluded)]
|
|
|
|
| 470 |
if not all_stems:
|
| 471 |
raise ValueError("No valid stems for ensemble after exclusion.")
|
| 472 |
|
| 473 |
+
weights = [float(w.strip()) for w in weights_str.split(',')] if weights_str.strip() else [1.0] * len(all_stems)
|
| 474 |
+
if len(weights) != len(all_stems):
|
| 475 |
weights = [1.0] * len(all_stems)
|
| 476 |
|
|
|
|
| 477 |
output_file = os.path.join(output_dir, f"{base_name}_ensemble_{ensemble_method}.{out_format}")
|
| 478 |
ensemble_args = [
|
| 479 |
"--files", *all_stems,
|
|
|
|
| 489 |
|
| 490 |
def update_roformer_models(category):
|
| 491 |
"""Update Roformer model dropdown based on selected category."""
|
| 492 |
+
return gr.update(choices=list(ROFORMER_MODELS.get(category, {}).keys()) or [])
|
| 493 |
|
| 494 |
def update_ensemble_models(category):
|
| 495 |
"""Update ensemble model dropdown based on selected category."""
|
| 496 |
+
return gr.update(choices=list(ROFORMER_MODELS.get(category, {}).keys()) or [])
|
| 497 |
|
| 498 |
# Interface creation
|
| 499 |
def create_interface():
|
|
|
|
| 509 |
model_file_dir = gr.Textbox(value="/tmp/audio-separator-models/", label="π Model Cache", placeholder="Path to model directory", interactive=True)
|
| 510 |
output_dir = gr.Textbox(value="output", label="π€ Output Directory", placeholder="Where to save results", interactive=True)
|
| 511 |
output_format = gr.Dropdown(value="wav", choices=OUTPUT_FORMATS, label="πΆ Output Format", interactive=True)
|
| 512 |
+
norm_threshold = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="π Normalization Threshold", interactive=True)
|
| 513 |
+
amp_threshold = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="π Amplification Threshold", interactive=True)
|
| 514 |
batch_size = gr.Slider(1, 16, value=4, step=1, label="β‘ Batch Size", interactive=True)
|
| 515 |
|
| 516 |
# Roformer Tab
|
|
|
|
| 575 |
ensemble_category.change(update_ensemble_models, inputs=[ensemble_category], outputs=[ensemble_models])
|
| 576 |
download_ensemble.click(fn=download_audio, inputs=[url_ensemble], outputs=[ensemble_audio])
|
| 577 |
ensemble_button.click(
|
| 578 |
+
fn=auto_ensemble_process,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
inputs=[
|
| 580 |
ensemble_audio, ensemble_models, ensemble_seg_size, ensemble_overlap,
|
| 581 |
output_format, ensemble_use_tta, model_file_dir, output_dir,
|
|
|
|
| 594 |
|
| 595 |
app = create_interface()
|
| 596 |
try:
|
| 597 |
+
# For Hugging Face Spaces or local testing
|
| 598 |
+
app.launch(server_name="0.0.0.0", server_port=args.port, share=False)
|
| 599 |
except Exception as e:
|
| 600 |
logger.error(f"Failed to launch app: {e}")
|
| 601 |
raise
|