ASesYusuf1 commited on
Commit
8fc1631
Β·
verified Β·
1 Parent(s): bcc6e08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -121
app.py CHANGED
@@ -25,6 +25,13 @@ import scipy.io.wavfile
25
  import subprocess
26
  import spaces
27
  import torchaudio
 
 
 
 
 
 
 
28
 
29
  # Logging setup
30
  logging.basicConfig(level=logging.INFO)
@@ -63,97 +70,8 @@ max_retries = 2
63
  time_budget = 300 # ZeroGPU iΓ§in işlem sΔ±nΔ±rΔ±
64
  gpu_lock = Lock()
65
 
66
- # ROFORMER_MODELS and OUTPUT_FORMATS
67
- ROFORMER_MODELS = {
68
- "Vocals": {
69
- 'MelBand Roformer | Big Beta 6X by unwa': 'melband_roformer_big_beta6x.ckpt',
70
- 'MelBand Roformer Kim | Big Beta 4 FT by unwa': 'melband_roformer_big_beta4.ckpt',
71
- 'MelBand Roformer Kim | Big Beta 5e FT by unwa': 'melband_roformer_big_beta5e.ckpt',
72
- 'MelBand Roformer | Big Beta 6 by unwa': 'melband_roformer_big_beta6.ckpt',
73
- 'MelBand Roformer | Vocals by Kimberley Jensen': 'vocals_mel_band_roformer.ckpt',
74
- 'MelBand Roformer Kim | FT 3 by unwa': 'mel_band_roformer_kim_ft3_unwa.ckpt',
75
- 'MelBand Roformer Kim | FT by unwa': 'mel_band_roformer_kim_ft_unwa.ckpt',
76
- 'MelBand Roformer Kim | FT 2 by unwa': 'mel_band_roformer_kim_ft2_unwa.ckpt',
77
- 'MelBand Roformer Kim | FT 2 Bleedless by unwa': 'mel_band_roformer_kim_ft2_bleedless_unwa.ckpt',
78
- 'MelBand Roformer | Vocals by becruily': 'mel_band_roformer_vocals_becruily.ckpt',
79
- 'MelBand Roformer | Vocals Fullness by Aname': 'mel_band_roformer_vocal_fullness_aname.ckpt',
80
- 'BS Roformer | Vocals by Gabox': 'bs_roformer_vocals_gabox.ckpt',
81
- 'MelBand Roformer | Vocals by Gabox': 'mel_band_roformer_vocals_gabox.ckpt',
82
- 'MelBand Roformer | Vocals FV1 by Gabox': 'mel_band_roformer_vocals_fv1_gabox.ckpt',
83
- 'MelBand Roformer | Vocals FV2 by Gabox': 'mel_band_roformer_vocals_fv2_gabox.ckpt',
84
- 'MelBand Roformer | Vocals FV3 by Gabox': 'mel_band_roformer_vocals_fv3_gabox.ckpt',
85
- 'MelBand Roformer | Vocals FV4 by Gabox': 'mel_band_roformer_vocals_fv4_gabox.ckpt',
86
- 'BS Roformer | Chorus Male-Female by Sucial': 'model_chorus_bs_roformer_ep_267_sdr_24.1275.ckpt',
87
- 'BS Roformer | Male-Female by aufr33': 'bs_roformer_male_female_by_aufr33_sdr_7.2889.ckpt',
88
- },
89
- "Instrumentals": {
90
- 'MelBand Roformer | FVX by Gabox': 'mel_band_roformer_instrumental_fvx_gabox.ckpt',
91
- 'MelBand Roformer | INSTV8N by Gabox': 'mel_band_roformer_instrumental_instv8n_gabox.ckpt',
92
- 'MelBand Roformer | INSTV8 by Gabox': 'mel_band_roformer_instrumental_instv8_gabox.ckpt',
93
- 'MelBand Roformer | INSTV7N by Gabox': 'mel_band_roformer_instrumental_instv7n_gabox.ckpt',
94
- 'MelBand Roformer | Instrumental Bleedless V3 by Gabox': 'mel_band_roformer_instrumental_bleedless_v3_gabox.ckpt',
95
- 'MelBand Roformer Kim | Inst V1 (E) Plus by Unwa': 'melband_roformer_inst_v1e_plus.ckpt',
96
- 'MelBand Roformer Kim | Inst V1 Plus by Unwa': 'melband_roformer_inst_v1_plus.ckpt',
97
- 'MelBand Roformer Kim | Inst V1 by Unwa': 'melband_roformer_inst_v1.ckpt',
98
- 'MelBand Roformer Kim | Inst V1 (E) by Unwa': 'melband_roformer_inst_v1e.ckpt',
99
- 'MelBand Roformer Kim | Inst V2 by Unwa': 'melband_roformer_inst_v2.ckpt',
100
- 'MelBand Roformer | Instrumental by becruily': 'mel_band_roformer_instrumental_becruily.ckpt',
101
- 'MelBand Roformer | Instrumental by Gabox': 'mel_band_roformer_instrumental_gabox.ckpt',
102
- 'MelBand Roformer | Instrumental 2 by Gabox': 'mel_band_roformer_instrumental_2_gabox.ckpt',
103
- 'MelBand Roformer | Instrumental 3 by Gabox': 'mel_band_roformer_instrumental_3_gabox.ckpt',
104
- 'MelBand Roformer | Instrumental Bleedless V1 by Gabox': 'mel_band_roformer_instrumental_bleedless_v1_gabox.ckpt',
105
- 'MelBand Roformer | Instrumental Bleedless V2 by Gabox': 'mel_band_roformer_instrumental_bleedless_v2_gabox.ckpt',
106
- 'MelBand Roformer | Instrumental Fullness V1 by Gabox': 'mel_band_roformer_instrumental_fullness_v1_gabox.ckpt',
107
- 'MelBand Roformer | Instrumental Fullness V2 by Gabox': 'mel_band_roformer_instrumental_fullness_v2_gabox.ckpt',
108
- 'MelBand Roformer | Instrumental Fullness V3 by Gabox': 'mel_band_roformer_instrumental_fullness_v3_gabox.ckpt',
109
- 'MelBand Roformer | Instrumental Fullness Noisy V4 by Gabox': 'mel_band_roformer_instrumental_fullness_noise_v4_gabox.ckpt',
110
- 'MelBand Roformer | INSTV5 by Gabox': 'mel_band_roformer_instrumental_instv5_gabox.ckpt',
111
- 'MelBand Roformer | INSTV5N by Gabox': 'mel_band_roformer_instrumental_instv5n_gabox.ckpt',
112
- 'MelBand Roformer | INSTV6 by Gabox': 'mel_band_roformer_instrumental_instv6_gabox.ckpt',
113
- 'MelBand Roformer | INSTV6N by Gabox': 'mel_band_roformer_instrumental_instv6n_gabox.ckpt',
114
- 'MelBand Roformer | INSTV7 by Gabox': 'mel_band_roformer_instrumental_instv7_gabox.ckpt',
115
- },
116
- "InstVoc Duality": {
117
- 'MelBand Roformer Kim | InstVoc Duality V1 by Unwa': 'melband_roformer_instvoc_duality_v1.ckpt',
118
- 'MelBand Roformer Kim | InstVoc Duality V2 by Unwa': 'melband_roformer_instvox_duality_v2.ckpt',
119
- },
120
- "De-Reverb": {
121
- 'BS-Roformer-De-Reverb': 'deverb_bs_roformer_8_384dim_10depth.ckpt',
122
- 'MelBand Roformer | De-Reverb by anvuew': 'dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt',
123
- 'MelBand Roformer | De-Reverb Less Aggressive by anvuew': 'dereverb_mel_band_roformer_less_aggressive_anvuew_sdr_18.8050.ckpt',
124
- 'MelBand Roformer | De-Reverb Mono by anvuew': 'dereverb_mel_band_roformer_mono_anvuew.ckpt',
125
- 'MelBand Roformer | De-Reverb Big by Sucial': 'dereverb_big_mbr_ep_362.ckpt',
126
- 'MelBand Roformer | De-Reverb Super Big by Sucial': 'dereverb_super_big_mbr_ep_346.ckpt',
127
- 'MelBand Roformer | De-Reverb-Echo by Sucial': 'dereverb-echo_mel_band_roformer_sdr_10.0169.ckpt',
128
- 'MelBand Roformer | De-Reverb-Echo V2 by Sucial': 'dereverb-echo_mel_band_roformer_sdr_13.4843_v2.ckpt',
129
- 'MelBand Roformer | De-Reverb-Echo Fused by Sucial': 'dereverb_echo_mbr_fused.ckpt',
130
- },
131
- "Denoise": {
132
- 'Mel-Roformer-Denoise-Aufr33': 'denoise_mel_band_roformer_aufr33_sdr_27.9959.ckpt',
133
- 'Mel-Roformer-Denoise-Aufr33-Aggr': 'denoise_mel_band_roformer_aufr33_aggr_sdr_27.9768.ckpt',
134
- 'MelBand Roformer | Denoise-Debleed by Gabox': 'mel_band_roformer_denoise_debleed_gabox.ckpt',
135
- 'MelBand Roformer | Bleed Suppressor V1 by unwa-97chris': 'mel_band_roformer_bleed_suppressor_v1.ckpt',
136
- },
137
- "Karaoke": {
138
- 'Mel-Roformer-Karaoke-Aufr33-Viperx': 'mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956.ckpt',
139
- 'MelBand Roformer | Karaoke by Gabox': 'mel_band_roformer_karaoke_gabox.ckpt',
140
- 'MelBand Roformer | Karaoke by becruily': 'mel_band_roformer_karaoke_becruily.ckpt',
141
- },
142
- "General Purpose": {
143
- 'BS-Roformer-Viperx-1297': 'model_bs_roformer_ep_317_sdr_12.9755.ckpt',
144
- 'BS-Roformer-Viperx-1296': 'model_bs_roformer_ep_368_sdr_12.9628.ckpt',
145
- 'BS-Roformer-Viperx-1053': 'model_bs_roformer_ep_937_sdr_10.5309.ckpt',
146
- 'Mel-Roformer-Viperx-1143': 'model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt',
147
- 'Mel-Roformer-Crowd-Aufr33-Viperx': 'mel_band_roformer_crowd_aufr33_viperx_sdr_8.7144.ckpt',
148
- 'MelBand Roformer Kim | SYHFT by SYH99999': 'MelBandRoformerSYHFT.ckpt',
149
- 'MelBand Roformer Kim | SYHFT V2 by SYH99999': 'MelBandRoformerSYHFTV2.ckpt',
150
- 'MelBand Roformer Kim | SYHFT V2.5 by SYH99999': 'MelBandRoformerSYHFTV2.5.ckpt',
151
- 'MelBand Roformer Kim | SYHFT V3 by SYH99999': 'MelBandRoformerSYHFTV3Epsilon.ckpt',
152
- 'MelBand Roformer Kim | Big SYHFT V1 by SYH99999': 'MelBandRoformerBigSYHFTV1.ckpt',
153
- 'MelBand Roformer | Aspiration by Sucial': 'aspiration_mel_band_roformer_sdr_18.9845.ckpt',
154
- 'MelBand Roformer | Aspiration Less Aggressive by Sucial': 'aspiration_mel_band_roformer_less_aggr_sdr_18.1201.ckpt',
155
- }
156
- }
157
 
158
  OUTPUT_FORMATS = ['wav', 'flac', 'mp3', 'ogg', 'opus', 'm4a', 'aiff', 'ac3']
159
 
@@ -509,12 +427,13 @@ def download_from_google_drive(url):
509
  except Exception as e:
510
  logger.warning(f"Failed to delete temporary file {temp_output_path}: {str(e)}")
511
 
512
- @spaces.GPU(duration=60)
513
  def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, pitch_shift, model_dir, output_dir, out_format, norm_thresh, amp_thresh, batch_size, exclude_stems="", progress=gr.Progress(track_tqdm=True)):
514
  if not audio:
515
  raise ValueError("No audio or video file provided.")
516
  temp_audio_path = None
517
  extracted_audio_path = None
 
518
  try:
519
  file_extension = os.path.splitext(audio)[1].lower().lstrip('.')
520
  supported_formats = ['wav', 'mp3', 'flac', 'ogg', 'opus', 'm4a', 'aiff', 'ac3', 'mp4', 'mov', 'avi', 'mkv', 'flv', 'wmv', 'webm', 'mpeg', 'mpg', 'ts', 'vob']
@@ -554,29 +473,88 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
554
  if os.path.exists(output_dir):
555
  shutil.rmtree(output_dir)
556
  os.makedirs(output_dir, exist_ok=True)
557
- base_name = os.path.splitext(os.path.basename(audio))[0].replace(' ', '_') # Boşlukları alt çizgi ile değiştir
558
- for category, models in ROFORMER_MODELS.items():
559
- if model_key in models:
560
- model = models[model_key]
561
- break
562
- else:
563
  raise ValueError(f"Model '{model_key}' not found.")
 
 
 
 
 
 
 
564
  logger.info(f"Separating {base_name} with {model_key} on {device}")
565
- separator = Separator(
566
- log_level=logging.INFO,
567
- model_file_dir=model_dir,
568
- output_dir=output_dir,
569
- output_format=out_format,
570
- normalization_threshold=norm_thresh,
571
- amplification_threshold=amp_thresh,
572
- use_autocast=use_autocast,
573
- mdxc_params={"segment_size": seg_size, "override_model_segment_size": override_seg_size, "batch_size": batch_size, "overlap": overlap, "pitch_shift": pitch_shift}
574
- )
575
- progress(0.2, desc="Loading model...")
576
- separator.load_model(model_filename=model)
577
- progress(0.7, desc="Separating audio...")
578
- separation = separator.separate(audio_to_process)
579
- stems = [os.path.join(output_dir, file_name) for file_name in separation]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
  file_list = []
581
  if exclude_stems.strip():
582
  excluded = [s.strip().lower() for s in exclude_stems.split(',')]
@@ -586,7 +564,7 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
586
  stem2 = filtered_stems[1] if len(filtered_stems) > 1 else None
587
  else:
588
  file_list = stems
589
- stem1 = stems[0]
590
  stem2 = stems[1] if len(stems) > 1 else None
591
 
592
  return stem1, stem2, file_list
@@ -611,7 +589,7 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
611
  torch.cuda.empty_cache()
612
  logger.info("GPU memory cleared")
613
 
614
- @spaces.GPU(duration=60)
615
  def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, out_format="wav", use_tta="False", model_dir="/tmp/audio-separator-models/", output_dir="output", norm_thresh=0.9, amp_thresh=0.9, batch_size=1, ensemble_method="avg_wave", exclude_stems="", weights_str="", progress=gr.Progress(track_tqdm=True)):
616
  temp_audio_path = None
617
  extracted_audio_path = None
@@ -801,6 +779,10 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
801
 
802
  if model_key not in model_cache:
803
  logger.info(f"Loading {model_key} into cache")
 
 
 
 
804
  separator = Separator(
805
  log_level=logging.INFO,
806
  model_file_dir=model_dir,
@@ -895,12 +877,14 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
895
  logger.info("GPU memory cleared")
896
 
897
  def update_roformer_models(category):
898
- choices = list(ROFORMER_MODELS.get(category, {}).keys()) or []
 
899
  logger.debug(f"Updating roformer models for category {category}: {choices}")
900
  return gr.update(choices=choices, value=choices[0] if choices else None)
901
 
902
  def update_ensemble_models(category):
903
- choices = list(ROFORMER_MODELS.get(category, {}).keys()) or []
 
904
  logger.debug(f"Updating ensemble models for category {category}: {choices}")
905
  return gr.update(choices=choices, value=[])
906
 
@@ -908,6 +892,59 @@ def download_audio_wrapper(url, cookie_file):
908
  file_path, status, audio_data = download_audio(url, cookie_file)
909
  return file_path, status # Return file_path instead of audio_data
910
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
911
  def create_interface():
912
  with gr.Blocks(title="🎡 SESA Fast Separation 🎡", css=CSS, elem_id="app-container") as app:
913
  gr.Markdown("<h1 class='header-text'>🎡 SESA Fast Separation 🎡</h1>")
@@ -940,8 +977,8 @@ def create_interface():
940
  roformer_download_status = gr.Textbox(label="πŸ“’ Download Status", interactive=False)
941
  roformer_exclude_stems = gr.Textbox(label="🚫 Exclude Stems", placeholder="e.g., vocals, drums (comma-separated)", interactive=True)
942
  with gr.Row():
943
- roformer_category = gr.Dropdown(label="πŸ“š Category", choices=list(ROFORMER_MODELS.keys()), value="General Purpose", interactive=True)
944
- roformer_model = gr.Dropdown(label="πŸ› οΈ Model", choices=list(ROFORMER_MODELS["General Purpose"].keys()), interactive=True, allow_custom_value=True)
945
  with gr.Row():
946
  roformer_seg_size = gr.Slider(32, 512, value=64, step=32, label="πŸ“ Segment Size", interactive=True)
947
  roformer_overlap = gr.Slider(2, 10, value=8, step=1, label="πŸ”„ Overlap", interactive=True)
@@ -965,8 +1002,8 @@ def create_interface():
965
  ensemble_download_status = gr.Textbox(label="πŸ“’ Download Status", interactive=False)
966
  ensemble_exclude_stems = gr.Textbox(label="🚫 Exclude Stems", placeholder="e.g., vocals, drums (comma-separated)", interactive=True)
967
  with gr.Row():
968
- ensemble_category = gr.Dropdown(label="πŸ“š Category", choices=list(ROFORMER_MODELS.keys()), value="Instrumentals", interactive=True)
969
- ensemble_models = gr.Dropdown(label="πŸ› οΈ Models (Max 6)", choices=list(ROFORMER_MODELS["Instrumentals"].keys()), multiselect=True, interactive=True, allow_custom_value=True)
970
  with gr.Row():
971
  ensemble_seg_size = gr.Slider(32, 512, value=64, step=32, label="πŸ“ Segment Size", interactive=True)
972
  ensemble_overlap = gr.Slider(2, 10, value=8, step=1, label="πŸ”„ Overlap", interactive=True)
@@ -977,6 +1014,42 @@ def create_interface():
977
  ensemble_output = gr.Audio(label="🎢 Ensemble Result", type="filepath", interactive=False)
978
  ensemble_status = gr.HTML(label="πŸ“’ Status")
979
  ensemble_files = gr.File(label="πŸ“₯ Download Ensemble and Stems", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
980
  gr.HTML("<div class='footer'>Powered by Audio-Separator 🌟🎢 | Made with ❀️</div>")
981
  roformer_category.change(update_roformer_models, inputs=[roformer_category], outputs=[roformer_model])
982
  download_roformer.click(
@@ -1009,6 +1082,28 @@ def create_interface():
1009
  ],
1010
  outputs=[ensemble_output, ensemble_status, ensemble_files, ensemble_state]
1011
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1012
  return app
1013
 
1014
  if __name__ == "__main__":
 
25
  import subprocess
26
  import spaces
27
  import torchaudio
28
+ from models_config import (
29
+ EXTENDED_MODELS, get_all_models, get_categories, get_model_choices,
30
+ find_model_filename, add_custom_model, delete_custom_model, load_custom_models,
31
+ get_custom_models_list, ensure_model_files_downloaded,
32
+ get_audio_duration, split_audio_segments, concatenate_segment_outputs,
33
+ MAX_UNSPLIT_DURATION, SEGMENT_DURATION
34
+ )
35
 
36
  # Logging setup
37
  logging.basicConfig(level=logging.INFO)
 
70
  time_budget = 300 # ZeroGPU iΓ§in işlem sΔ±nΔ±rΔ±
71
  gpu_lock = Lock()
72
 
73
+ # ROFORMER_MODELS - now using EXTENDED_MODELS from models_config
74
+ ROFORMER_MODELS = get_all_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  OUTPUT_FORMATS = ['wav', 'flac', 'mp3', 'ogg', 'opus', 'm4a', 'aiff', 'ac3']
77
 
 
427
  except Exception as e:
428
  logger.warning(f"Failed to delete temporary file {temp_output_path}: {str(e)}")
429
 
430
+ @spaces.GPU(duration=300)
431
  def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, pitch_shift, model_dir, output_dir, out_format, norm_thresh, amp_thresh, batch_size, exclude_stems="", progress=gr.Progress(track_tqdm=True)):
432
  if not audio:
433
  raise ValueError("No audio or video file provided.")
434
  temp_audio_path = None
435
  extracted_audio_path = None
436
+ segment_temp_dir = None
437
  try:
438
  file_extension = os.path.splitext(audio)[1].lower().lstrip('.')
439
  supported_formats = ['wav', 'mp3', 'flac', 'ogg', 'opus', 'm4a', 'aiff', 'ac3', 'mp4', 'mov', 'avi', 'mkv', 'flv', 'wmv', 'webm', 'mpeg', 'mpg', 'ts', 'vob']
 
473
  if os.path.exists(output_dir):
474
  shutil.rmtree(output_dir)
475
  os.makedirs(output_dir, exist_ok=True)
476
+ base_name = os.path.splitext(os.path.basename(audio))[0].replace(' ', '_')
477
+
478
+ # Find model from EXTENDED_MODELS + custom models
479
+ model = find_model_filename(model_key)
480
+ if not model:
 
481
  raise ValueError(f"Model '{model_key}' not found.")
482
+
483
+ # Pre-download model files (checkpoint + config YAML) before loading
484
+ # This is required for the separator.py bypass to work
485
+ dl_success, dl_msg = ensure_model_files_downloaded(model, model_dir)
486
+ if not dl_success:
487
+ logger.warning(f"Pre-download warning for {model}: {dl_msg}")
488
+
489
  logger.info(f"Separating {base_name} with {model_key} on {device}")
490
+
491
+ # ── Large file segmentation ──
492
+ audio_duration = get_audio_duration(audio_to_process)
493
+ was_segmented = False
494
+ if audio_duration > MAX_UNSPLIT_DURATION:
495
+ duration_min = audio_duration / 60
496
+ logger.info(f"⚠️ Large audio detected: {duration_min:.0f} min. Splitting to prevent OOM...")
497
+ progress(0.05, desc=f"Splitting {duration_min:.0f} min audio into segments...")
498
+ segment_temp_dir = os.path.join("/tmp", f"sesa_segments_{base_name}")
499
+ os.makedirs(segment_temp_dir, exist_ok=True)
500
+ segments = split_audio_segments(audio_to_process, segment_temp_dir, SEGMENT_DURATION)
501
+ if segments:
502
+ was_segmented = True
503
+ logger.info(f"Split into {len(segments)} segments")
504
+ # Process each segment
505
+ seg_output_dir = os.path.join("/tmp", f"sesa_seg_output_{base_name}")
506
+ os.makedirs(seg_output_dir, exist_ok=True)
507
+ for i, seg_path in enumerate(segments):
508
+ progress(0.1 + 0.7 * (i / len(segments)), desc=f"Processing segment {i+1}/{len(segments)}...")
509
+ separator = Separator(
510
+ log_level=logging.INFO,
511
+ model_file_dir=model_dir,
512
+ output_dir=seg_output_dir,
513
+ output_format=out_format,
514
+ normalization_threshold=norm_thresh,
515
+ amplification_threshold=amp_thresh,
516
+ use_autocast=use_autocast,
517
+ mdxc_params={"segment_size": seg_size, "override_model_segment_size": override_seg_size, "batch_size": batch_size, "overlap": overlap, "pitch_shift": pitch_shift}
518
+ )
519
+ separator.load_model(model_filename=model)
520
+ separator.separate(seg_path)
521
+ # Free GPU memory between segments
522
+ del separator
523
+ if torch.cuda.is_available():
524
+ torch.cuda.empty_cache()
525
+ gc.collect()
526
+ # Concatenate segment outputs
527
+ progress(0.85, desc="Concatenating segments...")
528
+ concatenate_segment_outputs(seg_output_dir, out_format)
529
+ # Move final concatenated files to output_dir
530
+ for f in os.listdir(seg_output_dir):
531
+ if '_seg' not in f.lower(): # Only move final merged files
532
+ shutil.move(os.path.join(seg_output_dir, f), os.path.join(output_dir, f))
533
+ # Cleanup temp dirs
534
+ shutil.rmtree(segment_temp_dir, ignore_errors=True)
535
+ shutil.rmtree(seg_output_dir, ignore_errors=True)
536
+ segment_temp_dir = None
537
+
538
+ if not was_segmented:
539
+ # Normal processing (no segmentation)
540
+ separator = Separator(
541
+ log_level=logging.INFO,
542
+ model_file_dir=model_dir,
543
+ output_dir=output_dir,
544
+ output_format=out_format,
545
+ normalization_threshold=norm_thresh,
546
+ amplification_threshold=amp_thresh,
547
+ use_autocast=use_autocast,
548
+ mdxc_params={"segment_size": seg_size, "override_model_segment_size": override_seg_size, "batch_size": batch_size, "overlap": overlap, "pitch_shift": pitch_shift}
549
+ )
550
+ progress(0.2, desc="Loading model...")
551
+ separator.load_model(model_filename=model)
552
+ progress(0.7, desc="Separating audio...")
553
+ separator.separate(audio_to_process)
554
+
555
+ # Collect all output stems
556
+ output_files = os.listdir(output_dir)
557
+ stems = [os.path.join(output_dir, f) for f in output_files if os.path.isfile(os.path.join(output_dir, f))]
558
  file_list = []
559
  if exclude_stems.strip():
560
  excluded = [s.strip().lower() for s in exclude_stems.split(',')]
 
564
  stem2 = filtered_stems[1] if len(filtered_stems) > 1 else None
565
  else:
566
  file_list = stems
567
+ stem1 = stems[0] if stems else None
568
  stem2 = stems[1] if len(stems) > 1 else None
569
 
570
  return stem1, stem2, file_list
 
589
  torch.cuda.empty_cache()
590
  logger.info("GPU memory cleared")
591
 
592
+ @spaces.GPU(duration=300)
593
  def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, out_format="wav", use_tta="False", model_dir="/tmp/audio-separator-models/", output_dir="output", norm_thresh=0.9, amp_thresh=0.9, batch_size=1, ensemble_method="avg_wave", exclude_stems="", weights_str="", progress=gr.Progress(track_tqdm=True)):
594
  temp_audio_path = None
595
  extracted_audio_path = None
 
779
 
780
  if model_key not in model_cache:
781
  logger.info(f"Loading {model_key} into cache")
782
+ # Pre-download model files for bypass
783
+ dl_ok, dl_msg = ensure_model_files_downloaded(model, model_dir)
784
+ if not dl_ok:
785
+ logger.warning(f"Pre-download warning: {dl_msg}")
786
  separator = Separator(
787
  log_level=logging.INFO,
788
  model_file_dir=model_dir,
 
877
  logger.info("GPU memory cleared")
878
 
879
  def update_roformer_models(category):
880
+ all_models = get_all_models()
881
+ choices = list(all_models.get(category, {}).keys()) or []
882
  logger.debug(f"Updating roformer models for category {category}: {choices}")
883
  return gr.update(choices=choices, value=choices[0] if choices else None)
884
 
885
  def update_ensemble_models(category):
886
+ all_models = get_all_models()
887
+ choices = list(all_models.get(category, {}).keys()) or []
888
  logger.debug(f"Updating ensemble models for category {category}: {choices}")
889
  return gr.update(choices=choices, value=[])
890
 
 
892
  file_path, status, audio_data = download_audio(url, cookie_file)
893
  return file_path, status # Return file_path instead of audio_data
894
 
895
+ # ─── Batch Processing ────────────────────────────────────────────────────────
896
+ @spaces.GPU(duration=300)
897
+ def batch_separator(audio_files, model_key, seg_size, override_seg_size, overlap, pitch_shift, model_dir, output_dir, out_format, norm_thresh, amp_thresh, batch_size, exclude_stems="", progress=gr.Progress(track_tqdm=True)):
898
+ """Process up to 10 audio files sequentially."""
899
+ if not audio_files:
900
+ raise ValueError("No audio files provided.")
901
+ if len(audio_files) > 10:
902
+ raise ValueError("Maximum 10 files per batch.")
903
+
904
+ all_output_files = []
905
+ status_lines = []
906
+ for i, audio in enumerate(audio_files):
907
+ # Handle gr.File objects
908
+ audio_path = audio.name if hasattr(audio, 'name') else audio
909
+ base = os.path.splitext(os.path.basename(audio_path))[0]
910
+ progress((i) / len(audio_files), desc=f"Processing file {i+1}/{len(audio_files)}: {base}")
911
+ try:
912
+ stem1, stem2, files = roformer_separator(
913
+ audio_path, model_key, seg_size, override_seg_size, overlap, pitch_shift,
914
+ model_dir, output_dir, out_format, norm_thresh, amp_thresh, batch_size,
915
+ exclude_stems, progress
916
+ )
917
+ all_output_files.extend(files)
918
+ status_lines.append(f"βœ… {base}: {len(files)} stems")
919
+ except Exception as e:
920
+ status_lines.append(f"❌ {base}: {str(e)[:100]}")
921
+ logger.error(f"Batch processing error for {base}: {e}")
922
+
923
+ status_text = "\n".join(status_lines)
924
+ return status_text, all_output_files
925
+
926
+ # ─── Custom Model Management UI handlers ─────────────────────────────────────
927
+ def add_custom_model_handler(name, checkpoint_url, config_url, custom_py_url):
928
+ success, msg = add_custom_model(name, checkpoint_url, config_url, custom_py_url)
929
+ # Refresh ROFORMER_MODELS
930
+ global ROFORMER_MODELS
931
+ ROFORMER_MODELS = get_all_models()
932
+ # Get updated custom model list
933
+ custom_list_data = get_custom_models_list()
934
+ custom_list = "\n".join([f"β€’ {n}: {u}" for n, u in custom_list_data]) if custom_list_data else "No custom models"
935
+ # Return updated categories
936
+ cats = get_categories()
937
+ return msg, custom_list, gr.update(choices=cats), gr.update(choices=cats)
938
+
939
+ def delete_custom_model_handler(name):
940
+ success, msg = delete_custom_model(name)
941
+ global ROFORMER_MODELS
942
+ ROFORMER_MODELS = get_all_models()
943
+ custom_list_data = get_custom_models_list()
944
+ custom_list = "\n".join([f"β€’ {n}: {u}" for n, u in custom_list_data]) if custom_list_data else "No custom models"
945
+ cats = get_categories()
946
+ return msg, custom_list, gr.update(choices=cats), gr.update(choices=cats)
947
+
948
  def create_interface():
949
  with gr.Blocks(title="🎡 SESA Fast Separation 🎡", css=CSS, elem_id="app-container") as app:
950
  gr.Markdown("<h1 class='header-text'>🎡 SESA Fast Separation 🎡</h1>")
 
977
  roformer_download_status = gr.Textbox(label="πŸ“’ Download Status", interactive=False)
978
  roformer_exclude_stems = gr.Textbox(label="🚫 Exclude Stems", placeholder="e.g., vocals, drums (comma-separated)", interactive=True)
979
  with gr.Row():
980
+ roformer_category = gr.Dropdown(label="πŸ“š Category", choices=get_categories(), value="Vocals", interactive=True)
981
+ roformer_model = gr.Dropdown(label="πŸ› οΈ Model", choices=get_model_choices("Vocals"), interactive=True, allow_custom_value=True)
982
  with gr.Row():
983
  roformer_seg_size = gr.Slider(32, 512, value=64, step=32, label="πŸ“ Segment Size", interactive=True)
984
  roformer_overlap = gr.Slider(2, 10, value=8, step=1, label="πŸ”„ Overlap", interactive=True)
 
1002
  ensemble_download_status = gr.Textbox(label="πŸ“’ Download Status", interactive=False)
1003
  ensemble_exclude_stems = gr.Textbox(label="🚫 Exclude Stems", placeholder="e.g., vocals, drums (comma-separated)", interactive=True)
1004
  with gr.Row():
1005
+ ensemble_category = gr.Dropdown(label="πŸ“š Category", choices=get_categories(), value="Instrumentals", interactive=True)
1006
+ ensemble_models = gr.Dropdown(label="πŸ› οΈ Models (Max 6)", choices=get_model_choices("Instrumentals"), multiselect=True, interactive=True, allow_custom_value=True)
1007
  with gr.Row():
1008
  ensemble_seg_size = gr.Slider(32, 512, value=64, step=32, label="πŸ“ Segment Size", interactive=True)
1009
  ensemble_overlap = gr.Slider(2, 10, value=8, step=1, label="πŸ”„ Overlap", interactive=True)
 
1014
  ensemble_output = gr.Audio(label="🎢 Ensemble Result", type="filepath", interactive=False)
1015
  ensemble_status = gr.HTML(label="πŸ“’ Status")
1016
  ensemble_files = gr.File(label="πŸ“₯ Download Ensemble and Stems", interactive=False)
1017
+ with gr.Tab("πŸ“¦ Batch Processing"):
1018
+ with gr.Group(elem_classes="dubbing-theme"):
1019
+ gr.Markdown("### Batch Processing (Max 10 Files)")
1020
+ gr.Markdown("Upload multiple audio files and process them all with the same model.")
1021
+ batch_audio = gr.File(label="🎧 Upload Audio Files", file_count="multiple", file_types=['.wav', '.mp3', '.flac', '.ogg', '.opus', '.m4a', '.aiff', '.ac3', '.mp4', '.mov', '.avi', '.mkv'], interactive=True)
1022
+ with gr.Row():
1023
+ batch_category = gr.Dropdown(label="πŸ“š Category", choices=get_categories(), value="Vocals", interactive=True)
1024
+ batch_model = gr.Dropdown(label="πŸ› οΈ Model", choices=get_model_choices("Vocals"), interactive=True, allow_custom_value=True)
1025
+ with gr.Row():
1026
+ batch_seg_size = gr.Slider(32, 512, value=64, step=32, label="πŸ“ Segment Size", interactive=True)
1027
+ batch_overlap = gr.Slider(2, 10, value=8, step=1, label="πŸ”„ Overlap", interactive=True)
1028
+ batch_pitch_shift = gr.Slider(-12, 12, value=0, step=1, label="🎡 Pitch Shift", interactive=True)
1029
+ batch_override_seg = gr.Dropdown(choices=["True", "False"], value="False", label="πŸ”§ Override Segment Size", interactive=True)
1030
+ batch_exclude = gr.Textbox(label="🚫 Exclude Stems", placeholder="e.g., vocals, drums (comma-separated)", interactive=True)
1031
+ batch_button = gr.Button("πŸš€ Process Batch!", variant="primary")
1032
+ batch_status = gr.Textbox(label="πŸ“’ Batch Status", interactive=False, lines=5)
1033
+ batch_files = gr.File(label="πŸ“₯ Download All Stems", interactive=False)
1034
+ with gr.Tab("πŸ”§ Custom Models"):
1035
+ with gr.Group(elem_classes="dubbing-theme"):
1036
+ gr.Markdown("### Custom Model Management")
1037
+ gr.Markdown("Add custom models from HuggingFace or other sources by providing download URLs. The model will be automatically downloaded when used.")
1038
+ with gr.Row():
1039
+ custom_model_name = gr.Textbox(label="πŸ“ Model Display Name", placeholder="e.g., My Custom Vocal Model", interactive=True)
1040
+ with gr.Row():
1041
+ custom_checkpoint_url = gr.Textbox(label="πŸ“¦ Checkpoint URL (required)", placeholder="https://huggingface.co/.../resolve/main/model.ckpt", interactive=True)
1042
+ with gr.Row():
1043
+ custom_config_url = gr.Textbox(label="πŸ“„ Config URL (optional)", placeholder="https://huggingface.co/.../resolve/main/config.yaml", interactive=True)
1044
+ with gr.Row():
1045
+ custom_py_url = gr.Textbox(label="🐍 Custom .py URL (optional)", placeholder="https://huggingface.co/.../resolve/main/bs_roformer.py", interactive=True)
1046
+ with gr.Row():
1047
+ add_model_btn = gr.Button("βž• Add Model", variant="primary")
1048
+ del_model_name = gr.Textbox(label="πŸ—‘οΈ Model Name to Delete", placeholder="Exact model name", interactive=True)
1049
+ del_model_btn = gr.Button("πŸ—‘οΈ Delete Model", variant="stop")
1050
+ custom_model_status = gr.Textbox(label="πŸ“’ Status", interactive=False)
1051
+ custom_model_list = gr.Textbox(label="πŸ“‹ Custom Models", interactive=False, lines=8,
1052
+ value="\n".join([f"β€’ {n}: {u}" for n, u in get_custom_models_list()]) or "No custom models")
1053
  gr.HTML("<div class='footer'>Powered by Audio-Separator 🌟🎢 | Made with ❀️</div>")
1054
  roformer_category.change(update_roformer_models, inputs=[roformer_category], outputs=[roformer_model])
1055
  download_roformer.click(
 
1082
  ],
1083
  outputs=[ensemble_output, ensemble_status, ensemble_files, ensemble_state]
1084
  )
1085
+ # Batch processing events
1086
+ batch_category.change(update_roformer_models, inputs=[batch_category], outputs=[batch_model])
1087
+ batch_button.click(
1088
+ fn=batch_separator,
1089
+ inputs=[
1090
+ batch_audio, batch_model, batch_seg_size, batch_override_seg, batch_overlap,
1091
+ batch_pitch_shift, model_file_dir, output_dir, output_format,
1092
+ norm_threshold, amp_threshold, batch_size, batch_exclude
1093
+ ],
1094
+ outputs=[batch_status, batch_files]
1095
+ )
1096
+ # Custom model events
1097
+ add_model_btn.click(
1098
+ fn=add_custom_model_handler,
1099
+ inputs=[custom_model_name, custom_checkpoint_url, custom_config_url, custom_py_url],
1100
+ outputs=[custom_model_status, custom_model_list, roformer_category, ensemble_category]
1101
+ )
1102
+ del_model_btn.click(
1103
+ fn=delete_custom_model_handler,
1104
+ inputs=[del_model_name],
1105
+ outputs=[custom_model_status, custom_model_list, roformer_category, ensemble_category]
1106
+ )
1107
  return app
1108
 
1109
  if __name__ == "__main__":