ASesYusuf1 commited on
Commit
1d35b52
·
verified ·
1 Parent(s): 4bcaa31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -34
app.py CHANGED
@@ -24,14 +24,6 @@ from threading import Lock
24
  import scipy.io.wavfile
25
  import spaces
26
 
27
- # Global state definition
28
- ensemble_state = {
29
- "current_model_idx": 0,
30
- "current_audio": None,
31
- "processed_stems": [],
32
- "model_outputs": {} # Her modelin stem'lerini saklamak için
33
- }
34
-
35
  # Logging setup
36
  logging.basicConfig(level=logging.INFO)
37
  logger = logging.getLogger(__name__)
@@ -450,7 +442,7 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
450
  logger.info("GPU memory cleared")
451
 
452
  @spaces.GPU
453
- def auto_ensemble_process(audio, model_keys, seg_size=64, overlap=0.1, out_format="wav", use_tta="False", model_dir="/tmp/audio-separator-models/", output_dir="output", norm_thresh=0.9, amp_thresh=0.9, batch_size=1, ensemble_method="avg_wave", exclude_stems="", weights_str="", progress=gr.Progress(track_tqdm=True)):
454
  temp_audio_path = None
455
  start_time = time.time()
456
  try:
@@ -475,12 +467,21 @@ def auto_ensemble_process(audio, model_keys, seg_size=64, overlap=0.1, out_forma
475
  scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
476
  audio = temp_audio_path
477
 
478
- # State yönetimini kontrol et ve sıfırlama
479
- if ensemble_state["current_audio"] != audio or ensemble_state["current_model_idx"] >= len(model_keys):
480
- ensemble_state["current_audio"] = audio
481
- ensemble_state["current_model_idx"] = 0
482
- ensemble_state["processed_stems"] = []
483
- ensemble_state["model_outputs"] = {model_key: {"vocals": [], "other": []} for model_key in model_keys}
 
 
 
 
 
 
 
 
 
484
  logger.info("New audio or completed cycle detected, resetting ensemble state.")
485
 
486
  use_tta = use_tta == "True"
@@ -497,7 +498,7 @@ def auto_ensemble_process(audio, model_keys, seg_size=64, overlap=0.1, out_forma
497
  total_tasks = len(model_keys)
498
 
499
  # Şu anki modeli işle
500
- current_idx = ensemble_state["current_model_idx"]
501
  if current_idx >= len(model_keys):
502
  # Tüm modeller işlendiyse ensemble işlemini yap
503
  logger.info("All models processed, running ensemble...")
@@ -507,7 +508,7 @@ def auto_ensemble_process(audio, model_keys, seg_size=64, overlap=0.1, out_forma
507
  excluded_stems_list = [s.strip().lower() for s in exclude_stems.split(',')] if exclude_stems.strip() else []
508
 
509
  # Tüm stem’leri topla, ama "Exclude Stems" ile belirtilenleri hariç tut
510
- for model_key, stems_dict in ensemble_state["model_outputs"].items():
511
  for stem_type in ["vocals", "other"]:
512
  if stems_dict[stem_type]:
513
  if stem_type.lower() in excluded_stems_list:
@@ -537,10 +538,10 @@ def auto_ensemble_process(audio, model_keys, seg_size=64, overlap=0.1, out_forma
537
  raise RuntimeError(f"Ensemble failed, output file not created: {output_file}")
538
 
539
  # Durumu sıfırla
540
- ensemble_state["current_model_idx"] = 0
541
- ensemble_state["current_audio"] = None
542
- ensemble_state["processed_stems"] = []
543
- ensemble_state["model_outputs"] = {}
544
 
545
  elapsed = time.time() - start_time
546
  logger.info(f"Ensemble completed, output: {output_file}, took {elapsed:.2f}s")
@@ -551,7 +552,7 @@ def auto_ensemble_process(audio, model_keys, seg_size=64, overlap=0.1, out_forma
551
  file_name = os.path.basename(file)
552
  status += f"<li><a href='file={file}' download>{file_name}</a></li>"
553
  status += "</ul>"
554
- return output_file, status, file_list
555
 
556
  # Şu anki modeli işle
557
  model_key = model_keys[current_idx]
@@ -568,8 +569,8 @@ def auto_ensemble_process(audio, model_keys, seg_size=64, overlap=0.1, out_forma
568
  break
569
  else:
570
  logger.warning(f"Model {model_key} not found, skipping")
571
- ensemble_state["current_model_idx"] += 1
572
- return None, f"Model {model_key} not found, proceeding to next model.", []
573
 
574
  # Zaman kontrolü
575
  elapsed = time.time() - start_time
@@ -613,19 +614,19 @@ def auto_ensemble_process(audio, model_keys, seg_size=64, overlap=0.1, out_forma
613
  stem_type = "vocals" if "vocals" in os.path.basename(stem).lower() else "other"
614
  permanent_stem_path = os.path.join(permanent_output_dir, f"{base_name}_{stem_type}_{model_key.replace(' | ', '_').replace(' ', '_')}.{out_format}")
615
  shutil.copy(stem, permanent_stem_path)
616
- ensemble_state["model_outputs"][model_key][stem_type].append(permanent_stem_path)
617
  if stem_type not in exclude_stems.lower():
618
  result.append(permanent_stem_path)
619
 
620
- ensemble_state["processed_stems"].extend(result)
621
  break
622
 
623
  except Exception as e:
624
  logger.error(f"Error processing {model_key}, attempt {attempt + 1}/{max_retries + 1}: {e}")
625
  if attempt == max_retries:
626
  logger.error(f"Max retries reached for {model_key}, skipping")
627
- ensemble_state["current_model_idx"] += 1
628
- return None, f"Failed to process {model_key} after {max_retries} attempts.", []
629
  time.sleep(1)
630
 
631
  finally:
@@ -641,18 +642,18 @@ def auto_ensemble_process(audio, model_keys, seg_size=64, overlap=0.1, out_forma
641
  logger.info("Cleared model cache and GPU memory")
642
 
643
  # Bir sonraki modele geç
644
- ensemble_state["current_model_idx"] += 1
645
  elapsed = time.time() - start_time
646
  logger.info(f"Model {model_key} completed in {elapsed:.2f}s")
647
 
648
  # Çıktılar
649
- file_list = ensemble_state["processed_stems"]
650
  status = f"Model {model_key} (Model {current_idx + 1}/{len(model_keys)}) completed in {elapsed:.2f}s<br>Click 'Run Ensemble!' to process the next model.<br>Processed stems:<ul>"
651
  for file in file_list:
652
  file_name = os.path.basename(file)
653
  status += f"<li><a href='file={file}' download>{file_name}</a></li>"
654
  status += "</ul>"
655
- return file_list[0] if file_list else None, status, file_list
656
 
657
  except Exception as e:
658
  logger.error(f"Ensemble error: {e}")
@@ -668,7 +669,7 @@ def auto_ensemble_process(audio, model_keys, seg_size=64, overlap=0.1, out_forma
668
  logger.warning(f"Failed to delete temporary file {temp_audio_path}: {e}")
669
  if torch.cuda.is_available():
670
  torch.cuda.empty_cache()
671
- logger.info("GPU memory cleared")
672
 
673
  def update_roformer_models(category):
674
  """Update Roformer model dropdown based on selected category."""
@@ -691,6 +692,13 @@ def create_interface():
691
  gr.Markdown("<h1 class='header-text'>🎵 SESA Fast Separation 🎵</h1>")
692
  gr.Markdown("**Note**: If YouTube downloads fail, upload a valid cookies file or a local WAV file. [Cookie Instructions](https://github.com/yt-dlp/yt-dlp/wiki/Extractors#exporting-youtube-cookies)")
693
  gr.Markdown("**Tip**: For best results, use audio shorter than 15 minutes or fewer models (up to 6) to ensure smooth processing.")
 
 
 
 
 
 
 
694
  with gr.Tabs():
695
  with gr.Tab("⚙️ Settings"):
696
  with gr.Group(elem_classes="dubbing-theme"):
@@ -774,12 +782,12 @@ def create_interface():
774
  ensemble_button.click(
775
  fn=auto_ensemble_process,
776
  inputs=[
777
- ensemble_audio, ensemble_models, ensemble_seg_size, ensemble_overlap,
778
  output_format, ensemble_use_tta, model_file_dir, output_dir,
779
  norm_threshold, amp_threshold, batch_size, ensemble_method,
780
  ensemble_exclude_stems, ensemble_weights
781
  ],
782
- outputs=[ensemble_output, ensemble_status, ensemble_files]
783
  )
784
  return app
785
 
 
24
  import scipy.io.wavfile
25
  import spaces
26
 
 
 
 
 
 
 
 
 
27
  # Logging setup
28
  logging.basicConfig(level=logging.INFO)
29
  logger = logging.getLogger(__name__)
 
442
  logger.info("GPU memory cleared")
443
 
444
  @spaces.GPU
445
+ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, out_format="wav", use_tta="False", model_dir="/tmp/audio-separator-models/", output_dir="output", norm_thresh=0.9, amp_thresh=0.9, batch_size=1, ensemble_method="avg_wave", exclude_stems="", weights_str="", progress=gr.Progress(track_tqdm=True)):
446
  temp_audio_path = None
447
  start_time = time.time()
448
  try:
 
467
  scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
468
  audio = temp_audio_path
469
 
470
+ # State kontrolü
471
+ if not state:
472
+ state = {
473
+ "current_audio": None,
474
+ "current_model_idx": 0,
475
+ "processed_stems": [],
476
+ "model_outputs": {}
477
+ }
478
+
479
+ # Yeni audio dosyası kontrolü
480
+ if state["current_audio"] != audio or state["current_model_idx"] >= len(model_keys):
481
+ state["current_audio"] = audio
482
+ state["current_model_idx"] = 0
483
+ state["processed_stems"] = []
484
+ state["model_outputs"] = {model_key: {"vocals": [], "other": []} for model_key in model_keys}
485
  logger.info("New audio or completed cycle detected, resetting ensemble state.")
486
 
487
  use_tta = use_tta == "True"
 
498
  total_tasks = len(model_keys)
499
 
500
  # Şu anki modeli işle
501
+ current_idx = state["current_model_idx"]
502
  if current_idx >= len(model_keys):
503
  # Tüm modeller işlendiyse ensemble işlemini yap
504
  logger.info("All models processed, running ensemble...")
 
508
  excluded_stems_list = [s.strip().lower() for s in exclude_stems.split(',')] if exclude_stems.strip() else []
509
 
510
  # Tüm stem’leri topla, ama "Exclude Stems" ile belirtilenleri hariç tut
511
+ for model_key, stems_dict in state["model_outputs"].items():
512
  for stem_type in ["vocals", "other"]:
513
  if stems_dict[stem_type]:
514
  if stem_type.lower() in excluded_stems_list:
 
538
  raise RuntimeError(f"Ensemble failed, output file not created: {output_file}")
539
 
540
  # Durumu sıfırla
541
+ state["current_model_idx"] = 0
542
+ state["current_audio"] = None
543
+ state["processed_stems"] = []
544
+ state["model_outputs"] = {}
545
 
546
  elapsed = time.time() - start_time
547
  logger.info(f"Ensemble completed, output: {output_file}, took {elapsed:.2f}s")
 
552
  file_name = os.path.basename(file)
553
  status += f"<li><a href='file={file}' download>{file_name}</a></li>"
554
  status += "</ul>"
555
+ return output_file, status, file_list, state
556
 
557
  # Şu anki modeli işle
558
  model_key = model_keys[current_idx]
 
569
  break
570
  else:
571
  logger.warning(f"Model {model_key} not found, skipping")
572
+ state["current_model_idx"] += 1
573
+ return None, f"Model {model_key} not found, proceeding to next model.", [], state
574
 
575
  # Zaman kontrolü
576
  elapsed = time.time() - start_time
 
614
  stem_type = "vocals" if "vocals" in os.path.basename(stem).lower() else "other"
615
  permanent_stem_path = os.path.join(permanent_output_dir, f"{base_name}_{stem_type}_{model_key.replace(' | ', '_').replace(' ', '_')}.{out_format}")
616
  shutil.copy(stem, permanent_stem_path)
617
+ state["model_outputs"][model_key][stem_type].append(permanent_stem_path)
618
  if stem_type not in exclude_stems.lower():
619
  result.append(permanent_stem_path)
620
 
621
+ state["processed_stems"].extend(result)
622
  break
623
 
624
  except Exception as e:
625
  logger.error(f"Error processing {model_key}, attempt {attempt + 1}/{max_retries + 1}: {e}")
626
  if attempt == max_retries:
627
  logger.error(f"Max retries reached for {model_key}, skipping")
628
+ state["current_model_idx"] += 1
629
+ return None, f"Failed to process {model_key} after {max_retries} attempts.", [], state
630
  time.sleep(1)
631
 
632
  finally:
 
642
  logger.info("Cleared model cache and GPU memory")
643
 
644
  # Bir sonraki modele geç
645
+ state["current_model_idx"] += 1
646
  elapsed = time.time() - start_time
647
  logger.info(f"Model {model_key} completed in {elapsed:.2f}s")
648
 
649
  # Çıktılar
650
+ file_list = state["processed_stems"]
651
  status = f"Model {model_key} (Model {current_idx + 1}/{len(model_keys)}) completed in {elapsed:.2f}s<br>Click 'Run Ensemble!' to process the next model.<br>Processed stems:<ul>"
652
  for file in file_list:
653
  file_name = os.path.basename(file)
654
  status += f"<li><a href='file={file}' download>{file_name}</a></li>"
655
  status += "</ul>"
656
+ return file_list[0] if file_list else None, status, file_list, state
657
 
658
  except Exception as e:
659
  logger.error(f"Ensemble error: {e}")
 
669
  logger.warning(f"Failed to delete temporary file {temp_audio_path}: {e}")
670
  if torch.cuda.is_available():
671
  torch.cuda.empty_cache()
672
+ logger.info("GPU memory cleared")
673
 
674
  def update_roformer_models(category):
675
  """Update Roformer model dropdown based on selected category."""
 
692
  gr.Markdown("<h1 class='header-text'>🎵 SESA Fast Separation 🎵</h1>")
693
  gr.Markdown("**Note**: If YouTube downloads fail, upload a valid cookies file or a local WAV file. [Cookie Instructions](https://github.com/yt-dlp/yt-dlp/wiki/Extractors#exporting-youtube-cookies)")
694
  gr.Markdown("**Tip**: For best results, use audio shorter than 15 minutes or fewer models (up to 6) to ensure smooth processing.")
695
+ # Gradio State bileşeni
696
+ ensemble_state = gr.State(value={
697
+ "current_audio": None,
698
+ "current_model_idx": 0,
699
+ "processed_stems": [],
700
+ "model_outputs": {}
701
+ })
702
  with gr.Tabs():
703
  with gr.Tab("⚙️ Settings"):
704
  with gr.Group(elem_classes="dubbing-theme"):
 
782
  ensemble_button.click(
783
  fn=auto_ensemble_process,
784
  inputs=[
785
+ ensemble_audio, ensemble_models, ensemble_state, ensemble_seg_size, ensemble_overlap,
786
  output_format, ensemble_use_tta, model_file_dir, output_dir,
787
  norm_threshold, amp_threshold, batch_size, ensemble_method,
788
  ensemble_exclude_stems, ensemble_weights
789
  ],
790
+ outputs=[ensemble_output, ensemble_status, ensemble_files, ensemble_state]
791
  )
792
  return app
793