ASesYusuf1 commited on
Commit
92e9644
Β·
verified Β·
1 Parent(s): 92125aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -16
app.py CHANGED
@@ -1,10 +1,9 @@
1
- from typing import Optional, Any # Optional ve Any iΓ§in gerekli iΓ§e aktarma
2
-
3
- # Mevcut diğer içe aktarmalar (ârneğin, ânceki kodunuzdan)
4
  import os
5
  import sys
6
  import torch
7
  import logging
 
8
  from yt_dlp import YoutubeDL
9
  import gradio as gr
10
  import argparse
@@ -20,8 +19,8 @@ import spaces
20
  import gdown
21
  import scipy.io.wavfile
22
  from pydub import AudioSegment
23
- import yt_dlp
24
  import gc
 
25
 
26
  # Logging setup (mevcut)
27
  logging.basicConfig(level=logging.INFO)
@@ -433,12 +432,15 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
433
  def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_format="wav", use_tta="False", model_dir="/tmp/audio-separator-models/", output_dir="output", norm_thresh=0.9, amp_thresh=0.9, batch_size=1, ensemble_method="avg_wave", exclude_stems="", weights_str="", progress=gr.Progress(track_tqdm=True)):
434
  temp_audio_path = None
435
  chunk_paths = []
436
- max_retries = 2 # Retry attempts for ZeroGPU session issues
437
  try:
438
  if not audio:
439
  raise ValueError("No audio file provided.")
440
  if not model_keys:
441
  raise ValueError("No models selected.")
 
 
 
442
  if isinstance(audio, tuple):
443
  sample_rate, data = audio
444
  temp_audio_path = os.path.join("/tmp", "temp_audio.wav")
@@ -465,8 +467,7 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
465
  chunks = [audio]
466
  use_tta = use_tta == "True"
467
  if os.path.exists(output_dir):
468
- shutil.rmtree(outputwatermark = True)
469
- shutil.copyfile(audio, os.path.join(output_dir, os.path.basename(audio)))
470
  os.makedirs(output_dir, exist_ok=True)
471
  base_name = os.path.splitext(os.path.basename(audio))[0]
472
  logger.info(f"Ensemble for {base_name} with {model_keys} on {device}")
@@ -487,6 +488,10 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
487
  while retry_count <= max_retries:
488
  try:
489
  progress((model_idx + 0.1) / total_models, desc=f"Loading {model_key} for chunk {chunk_idx}")
 
 
 
 
490
  separator = Separator(
491
  log_level=logging.INFO,
492
  model_file_dir=model_dir,
@@ -508,20 +513,24 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
508
  model_stems[model_key]["vocals"].append(stem)
509
  elif "other" in os.path.basename(stem).lower() or "instrumental" in os.path.basename(stem).lower():
510
  model_stems[model_key]["other"].append(stem)
511
- break # Success, exit retry loop
512
  except Exception as e:
513
  retry_count += 1
514
  logger.error(f"Error processing {model_key} chunk {chunk_idx}, attempt {retry_count}/{max_retries}: {e}")
 
 
515
  if retry_count > max_retries:
516
  logger.error(f"Max retries reached for {model_key} chunk {chunk_idx}, skipping")
517
  break
518
- time.sleep(2) # Wait before retrying
519
  finally:
520
  separator = None
521
  gc.collect()
522
  if torch.cuda.is_available():
523
  torch.cuda.empty_cache()
524
  logger.info(f"Cleared CUDA cache after {model_key} chunk {chunk_idx}")
 
 
525
  progress(0.8, desc="Combining stems...")
526
  for model_key, stems_dict in model_stems.items():
527
  for stem_type in ["vocals", "other"]:
@@ -541,7 +550,7 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
541
  logger.error(f"Error combining {stem_type} for {model_key}: {e}")
542
  all_stems = [stem for stem in all_stems if os.path.exists(stem)]
543
  if not all_stems:
544
- raise ValueError("No valid stems found for ensemble.")
545
  weights = [float(w.strip()) for w in weights_str.split(',')] if weights_str.strip() else [1.0] * len(all_stems)
546
  if len(weights) != len(all_stems):
547
  weights = [1.0] * len(all_stems)
@@ -565,7 +574,11 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
565
  raise RuntimeError(f"Ensemble processing error: {e}")
566
  except Exception as e:
567
  logger.error(f"Ensemble error: {e}")
568
- raise RuntimeError(f"Ensemble error: {e}")
 
 
 
 
569
  finally:
570
  for path in chunk_paths + ([temp_audio_path] if temp_audio_path and os.path.exists(temp_audio_path) else []):
571
  try:
@@ -599,6 +612,7 @@ def create_interface():
599
  gr.Markdown("<h1 class='header-text'>🎡 SESA Fast Separation 🎡</h1>")
600
  gr.Markdown("**Note**: If YouTube downloads fail, upload an audio file directly or use a valid cookies file. [Cookie Instructions](https://github.com/yt-dlp/yt-dlp/wiki/Extractors#exporting-youtube-cookies)")
601
  gr.Markdown("**Warning**: Audio files longer than 15 minutes are automatically split into chunks, which may require more time and resources.")
 
602
  with gr.Tabs():
603
  with gr.Tab("βš™οΈ Settings"):
604
  with gr.Group(elem_classes="dubbing-theme"):
@@ -608,7 +622,7 @@ def create_interface():
608
  output_format = gr.Dropdown(value="wav", choices=OUTPUT_FORMATS, label="🎢 Output Format", interactive=True)
609
  norm_threshold = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="πŸ”Š Normalization Threshold", interactive=True)
610
  amp_threshold = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="πŸ“ˆ Amplification Threshold", interactive=True)
611
- batch_size = gr.Slider(1, 16, value=1, step=1, label="⚑ Batch Size", interactive=True)
612
  with gr.Tab("🎀 Roformer"):
613
  with gr.Group(elem_classes="dubbing-theme"):
614
  gr.Markdown("### Audio Separation")
@@ -623,7 +637,7 @@ def create_interface():
623
  roformer_category = gr.Dropdown(label="πŸ“š Category", choices=list(ROFORMER_MODELS.keys()), value="General Purpose", interactive=True)
624
  roformer_model = gr.Dropdown(label="πŸ› οΈ Model", choices=list(ROFORMER_MODELS["General Purpose"].keys()), interactive=True, allow_custom_value=True)
625
  with gr.Row():
626
- roformer_seg_size = gr.Slider(32, 4000, value=256, step=32, label="πŸ“ Segment Size", interactive=True)
627
  roformer_overlap = gr.Slider(2, 10, value=8, step=1, label="πŸ”„ Overlap", interactive=True)
628
  with gr.Row():
629
  roformer_pitch_shift = gr.Slider(-12, 12, value=0, step=1, label="🎡 Pitch Shift", interactive=True)
@@ -635,7 +649,7 @@ def create_interface():
635
  with gr.Tab("🎚️ Auto Ensemble"):
636
  with gr.Group(elem_classes="dubbing-theme"):
637
  gr.Markdown("### Ensemble Processing")
638
- gr.Markdown("Note: If weights are not specified, equal weights (1.0) are applied to all models.")
639
  with gr.Row():
640
  ensemble_audio = gr.Audio(label="🎧 Upload Audio", type="filepath", interactive=True)
641
  url_ensemble = gr.Textbox(label="πŸ”— Or Paste URL", placeholder="YouTube or audio URL", interactive=True)
@@ -645,9 +659,9 @@ def create_interface():
645
  ensemble_exclude_stems = gr.Textbox(label="🚫 Exclude Stems", placeholder="e.g., vocals, drums (comma-separated)", interactive=True)
646
  with gr.Row():
647
  ensemble_category = gr.Dropdown(label="πŸ“š Category", choices=list(ROFORMER_MODELS.keys()), value="Instrumentals", interactive=True)
648
- ensemble_models = gr.Dropdown(label="πŸ› οΈ Models", choices=list(ROFORMER_MODELS["Instrumentals"].keys()), multiselect=True, interactive=True, allow_custom_value=True)
649
  with gr.Row():
650
- ensemble_seg_size = gr.Slider(32, 4000, value=256, step=32, label="πŸ“ Segment Size", interactive=True)
651
  ensemble_overlap = gr.Slider(2, 10, value=8, step=1, label="πŸ”„ Overlap", interactive=True)
652
  ensemble_use_tta = gr.Dropdown(choices=["True", "False"], value="False", label="πŸ” Use TTA", interactive=True)
653
  ensemble_method = gr.Dropdown(label="βš™οΈ Ensemble Method", choices=['avg_wave', 'median_wave', 'max_wave', 'min_wave', 'avg_fft', 'median_fft', 'max_fft', 'min_fft'], value='avg_wave', interactive=True)
 
1
+ from typing import Optional, Any
 
 
2
  import os
3
  import sys
4
  import torch
5
  import logging
6
+ import yt_dlp
7
  from yt_dlp import YoutubeDL
8
  import gradio as gr
9
  import argparse
 
19
  import gdown
20
  import scipy.io.wavfile
21
  from pydub import AudioSegment
 
22
  import gc
23
+ import time
24
 
25
  # Logging setup (mevcut)
26
  logging.basicConfig(level=logging.INFO)
 
432
  def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_format="wav", use_tta="False", model_dir="/tmp/audio-separator-models/", output_dir="output", norm_thresh=0.9, amp_thresh=0.9, batch_size=1, ensemble_method="avg_wave", exclude_stems="", weights_str="", progress=gr.Progress(track_tqdm=True)):
433
  temp_audio_path = None
434
  chunk_paths = []
435
+ max_retries = 2
436
  try:
437
  if not audio:
438
  raise ValueError("No audio file provided.")
439
  if not model_keys:
440
  raise ValueError("No models selected.")
441
+ if len(model_keys) > 2:
442
+ logger.warning("Limited to 2 models to avoid ZeroGPU timeouts. Using first two: %s", model_keys[:2])
443
+ model_keys = model_keys[:2]
444
  if isinstance(audio, tuple):
445
  sample_rate, data = audio
446
  temp_audio_path = os.path.join("/tmp", "temp_audio.wav")
 
467
  chunks = [audio]
468
  use_tta = use_tta == "True"
469
  if os.path.exists(output_dir):
470
+ shutil.rmtree(output_dir)
 
471
  os.makedirs(output_dir, exist_ok=True)
472
  base_name = os.path.splitext(os.path.basename(audio))[0]
473
  logger.info(f"Ensemble for {base_name} with {model_keys} on {device}")
 
488
  while retry_count <= max_retries:
489
  try:
490
  progress((model_idx + 0.1) / total_models, desc=f"Loading {model_key} for chunk {chunk_idx}")
491
+ # Check if model is cached
492
+ model_path = os.path.join(model_dir, model)
493
+ if not os.path.exists(model_path):
494
+ logger.info(f"Model {model} not cached, will download")
495
  separator = Separator(
496
  log_level=logging.INFO,
497
  model_file_dir=model_dir,
 
513
  model_stems[model_key]["vocals"].append(stem)
514
  elif "other" in os.path.basename(stem).lower() or "instrumental" in os.path.basename(stem).lower():
515
  model_stems[model_key]["other"].append(stem)
516
+ break
517
  except Exception as e:
518
  retry_count += 1
519
  logger.error(f"Error processing {model_key} chunk {chunk_idx}, attempt {retry_count}/{max_retries}: {e}")
520
+ if "ZeroGPU" in str(e) or "aborted" in str(e).lower():
521
+ logger.error("ZeroGPU task aborted, attempting recovery")
522
  if retry_count > max_retries:
523
  logger.error(f"Max retries reached for {model_key} chunk {chunk_idx}, skipping")
524
  break
525
+ time.sleep(1) # Reduced delay to minimize overhead
526
  finally:
527
  separator = None
528
  gc.collect()
529
  if torch.cuda.is_available():
530
  torch.cuda.empty_cache()
531
  logger.info(f"Cleared CUDA cache after {model_key} chunk {chunk_idx}")
532
+ # Yield control to ZeroGPU scheduler
533
+ time.sleep(0.1)
534
  progress(0.8, desc="Combining stems...")
535
  for model_key, stems_dict in model_stems.items():
536
  for stem_type in ["vocals", "other"]:
 
550
  logger.error(f"Error combining {stem_type} for {model_key}: {e}")
551
  all_stems = [stem for stem in all_stems if os.path.exists(stem)]
552
  if not all_stems:
553
+ raise ValueError("No valid stems found for ensemble. Try uploading a local WAV file.")
554
  weights = [float(w.strip()) for w in weights_str.split(',')] if weights_str.strip() else [1.0] * len(all_stems)
555
  if len(weights) != len(all_stems):
556
  weights = [1.0] * len(all_stems)
 
574
  raise RuntimeError(f"Ensemble processing error: {e}")
575
  except Exception as e:
576
  logger.error(f"Ensemble error: {e}")
577
+ if "ZeroGPU" in str(e) or "aborted" in str(e).lower():
578
+ error_msg = "ZeroGPU task aborted. Try using fewer models (max 2), lowering segment size, or uploading a local WAV file."
579
+ else:
580
+ error_msg = f"Ensemble error: {e}"
581
+ raise RuntimeError(error_msg)
582
  finally:
583
  for path in chunk_paths + ([temp_audio_path] if temp_audio_path and os.path.exists(temp_audio_path) else []):
584
  try:
 
612
  gr.Markdown("<h1 class='header-text'>🎡 SESA Fast Separation 🎡</h1>")
613
  gr.Markdown("**Note**: If YouTube downloads fail, upload an audio file directly or use a valid cookies file. [Cookie Instructions](https://github.com/yt-dlp/yt-dlp/wiki/Extractors#exporting-youtube-cookies)")
614
  gr.Markdown("**Warning**: Audio files longer than 15 minutes are automatically split into chunks, which may require more time and resources.")
615
+ gr.Markdown("**ZeroGPU Notice**: Use up to 2 models for ensemble to avoid timeouts. For large tasks, upload a local WAV file.")
616
  with gr.Tabs():
617
  with gr.Tab("βš™οΈ Settings"):
618
  with gr.Group(elem_classes="dubbing-theme"):
 
622
  output_format = gr.Dropdown(value="wav", choices=OUTPUT_FORMATS, label="🎢 Output Format", interactive=True)
623
  norm_threshold = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="πŸ”Š Normalization Threshold", interactive=True)
624
  amp_threshold = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="πŸ“ˆ Amplification Threshold", interactive=True)
625
+ batch_size = gr.Slider(1, 8, value=1, step=1, label="⚑ Batch Size", interactive=True)
626
  with gr.Tab("🎀 Roformer"):
627
  with gr.Group(elem_classes="dubbing-theme"):
628
  gr.Markdown("### Audio Separation")
 
637
  roformer_category = gr.Dropdown(label="πŸ“š Category", choices=list(ROFORMER_MODELS.keys()), value="General Purpose", interactive=True)
638
  roformer_model = gr.Dropdown(label="πŸ› οΈ Model", choices=list(ROFORMER_MODELS["General Purpose"].keys()), interactive=True, allow_custom_value=True)
639
  with gr.Row():
640
+ roformer_seg_size = gr.Slider(32, 512, value=128, step=32, label="πŸ“ Segment Size", interactive=True)
641
  roformer_overlap = gr.Slider(2, 10, value=8, step=1, label="πŸ”„ Overlap", interactive=True)
642
  with gr.Row():
643
  roformer_pitch_shift = gr.Slider(-12, 12, value=0, step=1, label="🎡 Pitch Shift", interactive=True)
 
649
  with gr.Tab("🎚️ Auto Ensemble"):
650
  with gr.Group(elem_classes="dubbing-theme"):
651
  gr.Markdown("### Ensemble Processing")
652
+ gr.Markdown("Note: If weights are not specified, equal weights (1.0) are applied to all models. Max 2 models recommended.")
653
  with gr.Row():
654
  ensemble_audio = gr.Audio(label="🎧 Upload Audio", type="filepath", interactive=True)
655
  url_ensemble = gr.Textbox(label="πŸ”— Or Paste URL", placeholder="YouTube or audio URL", interactive=True)
 
659
  ensemble_exclude_stems = gr.Textbox(label="🚫 Exclude Stems", placeholder="e.g., vocals, drums (comma-separated)", interactive=True)
660
  with gr.Row():
661
  ensemble_category = gr.Dropdown(label="πŸ“š Category", choices=list(ROFORMER_MODELS.keys()), value="Instrumentals", interactive=True)
662
+ ensemble_models = gr.Dropdown(label="πŸ› οΈ Models (Max 2)", choices=list(ROFORMER_MODELS["Instrumentals"].keys()), multiselect=True, interactive=True, allow_custom_value=True)
663
  with gr.Row():
664
+ ensemble_seg_size = gr.Slider(32, 512, value=128, step=32, label="πŸ“ Segment Size", interactive=True)
665
  ensemble_overlap = gr.Slider(2, 10, value=8, step=1, label="πŸ”„ Overlap", interactive=True)
666
  ensemble_use_tta = gr.Dropdown(choices=["True", "False"], value="False", label="πŸ” Use TTA", interactive=True)
667
  ensemble_method = gr.Dropdown(label="βš™οΈ Ensemble Method", choices=['avg_wave', 'median_wave', 'max_wave', 'min_wave', 'avg_fft', 'median_fft', 'max_fft', 'min_fft'], value='avg_wave', interactive=True)