Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,16 +11,17 @@ from audio_separator.separator import Separator
|
|
| 11 |
import numpy as np
|
| 12 |
import librosa
|
| 13 |
import soundfile as sf
|
| 14 |
-
from ensemble import ensemble_files
|
| 15 |
import shutil
|
| 16 |
import gradio_client.utils as client_utils
|
| 17 |
import matchering as mg
|
| 18 |
import spaces
|
| 19 |
import gdown
|
| 20 |
-
import scipy.io.wavfile
|
| 21 |
from pydub import AudioSegment
|
| 22 |
import gc
|
| 23 |
import time
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# Logging setup
|
| 26 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -230,8 +231,7 @@ button:hover {
|
|
| 230 |
box-shadow: 0 2px 8px rgba(255, 107, 107, 0.4) !important;
|
| 231 |
}
|
| 232 |
.compact-dropdown select, .compact-dropdown .gr-dropdown {
|
| 233 |
-
background: transparent !
|
| 234 |
-
|
| 235 |
color: #e0e0e0 !important;
|
| 236 |
border: none !important;
|
| 237 |
width: 100% !important;
|
|
@@ -340,15 +340,14 @@ def download_audio(url, cookie_file=None):
|
|
| 340 |
gdown.download(download_url, temp_output_path, quiet=False)
|
| 341 |
if not os.path.exists(temp_output_path):
|
| 342 |
return None, "Downloaded file not found", None
|
| 343 |
-
from mimetypes import guess_type
|
| 344 |
-
mime_type, _ = guess_type(temp_output_path)
|
| 345 |
-
if not mime_type or not mime_type.startswith('audio'):
|
| 346 |
-
return None, "Downloaded file is not an audio file", None
|
| 347 |
output_path = 'ytdl/gdrive_audio.wav'
|
| 348 |
-
|
| 349 |
-
|
|
|
|
|
|
|
|
|
|
| 350 |
sample_rate, data = scipy.io.wavfile.read(output_path)
|
| 351 |
-
return output_path, "Download successful", (sample_rate, data)
|
| 352 |
else:
|
| 353 |
os.makedirs('ytdl', exist_ok=True)
|
| 354 |
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
|
@@ -433,23 +432,36 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
|
|
| 433 |
temp_audio_path = None
|
| 434 |
chunk_paths = []
|
| 435 |
max_retries = 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
try:
|
| 437 |
if not audio:
|
| 438 |
raise ValueError("No audio file provided.")
|
| 439 |
if not model_keys:
|
| 440 |
raise ValueError("No models selected.")
|
| 441 |
-
if len(model_keys) >
|
| 442 |
-
logger.warning("
|
| 443 |
-
model_keys = model_keys[:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
if isinstance(audio, tuple):
|
| 445 |
sample_rate, data = audio
|
| 446 |
temp_audio_path = os.path.join("/tmp", "temp_audio.wav")
|
| 447 |
scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
|
| 448 |
audio = temp_audio_path
|
|
|
|
| 449 |
audio_data, sr = librosa.load(audio, sr=None, mono=False)
|
| 450 |
duration = librosa.get_duration(y=audio_data, sr=sr)
|
| 451 |
logger.info(f"Audio duration: {duration:.2f} seconds")
|
| 452 |
-
|
|
|
|
|
|
|
| 453 |
chunks = []
|
| 454 |
if duration > 900:
|
| 455 |
logger.info(f"Audio exceeds 15 minutes, splitting into {chunk_duration}-second chunks")
|
|
@@ -465,70 +477,116 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
|
|
| 465 |
logger.info(f"Created chunk {i}: {chunk_path}")
|
| 466 |
else:
|
| 467 |
chunks = [audio]
|
|
|
|
| 468 |
use_tta = use_tta == "True"
|
| 469 |
if os.path.exists(output_dir):
|
| 470 |
shutil.rmtree(output_dir)
|
| 471 |
os.makedirs(output_dir, exist_ok=True)
|
| 472 |
base_name = os.path.splitext(os.path.basename(audio))[0]
|
| 473 |
logger.info(f"Ensemble for {base_name} with {model_keys} on {device}")
|
|
|
|
|
|
|
|
|
|
| 474 |
all_stems = []
|
| 475 |
-
model_stems = {}
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
model = models[model_key]
|
| 482 |
-
break
|
| 483 |
-
else:
|
| 484 |
-
logger.warning(f"Model {model_key} not found, skipping")
|
| 485 |
-
continue
|
| 486 |
-
for chunk_idx, chunk_path in enumerate(chunks):
|
| 487 |
-
retry_count = 0
|
| 488 |
-
while retry_count <= max_retries:
|
| 489 |
try:
|
| 490 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
model_path = os.path.join(model_dir, model)
|
| 492 |
-
if not
|
| 493 |
-
logger.info(f"
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
except Exception as e:
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
if "ZeroGPU" in str(e) or "aborted" in str(e).lower():
|
| 520 |
-
logger.error("ZeroGPU task aborted, attempting recovery")
|
| 521 |
-
if retry_count > max_retries:
|
| 522 |
logger.error(f"Max retries reached for {model_key} chunk {chunk_idx}, skipping")
|
| 523 |
-
|
| 524 |
time.sleep(1)
|
| 525 |
finally:
|
| 526 |
-
separator = None
|
| 527 |
-
gc.collect()
|
| 528 |
if torch.cuda.is_available():
|
| 529 |
torch.cuda.empty_cache()
|
| 530 |
logger.info(f"Cleared CUDA cache after {model_key} chunk {chunk_idx}")
|
| 531 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
progress(0.8, desc="Combining stems...")
|
| 533 |
for model_key, stems_dict in model_stems.items():
|
| 534 |
for stem_type in ["vocals", "other"]:
|
|
@@ -546,13 +604,16 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
|
|
| 546 |
all_stems.append(combined_path)
|
| 547 |
except Exception as e:
|
| 548 |
logger.error(f"Error combining {stem_type} for {model_key}: {e}")
|
|
|
|
| 549 |
all_stems = [stem for stem in all_stems if os.path.exists(stem)]
|
| 550 |
if not all_stems:
|
| 551 |
raise ValueError("No valid stems found for ensemble. Try uploading a local WAV file.")
|
|
|
|
|
|
|
| 552 |
weights = [float(w.strip()) for w in weights_str.split(',')] if weights_str.strip() else [1.0] * len(all_stems)
|
| 553 |
if len(weights) != len(all_stems):
|
| 554 |
weights = [1.0] * len(all_stems)
|
| 555 |
-
logger.info("Weights mismatched,
|
| 556 |
output_file = os.path.join(output_dir, f"{base_name}_ensemble_{ensemble_method}.{out_format}")
|
| 557 |
ensemble_args = [
|
| 558 |
"--files", *all_stems,
|
|
@@ -563,12 +624,14 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
|
|
| 563 |
progress(0.9, desc="Running ensemble...")
|
| 564 |
logger.info(f"Running ensemble with args: {ensemble_args}")
|
| 565 |
try:
|
| 566 |
-
result = ensemble_files(ensemble_args)
|
| 567 |
if result is None or not os.path.exists(output_file):
|
| 568 |
raise RuntimeError(f"Ensemble failed, output file not created: {output_file}")
|
| 569 |
logger.info(f"Ensemble completed, output: {output_file}")
|
| 570 |
progress(1.0, desc="Ensemble completed")
|
| 571 |
-
|
|
|
|
|
|
|
| 572 |
except Exception as e:
|
| 573 |
logger.error(f"Ensemble processing error: {e}")
|
| 574 |
if "numpy" in str(e).lower() or "copy" in str(e).lower():
|
|
@@ -578,8 +641,8 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
|
|
| 578 |
raise RuntimeError(error_msg)
|
| 579 |
except Exception as e:
|
| 580 |
logger.error(f"Ensemble error: {e}")
|
| 581 |
-
if "ZeroGPU" in str(e) or "aborted" in str(e).lower():
|
| 582 |
-
error_msg = "ZeroGPU task aborted. Try
|
| 583 |
else:
|
| 584 |
error_msg = f"Ensemble error: {e}"
|
| 585 |
raise RuntimeError(error_msg)
|
|
@@ -615,8 +678,8 @@ def create_interface():
|
|
| 615 |
with gr.Blocks(title="π΅ SESA Fast Separation π΅", css=CSS, elem_id="app-container") as app:
|
| 616 |
gr.Markdown("<h1 class='header-text'>π΅ SESA Fast Separation π΅</h1>")
|
| 617 |
gr.Markdown("**Note**: If YouTube downloads fail, upload a valid cookies file or a local WAV file. [Cookie Instructions](https://github.com/yt-dlp/yt-dlp/wiki/Extractors#exporting-youtube-cookies)")
|
| 618 |
-
gr.Markdown("**Warning**: Audio files longer than 15 minutes are split into 5-minute chunks,
|
| 619 |
-
gr.Markdown("**ZeroGPU Notice**:
|
| 620 |
with gr.Tabs():
|
| 621 |
with gr.Tab("βοΈ Settings"):
|
| 622 |
with gr.Group(elem_classes="dubbing-theme"):
|
|
@@ -653,7 +716,7 @@ def create_interface():
|
|
| 653 |
with gr.Tab("ποΈ Auto Ensemble"):
|
| 654 |
with gr.Group(elem_classes="dubbing-theme"):
|
| 655 |
gr.Markdown("### Ensemble Processing")
|
| 656 |
-
gr.Markdown("Note: If weights are not specified, equal weights (1.0) are applied. Max
|
| 657 |
with gr.Row():
|
| 658 |
ensemble_audio = gr.Audio(label="π§ Upload Audio", type="filepath", interactive=True)
|
| 659 |
url_ensemble = gr.Textbox(label="π Or Paste URL", placeholder="YouTube or audio URL", interactive=True)
|
|
@@ -663,13 +726,13 @@ def create_interface():
|
|
| 663 |
ensemble_exclude_stems = gr.Textbox(label="π« Exclude Stems", placeholder="e.g., vocals, drums (comma-separated)", interactive=True)
|
| 664 |
with gr.Row():
|
| 665 |
ensemble_category = gr.Dropdown(label="π Category", choices=list(ROFORMER_MODELS.keys()), value="Instrumentals", interactive=True)
|
| 666 |
-
ensemble_models = gr.Dropdown(label="π οΈ Models (Max
|
| 667 |
with gr.Row():
|
| 668 |
ensemble_seg_size = gr.Slider(32, 512, value=128, step=32, label="π Segment Size", interactive=True)
|
| 669 |
ensemble_overlap = gr.Slider(2, 10, value=8, step=1, label="π Overlap", interactive=True)
|
| 670 |
ensemble_use_tta = gr.Dropdown(choices=["True", "False"], value="False", label="π Use TTA", interactive=True)
|
| 671 |
ensemble_method = gr.Dropdown(label="βοΈ Ensemble Method", choices=['avg_wave', 'median_wave', 'max_wave', 'min_wave', 'avg_fft', 'median_fft', 'max_fft', 'min_fft'], value='avg_wave', interactive=True)
|
| 672 |
-
ensemble_weights = gr.Textbox(label="βοΈ Weights", placeholder="e.g., 1.0, 1.0 (comma-separated)", interactive=True)
|
| 673 |
ensemble_button = gr.Button("ποΈ Run Ensemble!", variant="primary")
|
| 674 |
ensemble_output = gr.Audio(label="πΆ Ensemble Result", type="filepath", interactive=False)
|
| 675 |
ensemble_status = gr.Textbox(label="π’ Status", interactive=False)
|
|
@@ -699,7 +762,7 @@ def create_interface():
|
|
| 699 |
fn=auto_ensemble_process,
|
| 700 |
inputs=[
|
| 701 |
ensemble_audio, ensemble_models, ensemble_seg_size, ensemble_overlap,
|
| 702 |
-
output_format, ensemble_use_tta,
|
| 703 |
norm_threshold, amp_threshold, batch_size, ensemble_method,
|
| 704 |
ensemble_exclude_stems, ensemble_weights
|
| 705 |
],
|
|
|
|
| 11 |
import numpy as np
|
| 12 |
import librosa
|
| 13 |
import soundfile as sf
|
| 14 |
+
from ensemble import ensemble_files
|
| 15 |
import shutil
|
| 16 |
import gradio_client.utils as client_utils
|
| 17 |
import matchering as mg
|
| 18 |
import spaces
|
| 19 |
import gdown
|
|
|
|
| 20 |
from pydub import AudioSegment
|
| 21 |
import gc
|
| 22 |
import time
|
| 23 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 24 |
+
from threading import Lock
|
| 25 |
|
| 26 |
# Logging setup
|
| 27 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 231 |
box-shadow: 0 2px 8px rgba(255, 107, 107, 0.4) !important;
|
| 232 |
}
|
| 233 |
.compact-dropdown select, .compact-dropdown .gr-dropdown {
|
| 234 |
+
background: transparent !important;
|
|
|
|
| 235 |
color: #e0e0e0 !important;
|
| 236 |
border: none !important;
|
| 237 |
width: 100% !important;
|
|
|
|
| 340 |
gdown.download(download_url, temp_output_path, quiet=False)
|
| 341 |
if not os.path.exists(temp_output_path):
|
| 342 |
return None, "Downloaded file not found", None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
output_path = 'ytdl/gdrive_audio.wav'
|
| 344 |
+
try:
|
| 345 |
+
audio = AudioSegment.from_file(temp_output_path)
|
| 346 |
+
audio.export(output_path, format="wav")
|
| 347 |
+
except Exception as e:
|
| 348 |
+
return None, f"Failed to process Google Drive file as audio: {str(e)}. Ensure the file contains audio (e.g., MP3, WAV, or video with audio track).", None
|
| 349 |
sample_rate, data = scipy.io.wavfile.read(output_path)
|
| 350 |
+
return output_path, "Download and audio conversion successful", (sample_rate, data)
|
| 351 |
else:
|
| 352 |
os.makedirs('ytdl', exist_ok=True)
|
| 353 |
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
|
|
|
| 432 |
temp_audio_path = None
|
| 433 |
chunk_paths = []
|
| 434 |
max_retries = 2
|
| 435 |
+
start_time = time.time()
|
| 436 |
+
time_budget = 100 # seconds, to stay within ZeroGPU limit
|
| 437 |
+
max_models = 6 # Reasonable limit to prevent timeouts
|
| 438 |
+
gpu_lock = Lock() # Ensure only one model uses GPU at a time
|
| 439 |
+
|
| 440 |
try:
|
| 441 |
if not audio:
|
| 442 |
raise ValueError("No audio file provided.")
|
| 443 |
if not model_keys:
|
| 444 |
raise ValueError("No models selected.")
|
| 445 |
+
if len(model_keys) > max_models:
|
| 446 |
+
logger.warning(f"Selected {len(model_keys)} models, limiting to {max_models} to avoid ZeroGPU timeouts.")
|
| 447 |
+
model_keys = model_keys[:max_models]
|
| 448 |
+
|
| 449 |
+
# Dynamic batch size adjustment
|
| 450 |
+
dynamic_batch_size = max(1, min(4, 1 + (6 - len(model_keys)) // 2))
|
| 451 |
+
logger.info(f"Using batch size: {dynamic_batch_size} for {len(model_keys)} models")
|
| 452 |
+
|
| 453 |
if isinstance(audio, tuple):
|
| 454 |
sample_rate, data = audio
|
| 455 |
temp_audio_path = os.path.join("/tmp", "temp_audio.wav")
|
| 456 |
scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
|
| 457 |
audio = temp_audio_path
|
| 458 |
+
|
| 459 |
audio_data, sr = librosa.load(audio, sr=None, mono=False)
|
| 460 |
duration = librosa.get_duration(y=audio_data, sr=sr)
|
| 461 |
logger.info(f"Audio duration: {duration:.2f} seconds")
|
| 462 |
+
|
| 463 |
+
# Optimize chunking
|
| 464 |
+
chunk_duration = 300 if duration > 900 else duration
|
| 465 |
chunks = []
|
| 466 |
if duration > 900:
|
| 467 |
logger.info(f"Audio exceeds 15 minutes, splitting into {chunk_duration}-second chunks")
|
|
|
|
| 477 |
logger.info(f"Created chunk {i}: {chunk_path}")
|
| 478 |
else:
|
| 479 |
chunks = [audio]
|
| 480 |
+
|
| 481 |
use_tta = use_tta == "True"
|
| 482 |
if os.path.exists(output_dir):
|
| 483 |
shutil.rmtree(output_dir)
|
| 484 |
os.makedirs(output_dir, exist_ok=True)
|
| 485 |
base_name = os.path.splitext(os.path.basename(audio))[0]
|
| 486 |
logger.info(f"Ensemble for {base_name} with {model_keys} on {device}")
|
| 487 |
+
|
| 488 |
+
# Model cache
|
| 489 |
+
model_cache = {}
|
| 490 |
all_stems = []
|
| 491 |
+
model_stems = {model_key: {"vocals": [], "other": []} for model_key in model_keys}
|
| 492 |
+
total_tasks = len(model_keys) * len(chunks)
|
| 493 |
+
|
| 494 |
+
def process_model_chunk(model_key, chunk_path, chunk_idx, model_idx):
|
| 495 |
+
with torch.no_grad():
|
| 496 |
+
for attempt in range(max_retries + 1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
try:
|
| 498 |
+
# Find model
|
| 499 |
+
for category, models in ROFORMER_MODELS.items():
|
| 500 |
+
if model_key in models:
|
| 501 |
+
model = models[model_key]
|
| 502 |
+
break
|
| 503 |
+
else:
|
| 504 |
+
logger.warning(f"Model {model_key} not found, skipping")
|
| 505 |
+
return []
|
| 506 |
+
|
| 507 |
+
# Check time budget
|
| 508 |
+
elapsed = time.time() - start_time
|
| 509 |
+
if elapsed > time_budget:
|
| 510 |
+
logger.error(f"Time budget ({time_budget}s) exceeded, aborting")
|
| 511 |
+
raise TimeoutError("Processing exceeded time budget")
|
| 512 |
+
|
| 513 |
+
# Initialize separator
|
| 514 |
model_path = os.path.join(model_dir, model)
|
| 515 |
+
if model_key not in model_cache:
|
| 516 |
+
logger.info(f"Loading {model_key} into cache")
|
| 517 |
+
separator = Separator(
|
| 518 |
+
log_level=logging.INFO,
|
| 519 |
+
model_file_dir=model_dir,
|
| 520 |
+
output_dir=output_dir,
|
| 521 |
+
output_format=out_format,
|
| 522 |
+
normalization_threshold=norm_thresh,
|
| 523 |
+
amplification_threshold=amp_thresh,
|
| 524 |
+
use_autocast=use_autocast,
|
| 525 |
+
mdxc_params={
|
| 526 |
+
"segment_size": seg_size,
|
| 527 |
+
"overlap": overlap,
|
| 528 |
+
"use_tta": use_tta,
|
| 529 |
+
"batch_size": dynamic_batch_size
|
| 530 |
+
}
|
| 531 |
+
)
|
| 532 |
+
separator.load_model(model_filename=model)
|
| 533 |
+
model_cache[model_key] = separator
|
| 534 |
+
else:
|
| 535 |
+
separator = model_cache[model_key]
|
| 536 |
+
|
| 537 |
+
# Process with GPU lock
|
| 538 |
+
with gpu_lock:
|
| 539 |
+
progress((model_idx + chunk_idx / len(chunks)) / len(model_keys), desc=f"Separating chunk {chunk_idx} with {model_key}")
|
| 540 |
+
logger.info(f"Separating chunk {chunk_idx} with {model_key}")
|
| 541 |
+
separation = separator.separate(chunk_path)
|
| 542 |
+
stems = [os.path.join(output_dir, file_name) for file_name in separation]
|
| 543 |
+
result = []
|
| 544 |
+
for stem in stems:
|
| 545 |
+
if "vocals" in os.path.basename(stem).lower():
|
| 546 |
+
model_stems[model_key]["vocals"].append(stem)
|
| 547 |
+
elif "other" in os.path.basename(stem).lower() or "instrumental" in os.path.basename(stem).lower():
|
| 548 |
+
model_stems[model_key]["other"].append(stem)
|
| 549 |
+
result.append(stem)
|
| 550 |
+
return result
|
| 551 |
except Exception as e:
|
| 552 |
+
logger.error(f"Error processing {model_key} chunk {chunk_idx}, attempt {attempt + 1}/{max_retries + 1}: {e}")
|
| 553 |
+
if attempt == max_retries:
|
|
|
|
|
|
|
|
|
|
| 554 |
logger.error(f"Max retries reached for {model_key} chunk {chunk_idx}, skipping")
|
| 555 |
+
return []
|
| 556 |
time.sleep(1)
|
| 557 |
finally:
|
|
|
|
|
|
|
| 558 |
if torch.cuda.is_available():
|
| 559 |
torch.cuda.empty_cache()
|
| 560 |
logger.info(f"Cleared CUDA cache after {model_key} chunk {chunk_idx}")
|
| 561 |
+
|
| 562 |
+
# Parallel processing
|
| 563 |
+
progress(0.1, desc="Starting model separations...")
|
| 564 |
+
with ThreadPoolExecutor(max_workers=min(4, len(model_keys))) as executor:
|
| 565 |
+
future_to_task = {}
|
| 566 |
+
for model_idx, model_key in enumerate(model_keys):
|
| 567 |
+
for chunk_idx, chunk_path in enumerate(chunks):
|
| 568 |
+
future = executor.submit(process_model_chunk, model_key, chunk_path, chunk_idx, model_idx)
|
| 569 |
+
future_to_task[future] = (model_key, chunk_idx)
|
| 570 |
+
|
| 571 |
+
for future in as_completed(future_to_task):
|
| 572 |
+
model_key, chunk_idx = future_to_task[future]
|
| 573 |
+
try:
|
| 574 |
+
stems = future.result()
|
| 575 |
+
if stems:
|
| 576 |
+
logger.info(f"Completed {model_key} chunk {chunk_idx}")
|
| 577 |
+
else:
|
| 578 |
+
logger.warning(f"No stems produced for {model_key} chunk {chunk_idx}")
|
| 579 |
+
except Exception as e:
|
| 580 |
+
logger.error(f"Task {model_key} chunk {chunk_idx} failed: {e}")
|
| 581 |
+
|
| 582 |
+
# Clear model cache
|
| 583 |
+
model_cache.clear()
|
| 584 |
+
gc.collect()
|
| 585 |
+
if torch.cuda.is_available():
|
| 586 |
+
torch.cuda.empty_cache()
|
| 587 |
+
logger.info("Cleared model cache and GPU memory")
|
| 588 |
+
|
| 589 |
+
# Combine stems
|
| 590 |
progress(0.8, desc="Combining stems...")
|
| 591 |
for model_key, stems_dict in model_stems.items():
|
| 592 |
for stem_type in ["vocals", "other"]:
|
|
|
|
| 604 |
all_stems.append(combined_path)
|
| 605 |
except Exception as e:
|
| 606 |
logger.error(f"Error combining {stem_type} for {model_key}: {e}")
|
| 607 |
+
|
| 608 |
all_stems = [stem for stem in all_stems if os.path.exists(stem)]
|
| 609 |
if not all_stems:
|
| 610 |
raise ValueError("No valid stems found for ensemble. Try uploading a local WAV file.")
|
| 611 |
+
|
| 612 |
+
# Ensemble
|
| 613 |
weights = [float(w.strip()) for w in weights_str.split(',')] if weights_str.strip() else [1.0] * len(all_stems)
|
| 614 |
if len(weights) != len(all_stems):
|
| 615 |
weights = [1.0] * len(all_stems)
|
| 616 |
+
logger.info("Weights mismatched, defaulting to 1.0")
|
| 617 |
output_file = os.path.join(output_dir, f"{base_name}_ensemble_{ensemble_method}.{out_format}")
|
| 618 |
ensemble_args = [
|
| 619 |
"--files", *all_stems,
|
|
|
|
| 624 |
progress(0.9, desc="Running ensemble...")
|
| 625 |
logger.info(f"Running ensemble with args: {ensemble_args}")
|
| 626 |
try:
|
| 627 |
+
result = ensemble_files(ensemble_args)
|
| 628 |
if result is None or not os.path.exists(output_file):
|
| 629 |
raise RuntimeError(f"Ensemble failed, output file not created: {output_file}")
|
| 630 |
logger.info(f"Ensemble completed, output: {output_file}")
|
| 631 |
progress(1.0, desc="Ensemble completed")
|
| 632 |
+
elapsed = time.time() - start_time
|
| 633 |
+
logger.info(f"Total processing time: {elapsed:.2f}s")
|
| 634 |
+
return output_file, f"Ensemble completed with {ensemble_method}, excluded: {exclude_stems if exclude_stems else 'None'}, {len(model_keys)} models in {elapsed:.2f}s"
|
| 635 |
except Exception as e:
|
| 636 |
logger.error(f"Ensemble processing error: {e}")
|
| 637 |
if "numpy" in str(e).lower() or "copy" in str(e).lower():
|
|
|
|
| 641 |
raise RuntimeError(error_msg)
|
| 642 |
except Exception as e:
|
| 643 |
logger.error(f"Ensemble error: {e}")
|
| 644 |
+
if "ZeroGPU" in str(e) or "aborted" in str(e).lower() or isinstance(e, TimeoutError):
|
| 645 |
+
error_msg = f"ZeroGPU task aborted or timed out. Try fewer models (max {max_models}), shorter audio, or uploading a local WAV file."
|
| 646 |
else:
|
| 647 |
error_msg = f"Ensemble error: {e}"
|
| 648 |
raise RuntimeError(error_msg)
|
|
|
|
| 678 |
with gr.Blocks(title="π΅ SESA Fast Separation π΅", css=CSS, elem_id="app-container") as app:
|
| 679 |
gr.Markdown("<h1 class='header-text'>π΅ SESA Fast Separation π΅</h1>")
|
| 680 |
gr.Markdown("**Note**: If YouTube downloads fail, upload a valid cookies file or a local WAV file. [Cookie Instructions](https://github.com/yt-dlp/yt-dlp/wiki/Extractors#exporting-youtube-cookies)")
|
| 681 |
+
gr.Markdown("**Warning**: Audio files longer than 15 minutes are split into 5-minute chunks, increasing processing time.")
|
| 682 |
+
gr.Markdown("**ZeroGPU Notice**: Up to 6 models supported for ensemble. For long audio, use fewer models or a local WAV file to avoid timeouts.")
|
| 683 |
with gr.Tabs():
|
| 684 |
with gr.Tab("βοΈ Settings"):
|
| 685 |
with gr.Group(elem_classes="dubbing-theme"):
|
|
|
|
| 716 |
with gr.Tab("ποΈ Auto Ensemble"):
|
| 717 |
with gr.Group(elem_classes="dubbing-theme"):
|
| 718 |
gr.Markdown("### Ensemble Processing")
|
| 719 |
+
gr.Markdown("Note: If weights are not specified, equal weights (1.0) are applied. Max 6 models recommended to avoid ZeroGPU timeouts.")
|
| 720 |
with gr.Row():
|
| 721 |
ensemble_audio = gr.Audio(label="π§ Upload Audio", type="filepath", interactive=True)
|
| 722 |
url_ensemble = gr.Textbox(label="π Or Paste URL", placeholder="YouTube or audio URL", interactive=True)
|
|
|
|
| 726 |
ensemble_exclude_stems = gr.Textbox(label="π« Exclude Stems", placeholder="e.g., vocals, drums (comma-separated)", interactive=True)
|
| 727 |
with gr.Row():
|
| 728 |
ensemble_category = gr.Dropdown(label="π Category", choices=list(ROFORMER_MODELS.keys()), value="Instrumentals", interactive=True)
|
| 729 |
+
ensemble_models = gr.Dropdown(label="π οΈ Models (Max 6)", choices=list(ROFORMER_MODELS["Instrumentals"].keys()), multiselect=True, interactive=True, allow_custom_value=True)
|
| 730 |
with gr.Row():
|
| 731 |
ensemble_seg_size = gr.Slider(32, 512, value=128, step=32, label="π Segment Size", interactive=True)
|
| 732 |
ensemble_overlap = gr.Slider(2, 10, value=8, step=1, label="π Overlap", interactive=True)
|
| 733 |
ensemble_use_tta = gr.Dropdown(choices=["True", "False"], value="False", label="π Use TTA", interactive=True)
|
| 734 |
ensemble_method = gr.Dropdown(label="βοΈ Ensemble Method", choices=['avg_wave', 'median_wave', 'max_wave', 'min_wave', 'avg_fft', 'median_fft', 'max_fft', 'min_fft'], value='avg_wave', interactive=True)
|
| 735 |
+
ensemble_weights = gr.Textbox(label="βοΈ Weights", placeholder="e.g., 1.0, 1.0, 1.0 (comma-separated)", interactive=True)
|
| 736 |
ensemble_button = gr.Button("ποΈ Run Ensemble!", variant="primary")
|
| 737 |
ensemble_output = gr.Audio(label="πΆ Ensemble Result", type="filepath", interactive=False)
|
| 738 |
ensemble_status = gr.Textbox(label="π’ Status", interactive=False)
|
|
|
|
| 762 |
fn=auto_ensemble_process,
|
| 763 |
inputs=[
|
| 764 |
ensemble_audio, ensemble_models, ensemble_seg_size, ensemble_overlap,
|
| 765 |
+
output_format, ensemble_use_tta, model_dir, output_dir,
|
| 766 |
norm_threshold, amp_threshold, batch_size, ensemble_method,
|
| 767 |
ensemble_exclude_stems, ensemble_weights
|
| 768 |
],
|