ASesYusuf1 commited on
Commit
489680b
Β·
verified Β·
1 Parent(s): cc64221

Update gui.py

Browse files
Files changed (1) hide show
  1. gui.py +33 -27
gui.py CHANGED
@@ -11,6 +11,28 @@ import librosa
11
  import soundfile as sf
12
  from ensemble import ensemble_files
13
  import shutil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # Device and autocast setup
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -132,8 +154,7 @@ CSS = """
132
  overflow: hidden;
133
  }
134
  body {
135
- background: url('/content/logo.jpg') no-repeat center center fixed;
136
- background-size: cover;
137
  margin: 0;
138
  padding: 0;
139
  font-family: 'Roboto', sans-serif;
@@ -335,7 +356,6 @@ def download_audio(url, out_dir="ytdl"):
335
  if not url:
336
  raise ValueError("No URL provided.")
337
 
338
- # Clear ytdl directory
339
  if os.path.exists(out_dir):
340
  shutil.rmtree(out_dir)
341
  os.makedirs(out_dir, exist_ok=True)
@@ -358,10 +378,8 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
358
  if not audio:
359
  raise ValueError("No audio file provided.")
360
 
361
- # Convert override_seg_size to boolean
362
  override_seg_size = override_seg_size == "True"
363
 
364
- # Clear output directory
365
  if os.path.exists(output_dir):
366
  shutil.rmtree(output_dir)
367
  os.makedirs(output_dir, exist_ok=True)
@@ -392,7 +410,6 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
392
  separation = separator.separate(audio)
393
  stems = [os.path.join(output_dir, file_name) for file_name in separation]
394
 
395
- # Filter excluded stems
396
  if exclude_stems.strip():
397
  excluded = [s.strip().lower() for s in exclude_stems.split(',')]
398
  filtered_stems = [stem for stem in stems if not any(ex in os.path.basename(stem).lower() for ex in excluded)]
@@ -402,15 +419,13 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
402
  logger.error(f"Separation failed: {e}")
403
  raise RuntimeError(f"Separation failed: {e}")
404
 
405
- def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_tta, model_dir, output_dir, norm_thresh, amp_thresh, batch_size, ensemble_method, exclude_stems="", weights=None, progress=gr.Progress()):
406
  """Perform ensemble processing on audio using multiple Roformer models."""
407
  if not audio or not model_keys:
408
  raise ValueError("Audio or models missing.")
409
 
410
- # Convert use_tta to boolean
411
  use_tta = use_tta == "True"
412
 
413
- # Clear output directory
414
  if os.path.exists(output_dir):
415
  shutil.rmtree(output_dir)
416
  os.makedirs(output_dir, exist_ok=True)
@@ -421,7 +436,6 @@ def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_
421
  all_stems = []
422
  total_models = len(model_keys)
423
 
424
- # Separate audio with each model
425
  for i, model_key in enumerate(model_keys):
426
  for category, models in ROFORMER_MODELS.items():
427
  if model_key in models:
@@ -446,7 +460,6 @@ def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_
446
  separation = separator.separate(audio)
447
  stems = [os.path.join(output_dir, file_name) for file_name in separation]
448
 
449
- # Filter excluded stems
450
  if exclude_stems.strip():
451
  excluded = [s.strip().lower() for s in exclude_stems.split(',')]
452
  filtered_stems = [stem for stem in stems if not any(ex in os.path.basename(stem).lower() for ex in excluded)]
@@ -457,11 +470,10 @@ def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_
457
  if not all_stems:
458
  raise ValueError("No valid stems for ensemble after exclusion.")
459
 
460
- # Default weights if none provided
461
- if weights is None or len(weights) != len(all_stems):
462
  weights = [1.0] * len(all_stems)
463
 
464
- # Perform ensemble
465
  output_file = os.path.join(output_dir, f"{base_name}_ensemble_{ensemble_method}.{out_format}")
466
  ensemble_args = [
467
  "--files", *all_stems,
@@ -477,11 +489,11 @@ def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_
477
 
478
  def update_roformer_models(category):
479
  """Update Roformer model dropdown based on selected category."""
480
- return gr.update(choices=list(ROFORMER_MODELS.get(category, {}).keys()))
481
 
482
  def update_ensemble_models(category):
483
  """Update ensemble model dropdown based on selected category."""
484
- return gr.update(choices=list(ROFORMER_MODELS.get(category, {}).keys()))
485
 
486
  # Interface creation
487
  def create_interface():
@@ -497,8 +509,8 @@ def create_interface():
497
  model_file_dir = gr.Textbox(value="/tmp/audio-separator-models/", label="πŸ“‚ Model Cache", placeholder="Path to model directory", interactive=True)
498
  output_dir = gr.Textbox(value="output", label="πŸ“€ Output Directory", placeholder="Where to save results", interactive=True)
499
  output_format = gr.Dropdown(value="wav", choices=OUTPUT_FORMATS, label="🎢 Output Format", interactive=True)
500
- norm_threshold = gr.Slider(0.1, 1, value=0.9, step=0.1, label="πŸ”Š Normalization Threshold", interactive=True)
501
- amp_threshold = gr.Slider(0.1, 1, value=0.3, step=0.1, label="πŸ“ˆ Amplification Threshold", interactive=True)
502
  batch_size = gr.Slider(1, 16, value=4, step=1, label="⚑ Batch Size", interactive=True)
503
 
504
  # Roformer Tab
@@ -563,13 +575,7 @@ def create_interface():
563
  ensemble_category.change(update_ensemble_models, inputs=[ensemble_category], outputs=[ensemble_models])
564
  download_ensemble.click(fn=download_audio, inputs=[url_ensemble], outputs=[ensemble_audio])
565
  ensemble_button.click(
566
- fn=lambda audio, models, seg_size, overlap, out_format, use_tta, model_dir, output_dir,
567
- norm_thresh, amp_thresh, batch_size, method, exclude_stems, weights_str:
568
- auto_ensemble_process(
569
- audio, models, seg_size, overlap, out_format, use_tta, model_dir, output_dir,
570
- norm_thresh, amp_thresh, batch_size, method, exclude_stems,
571
- [float(w.strip()) for w in weights_str.split(',')] if weights_str.strip() else None
572
- ),
573
  inputs=[
574
  ensemble_audio, ensemble_models, ensemble_seg_size, ensemble_overlap,
575
  output_format, ensemble_use_tta, model_file_dir, output_dir,
@@ -588,8 +594,8 @@ if __name__ == "__main__":
588
 
589
  app = create_interface()
590
  try:
591
- # Use share=True for remote access or server_name="127.0.0.1" for local testing
592
- app.launch(server_name="0.0.0.0", server_port=args.port, share=True)
593
  except Exception as e:
594
  logger.error(f"Failed to launch app: {e}")
595
  raise
 
11
  import soundfile as sf
12
  from ensemble import ensemble_files
13
  import shutil
14
+ import gradio_client.utils as client_utils
15
+
16
+ # Patch gradio_client.utils.get_type to handle boolean schemas
17
+ def patched_get_type(schema):
18
+ if isinstance(schema, bool):
19
+ return "boolean"
20
+ if "const" in schema:
21
+ return repr(schema["const"])
22
+ if "enum" in schema:
23
+ return f"Union[{', '.join(repr(e) for e in schema['enum'])}]"
24
+ if "type" not in schema:
25
+ return "Any"
26
+ type_ = schema["type"]
27
+ if isinstance(type_, list):
28
+ return f"Union[{', '.join(t for t in type_ if t != 'null')}]"
29
+ if type_ == "array":
30
+ return f"List[{client_utils._json_schema_to_python_type(schema.get('items', {}), schema.get('$defs', {}))}]"
31
+ if type_ == "object":
32
+ return "Dict[str, Any]"
33
+ return type_
34
+
35
+ client_utils.get_type = patched_get_type
36
 
37
  # Device and autocast setup
38
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
154
  overflow: hidden;
155
  }
156
  body {
157
+ background: none;
 
158
  margin: 0;
159
  padding: 0;
160
  font-family: 'Roboto', sans-serif;
 
356
  if not url:
357
  raise ValueError("No URL provided.")
358
 
 
359
  if os.path.exists(out_dir):
360
  shutil.rmtree(out_dir)
361
  os.makedirs(out_dir, exist_ok=True)
 
378
  if not audio:
379
  raise ValueError("No audio file provided.")
380
 
 
381
  override_seg_size = override_seg_size == "True"
382
 
 
383
  if os.path.exists(output_dir):
384
  shutil.rmtree(output_dir)
385
  os.makedirs(output_dir, exist_ok=True)
 
410
  separation = separator.separate(audio)
411
  stems = [os.path.join(output_dir, file_name) for file_name in separation]
412
 
 
413
  if exclude_stems.strip():
414
  excluded = [s.strip().lower() for s in exclude_stems.split(',')]
415
  filtered_stems = [stem for stem in stems if not any(ex in os.path.basename(stem).lower() for ex in excluded)]
 
419
  logger.error(f"Separation failed: {e}")
420
  raise RuntimeError(f"Separation failed: {e}")
421
 
422
+ def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_tta, model_dir, output_dir, norm_thresh, amp_thresh, batch_size, ensemble_method, exclude_stems="", weights_str="", progress=gr.Progress()):
423
  """Perform ensemble processing on audio using multiple Roformer models."""
424
  if not audio or not model_keys:
425
  raise ValueError("Audio or models missing.")
426
 
 
427
  use_tta = use_tta == "True"
428
 
 
429
  if os.path.exists(output_dir):
430
  shutil.rmtree(output_dir)
431
  os.makedirs(output_dir, exist_ok=True)
 
436
  all_stems = []
437
  total_models = len(model_keys)
438
 
 
439
  for i, model_key in enumerate(model_keys):
440
  for category, models in ROFORMER_MODELS.items():
441
  if model_key in models:
 
460
  separation = separator.separate(audio)
461
  stems = [os.path.join(output_dir, file_name) for file_name in separation]
462
 
 
463
  if exclude_stems.strip():
464
  excluded = [s.strip().lower() for s in exclude_stems.split(',')]
465
  filtered_stems = [stem for stem in stems if not any(ex in os.path.basename(stem).lower() for ex in excluded)]
 
470
  if not all_stems:
471
  raise ValueError("No valid stems for ensemble after exclusion.")
472
 
473
+ weights = [float(w.strip()) for w in weights_str.split(',')] if weights_str.strip() else [1.0] * len(all_stems)
474
+ if len(weights) != len(all_stems):
475
  weights = [1.0] * len(all_stems)
476
 
 
477
  output_file = os.path.join(output_dir, f"{base_name}_ensemble_{ensemble_method}.{out_format}")
478
  ensemble_args = [
479
  "--files", *all_stems,
 
489
 
490
  def update_roformer_models(category):
491
  """Update Roformer model dropdown based on selected category."""
492
+ return gr.update(choices=list(ROFORMER_MODELS.get(category, {}).keys()) or [])
493
 
494
  def update_ensemble_models(category):
495
  """Update ensemble model dropdown based on selected category."""
496
+ return gr.update(choices=list(ROFORMER_MODELS.get(category, {}).keys()) or [])
497
 
498
  # Interface creation
499
  def create_interface():
 
509
  model_file_dir = gr.Textbox(value="/tmp/audio-separator-models/", label="πŸ“‚ Model Cache", placeholder="Path to model directory", interactive=True)
510
  output_dir = gr.Textbox(value="output", label="πŸ“€ Output Directory", placeholder="Where to save results", interactive=True)
511
  output_format = gr.Dropdown(value="wav", choices=OUTPUT_FORMATS, label="🎢 Output Format", interactive=True)
512
+ norm_threshold = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="πŸ”Š Normalization Threshold", interactive=True)
513
+ amp_threshold = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="πŸ“ˆ Amplification Threshold", interactive=True)
514
  batch_size = gr.Slider(1, 16, value=4, step=1, label="⚑ Batch Size", interactive=True)
515
 
516
  # Roformer Tab
 
575
  ensemble_category.change(update_ensemble_models, inputs=[ensemble_category], outputs=[ensemble_models])
576
  download_ensemble.click(fn=download_audio, inputs=[url_ensemble], outputs=[ensemble_audio])
577
  ensemble_button.click(
578
+ fn=auto_ensemble_process,
 
 
 
 
 
 
579
  inputs=[
580
  ensemble_audio, ensemble_models, ensemble_seg_size, ensemble_overlap,
581
  output_format, ensemble_use_tta, model_file_dir, output_dir,
 
594
 
595
  app = create_interface()
596
  try:
597
+ # For Hugging Face Spaces or local testing
598
+ app.launch(server_name="0.0.0.0", server_port=args.port, share=False)
599
  except Exception as e:
600
  logger.error(f"Failed to launch app: {e}")
601
  raise