ASesYusuf1 commited on
Commit
7b79193
·
verified ·
1 Parent(s): aecae1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -96
app.py CHANGED
@@ -369,29 +369,30 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
369
  if not audio:
370
  raise ValueError("No audio file provided.")
371
 
372
- # If audio is a tuple (sample_rate, data), save it as a temporary file
373
- if isinstance(audio, tuple):
374
- sample_rate, data = audio
375
- temp_audio_path = os.path.join("/tmp", "temp_audio.wav")
376
- scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
377
- audio = temp_audio_path
378
-
379
- override_seg_size = override_seg_size == "True"
380
-
381
- if os.path.exists(output_dir):
382
- shutil.rmtree(output_dir)
383
- os.makedirs(output_dir, exist_ok=True)
384
-
385
- base_name = os.path.splitext(os.path.basename(audio))[0]
386
- for category, models in ROFORMER_MODELS.items():
387
- if model_key in models:
388
- model = models[model_key]
389
- break
390
- else:
391
- raise ValueError(f"Model '{model_key}' not found.")
392
-
393
- logger.info(f"Separating {base_name} with {model_key} on {device}")
394
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  separator = Separator(
396
  log_level=logging.INFO,
397
  model_file_dir=model_dir,
@@ -417,9 +418,13 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
417
  logger.error(f"Separation failed: {e}")
418
  raise RuntimeError(f"Separation failed: {e}")
419
  finally:
420
- # Clean up temporary file if created
421
- if isinstance(audio, tuple) and os.path.exists(temp_audio_path):
422
- os.remove(temp_audio_path)
 
 
 
 
423
 
424
  @spaces.GPU
425
  def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_tta, model_dir, output_dir, norm_thresh, amp_thresh, batch_size, ensemble_method, exclude_stems="", weights_str="", progress=gr.Progress()):
@@ -427,80 +432,89 @@ def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_
427
  if not audio or not model_keys:
428
  raise ValueError("Audio or models missing.")
429
 
430
- # If audio is a tuple (sample_rate, data), save it as a temporary file
431
- if isinstance(audio, tuple):
432
- sample_rate, data = audio
433
- temp_audio_path = os.path.join("/tmp", "temp_audio.wav")
434
- scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
435
- audio = temp_audio_path
436
-
437
- use_tta = use_tta == "True"
438
-
439
- if os.path.exists(output_dir):
440
- shutil.rmtree(output_dir)
441
- os.makedirs(output_dir, exist_ok=True)
442
-
443
- base_name = os.path.splitext(os.path.basename(audio))[0]
444
- logger.info(f"Ensemble for {base_name} with {model_keys} on {device}")
445
-
446
- all_stems = []
447
- total_models = len(model_keys)
448
-
449
- for i, model_key in enumerate(model_keys):
450
- for category, models in ROFORMER_MODELS.items():
451
- if model_key in models:
452
- model = models[model_key]
453
- break
454
- else:
455
- continue
456
 
457
- separator = Separator(
458
- log_level=logging.INFO,
459
- model_file_dir=model_dir,
460
- output_dir=output_dir,
461
- output_format=out_format,
462
- normalization_threshold=norm_thresh,
463
- amplification_threshold=amp_thresh,
464
- use_autocast=use_autocast,
465
- mdxc_params={"segment_size": seg_size, "overlap": overlap, "use_tta": use_tta, "batch_size": batch_size}
466
- )
467
- progress(0.1 + (0.4 / total_models) * i, desc=f"Loading {model_key}")
468
- separator.load_model(model_filename=model)
469
- progress(0.5 + (0.4 / total_models) * i, desc=f"Separating with {model_key}")
470
- separation = separator.separate(audio)
471
- stems = [os.path.join(output_dir, file_name) for file_name in separation]
472
 
473
- if exclude_stems.strip():
474
- excluded = [s.strip().lower() for s in exclude_stems.split(',')]
475
- filtered_stems = [stem for stem in stems if not any(ex in os.path.basename(stem).lower() for ex in excluded)]
476
- all_stems.extend(filtered_stems)
477
- else:
478
- all_stems.extend(stems)
479
-
480
- if not all_stems:
481
- raise ValueError("No valid stems for ensemble after exclusion.")
482
-
483
- weights = [float(w.strip()) for w in weights_str.split(',')] if weights_str.strip() else [1.0] * len(all_stems)
484
- if len(weights) != len(all_stems):
485
- weights = [1.0] * len(all_stems)
486
-
487
- output_file = os.path.join(output_dir, f"{base_name}_ensemble_{ensemble_method}.{out_format}")
488
- ensemble_args = [
489
- "--files", *all_stems,
490
- "--type", ensemble_method,
491
- "--weights", *[str(w) for w in weights],
492
- "--output", output_file
493
- ]
494
- progress(0.9, desc="Running ensemble...")
495
- ensemble_files(ensemble_args)
496
-
497
- progress(1.0, desc="Ensemble complete")
498
- return output_file, f"Ensemble completed with {ensemble_method}, excluded: {exclude_stems if exclude_stems else 'None'}"
499
- finally:
500
- # Clean up temporary file if created
501
- if isinstance(audio, tuple) and os.path.exists(temp_audio_path):
502
- os.remove(temp_audio_path)
 
 
 
503
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  def update_roformer_models(category):
505
  """Update Roformer model dropdown based on selected category."""
506
  choices = list(ROFORMER_MODELS.get(category, {}).keys()) or []
 
369
  if not audio:
370
  raise ValueError("No audio file provided.")
371
 
372
+ temp_audio_path = None # Initialize to None to avoid undefined variable in finally
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  try:
374
+ # If audio is a tuple (sample_rate, data), save it as a temporary file
375
+ if isinstance(audio, tuple):
376
+ sample_rate, data = audio
377
+ temp_audio_path = os.path.join("/tmp", "temp_audio.wav")
378
+ scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
379
+ audio = temp_audio_path
380
+
381
+ override_seg_size = override_seg_size == "True"
382
+
383
+ if os.path.exists(output_dir):
384
+ shutil.rmtree(output_dir)
385
+ os.makedirs(output_dir, exist_ok=True)
386
+
387
+ base_name = os.path.splitext(os.path.basename(audio))[0]
388
+ for category, models in ROFORMER_MODELS.items():
389
+ if model_key in models:
390
+ model = models[model_key]
391
+ break
392
+ else:
393
+ raise ValueError(f"Model '{model_key}' not found.")
394
+
395
+ logger.info(f"Separating {base_name} with {model_key} on {device}")
396
  separator = Separator(
397
  log_level=logging.INFO,
398
  model_file_dir=model_dir,
 
418
  logger.error(f"Separation failed: {e}")
419
  raise RuntimeError(f"Separation failed: {e}")
420
  finally:
421
+ # Clean up temporary file if it was created
422
+ if temp_audio_path and os.path.exists(temp_audio_path):
423
+ try:
424
+ os.remove(temp_audio_path)
425
+ logger.info(f"Cleaned up temporary file: {temp_audio_path}")
426
+ except Exception as e:
427
+ logger.warning(f"Failed to clean up temporary file {temp_audio_path}: {e}")
428
 
429
  @spaces.GPU
430
  def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_tta, model_dir, output_dir, norm_thresh, amp_thresh, batch_size, ensemble_method, exclude_stems="", weights_str="", progress=gr.Progress()):
 
432
  if not audio or not model_keys:
433
  raise ValueError("Audio or models missing.")
434
 
435
+ temp_audio_path = None # Initialize to None to avoid undefined variable in finally
436
+ try:
437
+ # If audio is a tuple (sample_rate, data), save it as a temporary file
438
+ if isinstance(audio, tuple):
439
+ sample_rate, data = audio
440
+ temp_audio_path = os.path.join("/tmp", "temp_audio.wav")
441
+ scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
442
+ audio = temp_audio_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
+ use_tta = use_tta == "True"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
+ if os.path.exists(output_dir):
447
+ shutil.rmtree(output_dir)
448
+ os.makedirs(output_dir, exist_ok=True)
449
+
450
+ base_name = os.path.splitext(os.path.basename(audio))[0]
451
+ logger.info(f"Ensemble for {base_name} with {model_keys} on {device}")
452
+
453
+ all_stems = []
454
+ total_models = len(model_keys)
455
+
456
+ for i, model_key in enumerate(model_keys):
457
+ for category, models in ROFORMER_MODELS.items():
458
+ if model_key in models:
459
+ model = models[model_key]
460
+ break
461
+ else:
462
+ continue
463
+
464
+ separator = Separator(
465
+ log_level=logging.INFO,
466
+ model_file_dir=model_dir,
467
+ output_dir=output_dir,
468
+ output_format=out_format,
469
+ normalization_threshold=norm_thresh,
470
+ amplification_threshold=amp_thresh,
471
+ use_autocast=use_autocast,
472
+ mdxc_params={"segment_size": seg_size, "overlap": overlap, "use_tta": use_tta, "batch_size": batch_size}
473
+ )
474
+ progress(0.1 + (0.4 / total_models) * i, desc=f"Loading {model_key}")
475
+ separator.load_model(model_filename=model)
476
+ progress(0.5 + (0.4 / total_models) * i, desc=f"Separating with {model_key}")
477
+ separation = separator.separate(audio)
478
+ stems = [os.path.join(output_dir, file_name) for file_name in separation]
479
 
480
+ if exclude_stems.strip():
481
+ excluded = [s.strip().lower() for s in exclude_stems.split(',')]
482
+ filtered_stems = [stem for stem in stems if not any(ex in os.path.basename(stem).lower() for ex in excluded)]
483
+ all_stems.extend(filtered_stems)
484
+ else:
485
+ all_stems.extend(stems)
486
+
487
+ if not all_stems:
488
+ raise ValueError("No valid stems for ensemble after exclusion.")
489
+
490
+ weights = [float(w.strip()) for w in weights_str.split(',')] if weights_str.strip() else [1.0] * len(all_stems)
491
+ if len(weights) != len(all_stems):
492
+ weights = [1.0] * len(all_stems)
493
+
494
+ output_file = os.path.join(output_dir, f"{base_name}_ensemble_{ensemble_method}.{out_format}")
495
+ ensemble_args = [
496
+ "--files", *all_stems,
497
+ "--type", ensemble_method,
498
+ "--weights", *[str(w) for w in weights],
499
+ "--output", output_file
500
+ ]
501
+ progress(0.9, desc="Running ensemble...")
502
+ ensemble_files(ensemble_args)
503
+
504
+ progress(1.0, desc="Ensemble complete")
505
+ return output_file, f"Ensemble completed with {ensemble_method}, excluded: {exclude_stems if exclude_stems else 'None'}"
506
+ except Exception as e:
507
+ logger.error(f"Ensemble failed: {e}")
508
+ raise RuntimeError(f"Ensemble failed: {e}")
509
+ finally:
510
+ # Clean up temporary file if it was created
511
+ if temp_audio_path and os.path.exists(temp_audio_path):
512
+ try:
513
+ os.remove(temp_audio_path)
514
+ logger.info(f"Cleaned up temporary file: {temp_audio_path}")
515
+ except Exception as e:
516
+ logger.warning(f"Failed to clean up temporary file {temp_audio_path}: {e}")
517
+
518
  def update_roformer_models(category):
519
  """Update Roformer model dropdown based on selected category."""
520
  choices = list(ROFORMER_MODELS.get(category, {}).keys()) or []