ASesYusuf1 commited on
Commit
defb0b3
Β·
verified Β·
1 Parent(s): 6242fc6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -72
app.py CHANGED
@@ -22,7 +22,7 @@ import gc
22
  import time
23
  from concurrent.futures import ThreadPoolExecutor, as_completed
24
  from threading import Lock
25
- import scipy
26
 
27
  # Logging setup
28
  logging.basicConfig(level=logging.INFO)
@@ -384,7 +384,7 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
384
  scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
385
  audio = temp_audio_path
386
  if seg_size > 512:
387
- logger.warning(f"Segment size {seg_size} is large, this may cause crashes on ZeroGPU.")
388
  override_seg_size = override_seg_size == "True"
389
  if os.path.exists(output_dir):
390
  shutil.rmtree(output_dir)
@@ -429,14 +429,13 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
429
  logger.info("GPU memory cleared")
430
 
431
  @spaces.GPU
432
- def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_format="wav", use_tta="False", model_dir="/tmp/audio-separator-models/", output_dir="output", norm_thresh=0.9, amp_thresh=0.9, batch_size=1, ensemble_method="avg_wave", exclude_stems="", weights_str="", progress=gr.Progress(track_tqdm=True)):
433
  temp_audio_path = None
434
- chunk_paths = []
435
  max_retries = 2
436
  start_time = time.time()
437
- time_budget = 100 # seconds, to stay within ZeroGPU limit
438
- max_models = 6 # Reasonable limit to prevent timeouts
439
- gpu_lock = Lock() # Ensure only one model uses GPU at a time
440
 
441
  try:
442
  if not audio:
@@ -444,12 +443,15 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
444
  if not model_keys:
445
  raise ValueError("No models selected.")
446
  if len(model_keys) > max_models:
447
- logger.warning(f"Selected {len(model_keys)} models, limiting to {max_models} to avoid ZeroGPU timeouts.")
448
  model_keys = model_keys[:max_models]
449
 
450
- # Dynamic batch size adjustment
451
- dynamic_batch_size = max(1, min(4, 1 + (6 - len(model_keys)) // 2))
452
- logger.info(f"Using batch size: {dynamic_batch_size} for {len(model_keys)} models")
 
 
 
453
 
454
  if isinstance(audio, tuple):
455
  sample_rate, data = audio
@@ -457,28 +459,6 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
457
  scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
458
  audio = temp_audio_path
459
 
460
- audio_data, sr = librosa.load(audio, sr=None, mono=False)
461
- duration = librosa.get_duration(y=audio_data, sr=sr)
462
- logger.info(f"Audio duration: {duration:.2f} seconds")
463
-
464
- # Optimize chunking
465
- chunk_duration = 300 if duration > 900 else duration
466
- chunks = []
467
- if duration > 900:
468
- logger.info(f"Audio exceeds 15 minutes, splitting into {chunk_duration}-second chunks")
469
- num_chunks = int(np.ceil(duration / chunk_duration))
470
- for i in range(num_chunks):
471
- start = i * chunk_duration * sr
472
- end = min((i + 1) * chunk_duration * sr, audio_data.shape[-1])
473
- chunk_data = audio_data[:, start:end] if audio_data.ndim == 2 else audio_data[start:end]
474
- chunk_path = os.path.join("/tmp", f"chunk_{i}.wav")
475
- sf.write(chunk_path, chunk_data.T if audio_data.ndim == 2 else chunk_data, sr)
476
- chunks.append(chunk_path)
477
- chunk_paths.append(chunk_path)
478
- logger.info(f"Created chunk {i}: {chunk_path}")
479
- else:
480
- chunks = [audio]
481
-
482
  use_tta = use_tta == "True"
483
  if os.path.exists(output_dir):
484
  shutil.rmtree(output_dir)
@@ -490,9 +470,9 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
490
  model_cache = {}
491
  all_stems = []
492
  model_stems = {model_key: {"vocals": [], "other": []} for model_key in model_keys}
493
- total_tasks = len(model_keys) * len(chunks)
494
 
495
- def process_model_chunk(model_key, chunk_path, chunk_idx, model_idx):
496
  with torch.no_grad():
497
  for attempt in range(max_retries + 1):
498
  try:
@@ -508,8 +488,8 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
508
  # Check time budget
509
  elapsed = time.time() - start_time
510
  if elapsed > time_budget:
511
- logger.error(f"Time budget ({time_budget}s) exceeded, aborting")
512
- raise TimeoutError("Processing exceeded time budget")
513
 
514
  # Initialize separator
515
  model_path = os.path.join(model_dir, model)
@@ -537,9 +517,9 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
537
 
538
  # Process with GPU lock
539
  with gpu_lock:
540
- progress((model_idx + chunk_idx / len(chunks)) / len(model_keys), desc=f"Separating chunk {chunk_idx} with {model_key}")
541
- logger.info(f"Separating chunk {chunk_idx} with {model_key}")
542
- separation = separator.separate(chunk_path)
543
  stems = [os.path.join(output_dir, file_name) for file_name in separation]
544
  result = []
545
  for stem in stems:
@@ -550,35 +530,30 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
550
  result.append(stem)
551
  return result
552
  except Exception as e:
553
- logger.error(f"Error processing {model_key} chunk {chunk_idx}, attempt {attempt + 1}/{max_retries + 1}: {e}")
554
  if attempt == max_retries:
555
- logger.error(f"Max retries reached for {model_key} chunk {chunk_idx}, skipping")
556
  return []
557
  time.sleep(1)
558
  finally:
559
  if torch.cuda.is_available():
560
  torch.cuda.empty_cache()
561
- logger.info(f"Cleared CUDA cache after {model_key} chunk {chunk_idx}")
562
 
563
  # Parallel processing
564
  progress(0.1, desc="Starting model separations...")
565
  with ThreadPoolExecutor(max_workers=min(4, len(model_keys))) as executor:
566
- future_to_task = {}
567
- for model_idx, model_key in enumerate(model_keys):
568
- for chunk_idx, chunk_path in enumerate(chunks):
569
- future = executor.submit(process_model_chunk, model_key, chunk_path, chunk_idx, model_idx)
570
- future_to_task[future] = (model_key, chunk_idx)
571
-
572
  for future in as_completed(future_to_task):
573
- model_key, chunk_idx = future_to_task[future]
574
  try:
575
  stems = future.result()
576
  if stems:
577
- logger.info(f"Completed {model_key} chunk {chunk_idx}")
578
  else:
579
- logger.warning(f"No stems produced for {model_key} chunk {chunk_idx}")
580
  except Exception as e:
581
- logger.error(f"Task {model_key} chunk {chunk_idx} failed: {e}")
582
 
583
  # Clear model cache
584
  model_cache.clear()
@@ -594,10 +569,9 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
594
  if stems_dict[stem_type]:
595
  combined_path = os.path.join(output_dir, f"{base_name}_{stem_type}_{model_key.replace(' | ', '_').replace(' ', '_')}.wav")
596
  try:
597
- with sf.SoundFile(combined_path, 'w', sr, channels=2 if audio_data.ndim == 2 else 1) as f:
598
- for stem_path in stems_dict[stem_type]:
599
- data, _ = librosa.load(stem_path, sr=sr, mono=False)
600
- f.write(data.T if data.ndim == 2 else data)
601
  logger.info(f"Combined {stem_type} for {model_key}: {combined_path}")
602
  if exclude_stems.strip() and stem_type.lower() in [s.strip().lower() for s in exclude_stems.split(',')]:
603
  logger.info(f"Excluding {stem_type} for {model_key}")
@@ -642,19 +616,15 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
642
  raise RuntimeError(error_msg)
643
  except Exception as e:
644
  logger.error(f"Ensemble error: {e}")
645
- if "ZeroGPU" in str(e) or "aborted" in str(e).lower() or isinstance(e, TimeoutError):
646
- error_msg = f"ZeroGPU task aborted or timed out. Try fewer models (max {max_models}), shorter audio, or uploading a local WAV file."
647
- else:
648
- error_msg = f"Ensemble error: {e}"
649
  raise RuntimeError(error_msg)
650
  finally:
651
- for path in chunk_paths + ([temp_audio_path] if temp_audio_path and os.path.exists(temp_audio_path) else []):
652
  try:
653
- if os.path.exists(path):
654
- os.remove(path)
655
- logger.info(f"Temporary file deleted: {path}")
656
  except Exception as e:
657
- logger.warning(f"Failed to delete temporary file {path}: {e}")
658
  if torch.cuda.is_available():
659
  torch.cuda.empty_cache()
660
  logger.info("GPU memory cleared")
@@ -679,8 +649,7 @@ def create_interface():
679
  with gr.Blocks(title="🎡 SESA Fast Separation 🎡", css=CSS, elem_id="app-container") as app:
680
  gr.Markdown("<h1 class='header-text'>🎡 SESA Fast Separation 🎡</h1>")
681
  gr.Markdown("**Note**: If YouTube downloads fail, upload a valid cookies file or a local WAV file. [Cookie Instructions](https://github.com/yt-dlp/yt-dlp/wiki/Extractors#exporting-youtube-cookies)")
682
- gr.Markdown("**Warning**: Audio files longer than 15 minutes are split into 5-minute chunks, increasing processing time.")
683
- gr.Markdown("**ZeroGPU Notice**: Up to 6 models supported for ensemble. For long audio, use fewer models or a local WAV file to avoid timeouts.")
684
  with gr.Tabs():
685
  with gr.Tab("βš™οΈ Settings"):
686
  with gr.Group(elem_classes="dubbing-theme"):
@@ -705,7 +674,7 @@ def create_interface():
705
  roformer_category = gr.Dropdown(label="πŸ“š Category", choices=list(ROFORMER_MODELS.keys()), value="General Purpose", interactive=True)
706
  roformer_model = gr.Dropdown(label="πŸ› οΈ Model", choices=list(ROFORMER_MODELS["General Purpose"].keys()), interactive=True, allow_custom_value=True)
707
  with gr.Row():
708
- roformer_seg_size = gr.Slider(32, 512, value=128, step=32, label="πŸ“ Segment Size", interactive=True)
709
  roformer_overlap = gr.Slider(2, 10, value=8, step=1, label="πŸ”„ Overlap", interactive=True)
710
  with gr.Row():
711
  roformer_pitch_shift = gr.Slider(-12, 12, value=0, step=1, label="🎡 Pitch Shift", interactive=True)
@@ -717,7 +686,7 @@ def create_interface():
717
  with gr.Tab("🎚️ Auto Ensemble"):
718
  with gr.Group(elem_classes="dubbing-theme"):
719
  gr.Markdown("### Ensemble Processing")
720
- gr.Markdown("Note: If weights are not specified, equal weights (1.0) are applied. Max 6 models recommended to avoid ZeroGPU timeouts.")
721
  with gr.Row():
722
  ensemble_audio = gr.Audio(label="🎧 Upload Audio", type="filepath", interactive=True)
723
  url_ensemble = gr.Textbox(label="πŸ”— Or Paste URL", placeholder="YouTube or audio URL", interactive=True)
@@ -729,7 +698,7 @@ def create_interface():
729
  ensemble_category = gr.Dropdown(label="πŸ“š Category", choices=list(ROFORMER_MODELS.keys()), value="Instrumentals", interactive=True)
730
  ensemble_models = gr.Dropdown(label="πŸ› οΈ Models (Max 6)", choices=list(ROFORMER_MODELS["Instrumentals"].keys()), multiselect=True, interactive=True, allow_custom_value=True)
731
  with gr.Row():
732
- ensemble_seg_size = gr.Slider(32, 512, value=128, step=32, label="πŸ“ Segment Size", interactive=True)
733
  ensemble_overlap = gr.Slider(2, 10, value=8, step=1, label="πŸ”„ Overlap", interactive=True)
734
  ensemble_use_tta = gr.Dropdown(choices=["True", "False"], value="False", label="πŸ” Use TTA", interactive=True)
735
  ensemble_method = gr.Dropdown(label="βš™οΈ Ensemble Method", choices=['avg_wave', 'median_wave', 'max_wave', 'min_wave', 'avg_fft', 'median_fft', 'max_fft', 'min_fft'], value='avg_wave', interactive=True)
@@ -763,7 +732,7 @@ def create_interface():
763
  fn=auto_ensemble_process,
764
  inputs=[
765
  ensemble_audio, ensemble_models, ensemble_seg_size, ensemble_overlap,
766
- output_format, ensemble_use_tta, model_file_dir, output_dir,
767
  norm_threshold, amp_threshold, batch_size, ensemble_method,
768
  ensemble_exclude_stems, ensemble_weights
769
  ],
 
22
  import time
23
  from concurrent.futures import ThreadPoolExecutor, as_completed
24
  from threading import Lock
25
+ import scipy.io.wavfile
26
 
27
  # Logging setup
28
  logging.basicConfig(level=logging.INFO)
 
384
  scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
385
  audio = temp_audio_path
386
  if seg_size > 512:
387
+ logger.warning(f"Segment size {seg_size} is large, this may cause issues.")
388
  override_seg_size = override_seg_size == "True"
389
  if os.path.exists(output_dir):
390
  shutil.rmtree(output_dir)
 
429
  logger.info("GPU memory cleared")
430
 
431
  @spaces.GPU
432
+ def auto_ensemble_process(audio, model_keys, seg_size=64, overlap=0.1, out_format="wav", use_tta="False", model_dir="/tmp/audio-separator-models/", output_dir="output", norm_thresh=0.9, amp_thresh=0.9, batch_size=1, ensemble_method="avg_wave", exclude_stems="", weights_str="", progress=gr.Progress(track_tqdm=True)):
433
  temp_audio_path = None
 
434
  max_retries = 2
435
  start_time = time.time()
436
+ time_budget = 100 # seconds
437
+ max_models = 6
438
+ gpu_lock = Lock()
439
 
440
  try:
441
  if not audio:
 
443
  if not model_keys:
444
  raise ValueError("No models selected.")
445
  if len(model_keys) > max_models:
446
+ logger.warning(f"Selected {len(model_keys)} models, limiting to {max_models}.")
447
  model_keys = model_keys[:max_models]
448
 
449
+ # Dynamic batch size based on audio duration and model count
450
+ audio_data, sr = librosa.load(audio, sr=None, mono=False)
451
+ duration = librosa.get_duration(y=audio_data, sr=sr)
452
+ logger.info(f"Audio duration: {duration:.2f} seconds")
453
+ dynamic_batch_size = max(1, min(4, 1 + int(900 / (duration + 1)) - len(model_keys) // 2))
454
+ logger.info(f"Using batch size: {dynamic_batch_size} for {len(model_keys)} models, duration {duration:.2f}s")
455
 
456
  if isinstance(audio, tuple):
457
  sample_rate, data = audio
 
459
  scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
460
  audio = temp_audio_path
461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
  use_tta = use_tta == "True"
463
  if os.path.exists(output_dir):
464
  shutil.rmtree(output_dir)
 
470
  model_cache = {}
471
  all_stems = []
472
  model_stems = {model_key: {"vocals": [], "other": []} for model_key in model_keys}
473
+ total_tasks = len(model_keys)
474
 
475
+ def process_model(model_key, model_idx):
476
  with torch.no_grad():
477
  for attempt in range(max_retries + 1):
478
  try:
 
488
  # Check time budget
489
  elapsed = time.time() - start_time
490
  if elapsed > time_budget:
491
+ logger.error(f"Time budget ({time_budget}s) exceeded")
492
+ raise TimeoutError("Processing took too long")
493
 
494
  # Initialize separator
495
  model_path = os.path.join(model_dir, model)
 
517
 
518
  # Process with GPU lock
519
  with gpu_lock:
520
+ progress(0.3 + (model_idx / total_tasks) * 0.5, desc=f"Separating with {model_key}")
521
+ logger.info(f"Separating with {model_key}")
522
+ separation = separator.separate(audio)
523
  stems = [os.path.join(output_dir, file_name) for file_name in separation]
524
  result = []
525
  for stem in stems:
 
530
  result.append(stem)
531
  return result
532
  except Exception as e:
533
+ logger.error(f"Error processing {model_key}, attempt {attempt + 1}/{max_retries + 1}: {e}")
534
  if attempt == max_retries:
535
+ logger.error(f"Max retries reached for {model_key}, skipping")
536
  return []
537
  time.sleep(1)
538
  finally:
539
  if torch.cuda.is_available():
540
  torch.cuda.empty_cache()
541
+ logger.info(f"Cleared CUDA cache after {model_key}")
542
 
543
  # Parallel processing
544
  progress(0.1, desc="Starting model separations...")
545
  with ThreadPoolExecutor(max_workers=min(4, len(model_keys))) as executor:
546
+ future_to_task = {executor.submit(process_model, model_key, idx): model_key for idx, model_key in enumerate(model_keys)}
 
 
 
 
 
547
  for future in as_completed(future_to_task):
548
+ model_key = future_to_task[future]
549
  try:
550
  stems = future.result()
551
  if stems:
552
+ logger.info(f"Completed {model_key}")
553
  else:
554
+ logger.warning(f"No stems produced for {model_key}")
555
  except Exception as e:
556
+ logger.error(f"Task {model_key} failed: {e}")
557
 
558
  # Clear model cache
559
  model_cache.clear()
 
569
  if stems_dict[stem_type]:
570
  combined_path = os.path.join(output_dir, f"{base_name}_{stem_type}_{model_key.replace(' | ', '_').replace(' ', '_')}.wav")
571
  try:
572
+ data, _ = librosa.load(stems_dict[stem_type][0], sr=sr, mono=False)
573
+ with sf.SoundFile(combined_path, 'w', sr, channels=2 if data.ndim == 2 else 1) as f:
574
+ f.write(data.T if data.ndim == 2 else data)
 
575
  logger.info(f"Combined {stem_type} for {model_key}: {combined_path}")
576
  if exclude_stems.strip() and stem_type.lower() in [s.strip().lower() for s in exclude_stems.split(',')]:
577
  logger.info(f"Excluding {stem_type} for {model_key}")
 
616
  raise RuntimeError(error_msg)
617
  except Exception as e:
618
  logger.error(f"Ensemble error: {e}")
619
+ error_msg = f"Processing failed. Try fewer models (max {max_models}), shorter audio, or uploading a local WAV file."
 
 
 
620
  raise RuntimeError(error_msg)
621
  finally:
622
+ if temp_audio_path and os.path.exists(temp_audio_path):
623
  try:
624
+ os.remove(temp_audio_path)
625
+ logger.info(f"Temporary file deleted: {temp_audio_path}")
 
626
  except Exception as e:
627
+ logger.warning(f"Failed to delete temporary file {temp_audio_path}: {e}")
628
  if torch.cuda.is_available():
629
  torch.cuda.empty_cache()
630
  logger.info("GPU memory cleared")
 
649
  with gr.Blocks(title="🎡 SESA Fast Separation 🎡", css=CSS, elem_id="app-container") as app:
650
  gr.Markdown("<h1 class='header-text'>🎡 SESA Fast Separation 🎡</h1>")
651
  gr.Markdown("**Note**: If YouTube downloads fail, upload a valid cookies file or a local WAV file. [Cookie Instructions](https://github.com/yt-dlp/yt-dlp/wiki/Extractors#exporting-youtube-cookies)")
652
+ gr.Markdown("**Tip**: For best results, use audio shorter than 15 minutes or fewer models (up to 6) to ensure smooth processing.")
 
653
  with gr.Tabs():
654
  with gr.Tab("βš™οΈ Settings"):
655
  with gr.Group(elem_classes="dubbing-theme"):
 
674
  roformer_category = gr.Dropdown(label="πŸ“š Category", choices=list(ROFORMER_MODELS.keys()), value="General Purpose", interactive=True)
675
  roformer_model = gr.Dropdown(label="πŸ› οΈ Model", choices=list(ROFORMER_MODELS["General Purpose"].keys()), interactive=True, allow_custom_value=True)
676
  with gr.Row():
677
+ roformer_seg_size = gr.Slider(32, 512, value=64, step=32, label="πŸ“ Segment Size", interactive=True)
678
  roformer_overlap = gr.Slider(2, 10, value=8, step=1, label="πŸ”„ Overlap", interactive=True)
679
  with gr.Row():
680
  roformer_pitch_shift = gr.Slider(-12, 12, value=0, step=1, label="🎡 Pitch Shift", interactive=True)
 
686
  with gr.Tab("🎚️ Auto Ensemble"):
687
  with gr.Group(elem_classes="dubbing-theme"):
688
  gr.Markdown("### Ensemble Processing")
689
+ gr.Markdown("Note: If weights are not specified, equal weights (1.0) are applied. Use up to 6 models for best results.")
690
  with gr.Row():
691
  ensemble_audio = gr.Audio(label="🎧 Upload Audio", type="filepath", interactive=True)
692
  url_ensemble = gr.Textbox(label="πŸ”— Or Paste URL", placeholder="YouTube or audio URL", interactive=True)
 
698
  ensemble_category = gr.Dropdown(label="πŸ“š Category", choices=list(ROFORMER_MODELS.keys()), value="Instrumentals", interactive=True)
699
  ensemble_models = gr.Dropdown(label="πŸ› οΈ Models (Max 6)", choices=list(ROFORMER_MODELS["Instrumentals"].keys()), multiselect=True, interactive=True, allow_custom_value=True)
700
  with gr.Row():
701
+ ensemble_seg_size = gr.Slider(32, 512, value=64, step=32, label="πŸ“ Segment Size", interactive=True)
702
  ensemble_overlap = gr.Slider(2, 10, value=8, step=1, label="πŸ”„ Overlap", interactive=True)
703
  ensemble_use_tta = gr.Dropdown(choices=["True", "False"], value="False", label="πŸ” Use TTA", interactive=True)
704
  ensemble_method = gr.Dropdown(label="βš™οΈ Ensemble Method", choices=['avg_wave', 'median_wave', 'max_wave', 'min_wave', 'avg_fft', 'median_fft', 'max_fft', 'min_fft'], value='avg_wave', interactive=True)
 
732
  fn=auto_ensemble_process,
733
  inputs=[
734
  ensemble_audio, ensemble_models, ensemble_seg_size, ensemble_overlap,
735
+ output_format, ensemble_use_tta, model_dir, output_dir,
736
  norm_threshold, amp_threshold, batch_size, ensemble_method,
737
  ensemble_exclude_stems, ensemble_weights
738
  ],