ASesYusuf1 commited on
Commit
a5643d8
Β·
verified Β·
1 Parent(s): b00852e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -74
app.py CHANGED
@@ -11,16 +11,17 @@ from audio_separator.separator import Separator
11
  import numpy as np
12
  import librosa
13
  import soundfile as sf
14
- from ensemble import ensemble_files # Correct import
15
  import shutil
16
  import gradio_client.utils as client_utils
17
  import matchering as mg
18
  import spaces
19
  import gdown
20
- import scipy.io.wavfile
21
  from pydub import AudioSegment
22
  import gc
23
  import time
 
 
24
 
25
  # Logging setup
26
  logging.basicConfig(level=logging.INFO)
@@ -230,8 +231,7 @@ button:hover {
230
  box-shadow: 0 2px 8px rgba(255, 107, 107, 0.4) !important;
231
  }
232
  .compact-dropdown select, .compact-dropdown .gr-dropdown {
233
- background: transparent ! thαΊ­n
234
-
235
  color: #e0e0e0 !important;
236
  border: none !important;
237
  width: 100% !important;
@@ -340,15 +340,14 @@ def download_audio(url, cookie_file=None):
340
  gdown.download(download_url, temp_output_path, quiet=False)
341
  if not os.path.exists(temp_output_path):
342
  return None, "Downloaded file not found", None
343
- from mimetypes import guess_type
344
- mime_type, _ = guess_type(temp_output_path)
345
- if not mime_type or not mime_type.startswith('audio'):
346
- return None, "Downloaded file is not an audio file", None
347
  output_path = 'ytdl/gdrive_audio.wav'
348
- audio = AudioSegment.from_file(temp_output_path)
349
- audio.export(output_path, format="wav")
 
 
 
350
  sample_rate, data = scipy.io.wavfile.read(output_path)
351
- return output_path, "Download successful", (sample_rate, data)
352
  else:
353
  os.makedirs('ytdl', exist_ok=True)
354
  with yt_dlp.YoutubeDL(ydl_opts) as ydl:
@@ -433,23 +432,36 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
433
  temp_audio_path = None
434
  chunk_paths = []
435
  max_retries = 2
 
 
 
 
 
436
  try:
437
  if not audio:
438
  raise ValueError("No audio file provided.")
439
  if not model_keys:
440
  raise ValueError("No models selected.")
441
- if len(model_keys) > 2:
442
- logger.warning("Limited to 2 models to avoid ZeroGPU timeouts. Using first two: %s", model_keys[:2])
443
- model_keys = model_keys[:2]
 
 
 
 
 
444
  if isinstance(audio, tuple):
445
  sample_rate, data = audio
446
  temp_audio_path = os.path.join("/tmp", "temp_audio.wav")
447
  scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
448
  audio = temp_audio_path
 
449
  audio_data, sr = librosa.load(audio, sr=None, mono=False)
450
  duration = librosa.get_duration(y=audio_data, sr=sr)
451
  logger.info(f"Audio duration: {duration:.2f} seconds")
452
- chunk_duration = 300
 
 
453
  chunks = []
454
  if duration > 900:
455
  logger.info(f"Audio exceeds 15 minutes, splitting into {chunk_duration}-second chunks")
@@ -465,70 +477,116 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
465
  logger.info(f"Created chunk {i}: {chunk_path}")
466
  else:
467
  chunks = [audio]
 
468
  use_tta = use_tta == "True"
469
  if os.path.exists(output_dir):
470
  shutil.rmtree(output_dir)
471
  os.makedirs(output_dir, exist_ok=True)
472
  base_name = os.path.splitext(os.path.basename(audio))[0]
473
  logger.info(f"Ensemble for {base_name} with {model_keys} on {device}")
 
 
 
474
  all_stems = []
475
- model_stems = {}
476
- total_models = len(model_keys)
477
- for model_idx, model_key in enumerate(model_keys):
478
- model_stems[model_key] = {"vocals": [], "other": []}
479
- for category, models in ROFORMER_MODELS.items():
480
- if model_key in models:
481
- model = models[model_key]
482
- break
483
- else:
484
- logger.warning(f"Model {model_key} not found, skipping")
485
- continue
486
- for chunk_idx, chunk_path in enumerate(chunks):
487
- retry_count = 0
488
- while retry_count <= max_retries:
489
  try:
490
- progress((model_idx + 0.1) / total_models, desc=f"Loading {model_key} for chunk {chunk_idx}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  model_path = os.path.join(model_dir, model)
492
- if not os.path.exists(model_path):
493
- logger.info(f"Model {model} not cached, will download")
494
- separator = Separator(
495
- log_level=logging.INFO,
496
- model_file_dir=model_dir,
497
- output_dir=output_dir,
498
- output_format=out_format,
499
- normalization_threshold=norm_thresh,
500
- amplification_threshold=amp_thresh,
501
- use_autocast=use_autocast,
502
- mdxc_params={"segment_size": seg_size, "overlap": overlap, "use_tta": use_tta, "batch_size": batch_size}
503
- )
504
- logger.info(f"Loading {model_key} for chunk {chunk_idx}")
505
- separator.load_model(model_filename=model)
506
- progress((model_idx + 0.5) / total_models, desc=f"Separating chunk {chunk_idx} with {model_key}")
507
- logger.info(f"Separating chunk {chunk_idx} with {model_key}")
508
- separation = separator.separate(chunk_path)
509
- stems = [os.path.join(output_dir, file_name) for file_name in separation]
510
- for stem in stems:
511
- if "vocals" in os.path.basename(stem).lower():
512
- model_stems[model_key]["vocals"].append(stem)
513
- elif "other" in os.path.basename(stem).lower() or "instrumental" in os.path.basename(stem).lower():
514
- model_stems[model_key]["other"].append(stem)
515
- break
 
 
 
 
 
 
 
 
 
 
 
 
516
  except Exception as e:
517
- retry_count += 1
518
- logger.error(f"Error processing {model_key} chunk {chunk_idx}, attempt {retry_count}/{max_retries}: {e}")
519
- if "ZeroGPU" in str(e) or "aborted" in str(e).lower():
520
- logger.error("ZeroGPU task aborted, attempting recovery")
521
- if retry_count > max_retries:
522
  logger.error(f"Max retries reached for {model_key} chunk {chunk_idx}, skipping")
523
- break
524
  time.sleep(1)
525
  finally:
526
- separator = None
527
- gc.collect()
528
  if torch.cuda.is_available():
529
  torch.cuda.empty_cache()
530
  logger.info(f"Cleared CUDA cache after {model_key} chunk {chunk_idx}")
531
- time.sleep(0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  progress(0.8, desc="Combining stems...")
533
  for model_key, stems_dict in model_stems.items():
534
  for stem_type in ["vocals", "other"]:
@@ -546,13 +604,16 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
546
  all_stems.append(combined_path)
547
  except Exception as e:
548
  logger.error(f"Error combining {stem_type} for {model_key}: {e}")
 
549
  all_stems = [stem for stem in all_stems if os.path.exists(stem)]
550
  if not all_stems:
551
  raise ValueError("No valid stems found for ensemble. Try uploading a local WAV file.")
 
 
552
  weights = [float(w.strip()) for w in weights_str.split(',')] if weights_str.strip() else [1.0] * len(all_stems)
553
  if len(weights) != len(all_stems):
554
  weights = [1.0] * len(all_stems)
555
- logger.info("Weights mismatched, safest option is to default to 1.0")
556
  output_file = os.path.join(output_dir, f"{base_name}_ensemble_{ensemble_method}.{out_format}")
557
  ensemble_args = [
558
  "--files", *all_stems,
@@ -563,12 +624,14 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
563
  progress(0.9, desc="Running ensemble...")
564
  logger.info(f"Running ensemble with args: {ensemble_args}")
565
  try:
566
- result = ensemble_files(ensemble_args) # Correct function call
567
  if result is None or not os.path.exists(output_file):
568
  raise RuntimeError(f"Ensemble failed, output file not created: {output_file}")
569
  logger.info(f"Ensemble completed, output: {output_file}")
570
  progress(1.0, desc="Ensemble completed")
571
- return output_file, f"Ensemble completed with {ensemble_method}, excluded: {exclude_stems if exclude_stems else 'None'}"
 
 
572
  except Exception as e:
573
  logger.error(f"Ensemble processing error: {e}")
574
  if "numpy" in str(e).lower() or "copy" in str(e).lower():
@@ -578,8 +641,8 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
578
  raise RuntimeError(error_msg)
579
  except Exception as e:
580
  logger.error(f"Ensemble error: {e}")
581
- if "ZeroGPU" in str(e) or "aborted" in str(e).lower():
582
- error_msg = "ZeroGPU task aborted. Try using fewer models (max 2), lowering segment size, or uploading a local WAV file."
583
  else:
584
  error_msg = f"Ensemble error: {e}"
585
  raise RuntimeError(error_msg)
@@ -615,8 +678,8 @@ def create_interface():
615
  with gr.Blocks(title="🎡 SESA Fast Separation 🎡", css=CSS, elem_id="app-container") as app:
616
  gr.Markdown("<h1 class='header-text'>🎡 SESA Fast Separation 🎡</h1>")
617
  gr.Markdown("**Note**: If YouTube downloads fail, upload a valid cookies file or a local WAV file. [Cookie Instructions](https://github.com/yt-dlp/yt-dlp/wiki/Extractors#exporting-youtube-cookies)")
618
- gr.Markdown("**Warning**: Audio files longer than 15 minutes are split into 5-minute chunks, which may increase processing time.")
619
- gr.Markdown("**ZeroGPU Notice**: Use up to 2 models for ensemble to avoid timeouts. For large tasks, upload a local WAV file.")
620
  with gr.Tabs():
621
  with gr.Tab("βš™οΈ Settings"):
622
  with gr.Group(elem_classes="dubbing-theme"):
@@ -653,7 +716,7 @@ def create_interface():
653
  with gr.Tab("🎚️ Auto Ensemble"):
654
  with gr.Group(elem_classes="dubbing-theme"):
655
  gr.Markdown("### Ensemble Processing")
656
- gr.Markdown("Note: If weights are not specified, equal weights (1.0) are applied. Max 2 models recommended to avoid ZeroGPU timeouts.")
657
  with gr.Row():
658
  ensemble_audio = gr.Audio(label="🎧 Upload Audio", type="filepath", interactive=True)
659
  url_ensemble = gr.Textbox(label="πŸ”— Or Paste URL", placeholder="YouTube or audio URL", interactive=True)
@@ -663,13 +726,13 @@ def create_interface():
663
  ensemble_exclude_stems = gr.Textbox(label="🚫 Exclude Stems", placeholder="e.g., vocals, drums (comma-separated)", interactive=True)
664
  with gr.Row():
665
  ensemble_category = gr.Dropdown(label="πŸ“š Category", choices=list(ROFORMER_MODELS.keys()), value="Instrumentals", interactive=True)
666
- ensemble_models = gr.Dropdown(label="πŸ› οΈ Models (Max 2)", choices=list(ROFORMER_MODELS["Instrumentals"].keys()), multiselect=True, interactive=True, allow_custom_value=True)
667
  with gr.Row():
668
  ensemble_seg_size = gr.Slider(32, 512, value=128, step=32, label="πŸ“ Segment Size", interactive=True)
669
  ensemble_overlap = gr.Slider(2, 10, value=8, step=1, label="πŸ”„ Overlap", interactive=True)
670
  ensemble_use_tta = gr.Dropdown(choices=["True", "False"], value="False", label="πŸ” Use TTA", interactive=True)
671
  ensemble_method = gr.Dropdown(label="βš™οΈ Ensemble Method", choices=['avg_wave', 'median_wave', 'max_wave', 'min_wave', 'avg_fft', 'median_fft', 'max_fft', 'min_fft'], value='avg_wave', interactive=True)
672
- ensemble_weights = gr.Textbox(label="βš–οΈ Weights", placeholder="e.g., 1.0, 1.0 (comma-separated)", interactive=True)
673
  ensemble_button = gr.Button("πŸŽ›οΈ Run Ensemble!", variant="primary")
674
  ensemble_output = gr.Audio(label="🎢 Ensemble Result", type="filepath", interactive=False)
675
  ensemble_status = gr.Textbox(label="πŸ“’ Status", interactive=False)
@@ -699,7 +762,7 @@ def create_interface():
699
  fn=auto_ensemble_process,
700
  inputs=[
701
  ensemble_audio, ensemble_models, ensemble_seg_size, ensemble_overlap,
702
- output_format, ensemble_use_tta, model_file_dir, output_dir,
703
  norm_threshold, amp_threshold, batch_size, ensemble_method,
704
  ensemble_exclude_stems, ensemble_weights
705
  ],
 
11
  import numpy as np
12
  import librosa
13
  import soundfile as sf
14
+ from ensemble import ensemble_files
15
  import shutil
16
  import gradio_client.utils as client_utils
17
  import matchering as mg
18
  import spaces
19
  import gdown
 
20
  from pydub import AudioSegment
21
  import gc
22
  import time
23
+ from concurrent.futures import ThreadPoolExecutor, as_completed
24
+ from threading import Lock
25
 
26
  # Logging setup
27
  logging.basicConfig(level=logging.INFO)
 
231
  box-shadow: 0 2px 8px rgba(255, 107, 107, 0.4) !important;
232
  }
233
  .compact-dropdown select, .compact-dropdown .gr-dropdown {
234
+ background: transparent !important;
 
235
  color: #e0e0e0 !important;
236
  border: none !important;
237
  width: 100% !important;
 
340
  gdown.download(download_url, temp_output_path, quiet=False)
341
  if not os.path.exists(temp_output_path):
342
  return None, "Downloaded file not found", None
 
 
 
 
343
  output_path = 'ytdl/gdrive_audio.wav'
344
+ try:
345
+ audio = AudioSegment.from_file(temp_output_path)
346
+ audio.export(output_path, format="wav")
347
+ except Exception as e:
348
+ return None, f"Failed to process Google Drive file as audio: {str(e)}. Ensure the file contains audio (e.g., MP3, WAV, or video with audio track).", None
349
  sample_rate, data = scipy.io.wavfile.read(output_path)
350
+ return output_path, "Download and audio conversion successful", (sample_rate, data)
351
  else:
352
  os.makedirs('ytdl', exist_ok=True)
353
  with yt_dlp.YoutubeDL(ydl_opts) as ydl:
 
432
  temp_audio_path = None
433
  chunk_paths = []
434
  max_retries = 2
435
+ start_time = time.time()
436
+ time_budget = 100 # seconds, to stay within ZeroGPU limit
437
+ max_models = 6 # Reasonable limit to prevent timeouts
438
+ gpu_lock = Lock() # Ensure only one model uses GPU at a time
439
+
440
  try:
441
  if not audio:
442
  raise ValueError("No audio file provided.")
443
  if not model_keys:
444
  raise ValueError("No models selected.")
445
+ if len(model_keys) > max_models:
446
+ logger.warning(f"Selected {len(model_keys)} models, limiting to {max_models} to avoid ZeroGPU timeouts.")
447
+ model_keys = model_keys[:max_models]
448
+
449
+ # Dynamic batch size adjustment
450
+ dynamic_batch_size = max(1, min(4, 1 + (6 - len(model_keys)) // 2))
451
+ logger.info(f"Using batch size: {dynamic_batch_size} for {len(model_keys)} models")
452
+
453
  if isinstance(audio, tuple):
454
  sample_rate, data = audio
455
  temp_audio_path = os.path.join("/tmp", "temp_audio.wav")
456
  scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
457
  audio = temp_audio_path
458
+
459
  audio_data, sr = librosa.load(audio, sr=None, mono=False)
460
  duration = librosa.get_duration(y=audio_data, sr=sr)
461
  logger.info(f"Audio duration: {duration:.2f} seconds")
462
+
463
+ # Optimize chunking
464
+ chunk_duration = 300 if duration > 900 else duration
465
  chunks = []
466
  if duration > 900:
467
  logger.info(f"Audio exceeds 15 minutes, splitting into {chunk_duration}-second chunks")
 
477
  logger.info(f"Created chunk {i}: {chunk_path}")
478
  else:
479
  chunks = [audio]
480
+
481
  use_tta = use_tta == "True"
482
  if os.path.exists(output_dir):
483
  shutil.rmtree(output_dir)
484
  os.makedirs(output_dir, exist_ok=True)
485
  base_name = os.path.splitext(os.path.basename(audio))[0]
486
  logger.info(f"Ensemble for {base_name} with {model_keys} on {device}")
487
+
488
+ # Model cache
489
+ model_cache = {}
490
  all_stems = []
491
+ model_stems = {model_key: {"vocals": [], "other": []} for model_key in model_keys}
492
+ total_tasks = len(model_keys) * len(chunks)
493
+
494
+ def process_model_chunk(model_key, chunk_path, chunk_idx, model_idx):
495
+ with torch.no_grad():
496
+ for attempt in range(max_retries + 1):
 
 
 
 
 
 
 
 
497
  try:
498
+ # Find model
499
+ for category, models in ROFORMER_MODELS.items():
500
+ if model_key in models:
501
+ model = models[model_key]
502
+ break
503
+ else:
504
+ logger.warning(f"Model {model_key} not found, skipping")
505
+ return []
506
+
507
+ # Check time budget
508
+ elapsed = time.time() - start_time
509
+ if elapsed > time_budget:
510
+ logger.error(f"Time budget ({time_budget}s) exceeded, aborting")
511
+ raise TimeoutError("Processing exceeded time budget")
512
+
513
+ # Initialize separator
514
  model_path = os.path.join(model_dir, model)
515
+ if model_key not in model_cache:
516
+ logger.info(f"Loading {model_key} into cache")
517
+ separator = Separator(
518
+ log_level=logging.INFO,
519
+ model_file_dir=model_dir,
520
+ output_dir=output_dir,
521
+ output_format=out_format,
522
+ normalization_threshold=norm_thresh,
523
+ amplification_threshold=amp_thresh,
524
+ use_autocast=use_autocast,
525
+ mdxc_params={
526
+ "segment_size": seg_size,
527
+ "overlap": overlap,
528
+ "use_tta": use_tta,
529
+ "batch_size": dynamic_batch_size
530
+ }
531
+ )
532
+ separator.load_model(model_filename=model)
533
+ model_cache[model_key] = separator
534
+ else:
535
+ separator = model_cache[model_key]
536
+
537
+ # Process with GPU lock
538
+ with gpu_lock:
539
+ progress((model_idx + chunk_idx / len(chunks)) / len(model_keys), desc=f"Separating chunk {chunk_idx} with {model_key}")
540
+ logger.info(f"Separating chunk {chunk_idx} with {model_key}")
541
+ separation = separator.separate(chunk_path)
542
+ stems = [os.path.join(output_dir, file_name) for file_name in separation]
543
+ result = []
544
+ for stem in stems:
545
+ if "vocals" in os.path.basename(stem).lower():
546
+ model_stems[model_key]["vocals"].append(stem)
547
+ elif "other" in os.path.basename(stem).lower() or "instrumental" in os.path.basename(stem).lower():
548
+ model_stems[model_key]["other"].append(stem)
549
+ result.append(stem)
550
+ return result
551
  except Exception as e:
552
+ logger.error(f"Error processing {model_key} chunk {chunk_idx}, attempt {attempt + 1}/{max_retries + 1}: {e}")
553
+ if attempt == max_retries:
 
 
 
554
  logger.error(f"Max retries reached for {model_key} chunk {chunk_idx}, skipping")
555
+ return []
556
  time.sleep(1)
557
  finally:
 
 
558
  if torch.cuda.is_available():
559
  torch.cuda.empty_cache()
560
  logger.info(f"Cleared CUDA cache after {model_key} chunk {chunk_idx}")
561
+
562
+ # Parallel processing
563
+ progress(0.1, desc="Starting model separations...")
564
+ with ThreadPoolExecutor(max_workers=min(4, len(model_keys))) as executor:
565
+ future_to_task = {}
566
+ for model_idx, model_key in enumerate(model_keys):
567
+ for chunk_idx, chunk_path in enumerate(chunks):
568
+ future = executor.submit(process_model_chunk, model_key, chunk_path, chunk_idx, model_idx)
569
+ future_to_task[future] = (model_key, chunk_idx)
570
+
571
+ for future in as_completed(future_to_task):
572
+ model_key, chunk_idx = future_to_task[future]
573
+ try:
574
+ stems = future.result()
575
+ if stems:
576
+ logger.info(f"Completed {model_key} chunk {chunk_idx}")
577
+ else:
578
+ logger.warning(f"No stems produced for {model_key} chunk {chunk_idx}")
579
+ except Exception as e:
580
+ logger.error(f"Task {model_key} chunk {chunk_idx} failed: {e}")
581
+
582
+ # Clear model cache
583
+ model_cache.clear()
584
+ gc.collect()
585
+ if torch.cuda.is_available():
586
+ torch.cuda.empty_cache()
587
+ logger.info("Cleared model cache and GPU memory")
588
+
589
+ # Combine stems
590
  progress(0.8, desc="Combining stems...")
591
  for model_key, stems_dict in model_stems.items():
592
  for stem_type in ["vocals", "other"]:
 
604
  all_stems.append(combined_path)
605
  except Exception as e:
606
  logger.error(f"Error combining {stem_type} for {model_key}: {e}")
607
+
608
  all_stems = [stem for stem in all_stems if os.path.exists(stem)]
609
  if not all_stems:
610
  raise ValueError("No valid stems found for ensemble. Try uploading a local WAV file.")
611
+
612
+ # Ensemble
613
  weights = [float(w.strip()) for w in weights_str.split(',')] if weights_str.strip() else [1.0] * len(all_stems)
614
  if len(weights) != len(all_stems):
615
  weights = [1.0] * len(all_stems)
616
+ logger.info("Weights mismatched, defaulting to 1.0")
617
  output_file = os.path.join(output_dir, f"{base_name}_ensemble_{ensemble_method}.{out_format}")
618
  ensemble_args = [
619
  "--files", *all_stems,
 
624
  progress(0.9, desc="Running ensemble...")
625
  logger.info(f"Running ensemble with args: {ensemble_args}")
626
  try:
627
+ result = ensemble_files(ensemble_args)
628
  if result is None or not os.path.exists(output_file):
629
  raise RuntimeError(f"Ensemble failed, output file not created: {output_file}")
630
  logger.info(f"Ensemble completed, output: {output_file}")
631
  progress(1.0, desc="Ensemble completed")
632
+ elapsed = time.time() - start_time
633
+ logger.info(f"Total processing time: {elapsed:.2f}s")
634
+ return output_file, f"Ensemble completed with {ensemble_method}, excluded: {exclude_stems if exclude_stems else 'None'}, {len(model_keys)} models in {elapsed:.2f}s"
635
  except Exception as e:
636
  logger.error(f"Ensemble processing error: {e}")
637
  if "numpy" in str(e).lower() or "copy" in str(e).lower():
 
641
  raise RuntimeError(error_msg)
642
  except Exception as e:
643
  logger.error(f"Ensemble error: {e}")
644
+ if "ZeroGPU" in str(e) or "aborted" in str(e).lower() or isinstance(e, TimeoutError):
645
+ error_msg = f"ZeroGPU task aborted or timed out. Try fewer models (max {max_models}), shorter audio, or uploading a local WAV file."
646
  else:
647
  error_msg = f"Ensemble error: {e}"
648
  raise RuntimeError(error_msg)
 
678
  with gr.Blocks(title="🎡 SESA Fast Separation 🎡", css=CSS, elem_id="app-container") as app:
679
  gr.Markdown("<h1 class='header-text'>🎡 SESA Fast Separation 🎡</h1>")
680
  gr.Markdown("**Note**: If YouTube downloads fail, upload a valid cookies file or a local WAV file. [Cookie Instructions](https://github.com/yt-dlp/yt-dlp/wiki/Extractors#exporting-youtube-cookies)")
681
+ gr.Markdown("**Warning**: Audio files longer than 15 minutes are split into 5-minute chunks, increasing processing time.")
682
+ gr.Markdown("**ZeroGPU Notice**: Up to 6 models supported for ensemble. For long audio, use fewer models or a local WAV file to avoid timeouts.")
683
  with gr.Tabs():
684
  with gr.Tab("βš™οΈ Settings"):
685
  with gr.Group(elem_classes="dubbing-theme"):
 
716
  with gr.Tab("🎚️ Auto Ensemble"):
717
  with gr.Group(elem_classes="dubbing-theme"):
718
  gr.Markdown("### Ensemble Processing")
719
+ gr.Markdown("Note: If weights are not specified, equal weights (1.0) are applied. Max 6 models recommended to avoid ZeroGPU timeouts.")
720
  with gr.Row():
721
  ensemble_audio = gr.Audio(label="🎧 Upload Audio", type="filepath", interactive=True)
722
  url_ensemble = gr.Textbox(label="πŸ”— Or Paste URL", placeholder="YouTube or audio URL", interactive=True)
 
726
  ensemble_exclude_stems = gr.Textbox(label="🚫 Exclude Stems", placeholder="e.g., vocals, drums (comma-separated)", interactive=True)
727
  with gr.Row():
728
  ensemble_category = gr.Dropdown(label="πŸ“š Category", choices=list(ROFORMER_MODELS.keys()), value="Instrumentals", interactive=True)
729
+ ensemble_models = gr.Dropdown(label="πŸ› οΈ Models (Max 6)", choices=list(ROFORMER_MODELS["Instrumentals"].keys()), multiselect=True, interactive=True, allow_custom_value=True)
730
  with gr.Row():
731
  ensemble_seg_size = gr.Slider(32, 512, value=128, step=32, label="πŸ“ Segment Size", interactive=True)
732
  ensemble_overlap = gr.Slider(2, 10, value=8, step=1, label="πŸ”„ Overlap", interactive=True)
733
  ensemble_use_tta = gr.Dropdown(choices=["True", "False"], value="False", label="πŸ” Use TTA", interactive=True)
734
  ensemble_method = gr.Dropdown(label="βš™οΈ Ensemble Method", choices=['avg_wave', 'median_wave', 'max_wave', 'min_wave', 'avg_fft', 'median_fft', 'max_fft', 'min_fft'], value='avg_wave', interactive=True)
735
+ ensemble_weights = gr.Textbox(label="βš–οΈ Weights", placeholder="e.g., 1.0, 1.0, 1.0 (comma-separated)", interactive=True)
736
  ensemble_button = gr.Button("πŸŽ›οΈ Run Ensemble!", variant="primary")
737
  ensemble_output = gr.Audio(label="🎢 Ensemble Result", type="filepath", interactive=False)
738
  ensemble_status = gr.Textbox(label="πŸ“’ Status", interactive=False)
 
762
  fn=auto_ensemble_process,
763
  inputs=[
764
  ensemble_audio, ensemble_models, ensemble_seg_size, ensemble_overlap,
765
+ output_format, ensemble_use_tta, model_dir, output_dir,
766
  norm_threshold, amp_threshold, batch_size, ensemble_method,
767
  ensemble_exclude_stems, ensemble_weights
768
  ],