Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -434,14 +434,18 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
|
|
| 434 |
logger.warning(f"Failed to clean up temporary file {temp_audio_path}: {e}")
|
| 435 |
|
| 436 |
@spaces.GPU
|
| 437 |
-
def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_tta, model_dir, output_dir, norm_thresh, amp_thresh, batch_size, ensemble_method, exclude_stems="", weights_str=""
|
| 438 |
-
|
|
|
|
|
|
|
| 439 |
if not audio or not model_keys:
|
| 440 |
raise ValueError("Audio or models missing.")
|
| 441 |
|
| 442 |
-
temp_audio_path = None
|
| 443 |
try:
|
| 444 |
-
#
|
|
|
|
|
|
|
| 445 |
if isinstance(audio, tuple):
|
| 446 |
sample_rate, data = audio
|
| 447 |
temp_audio_path = os.path.join("/tmp", "temp_audio.wav")
|
|
@@ -478,9 +482,9 @@ def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_
|
|
| 478 |
use_autocast=use_autocast,
|
| 479 |
mdxc_params={"segment_size": seg_size, "overlap": overlap, "use_tta": use_tta, "batch_size": batch_size}
|
| 480 |
)
|
| 481 |
-
|
| 482 |
separator.load_model(model_filename=model)
|
| 483 |
-
|
| 484 |
separation = separator.separate(audio)
|
| 485 |
stems = [os.path.join(output_dir, file_name) for file_name in separation]
|
| 486 |
|
|
@@ -490,6 +494,12 @@ def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_
|
|
| 490 |
all_stems.extend(filtered_stems)
|
| 491 |
else:
|
| 492 |
all_stems.extend(stems)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
|
| 494 |
if not all_stems:
|
| 495 |
raise ValueError("No valid stems for ensemble after exclusion.")
|
|
@@ -505,22 +515,21 @@ def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_
|
|
| 505 |
"--weights", *[str(w) for w in weights],
|
| 506 |
"--output", output_file
|
| 507 |
]
|
| 508 |
-
|
| 509 |
ensemble_files(ensemble_args)
|
| 510 |
|
| 511 |
-
|
| 512 |
return output_file, f"Ensemble completed with {ensemble_method}, excluded: {exclude_stems if exclude_stems else 'None'}"
|
| 513 |
except Exception as e:
|
| 514 |
logger.error(f"Ensemble failed: {e}")
|
| 515 |
raise RuntimeError(f"Ensemble failed: {e}")
|
| 516 |
finally:
|
| 517 |
-
# Clean up temporary file if it was created
|
| 518 |
if temp_audio_path and os.path.exists(temp_audio_path):
|
| 519 |
try:
|
| 520 |
os.remove(temp_audio_path)
|
| 521 |
-
logger.info(f"
|
| 522 |
except Exception as e:
|
| 523 |
-
logger.
|
| 524 |
|
| 525 |
def update_roformer_models(category):
|
| 526 |
"""Update Roformer model dropdown based on selected category."""
|
|
|
|
| 434 |
logger.warning(f"Failed to clean up temporary file {temp_audio_path}: {e}")
|
| 435 |
|
| 436 |
@spaces.GPU
|
| 437 |
+
def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_tta, model_dir, output_dir, norm_thresh, amp_thresh, batch_size, ensemble_method, exclude_stems="", weights_str=""):
|
| 438 |
+
import gc
|
| 439 |
+
import torch
|
| 440 |
+
|
| 441 |
if not audio or not model_keys:
|
| 442 |
raise ValueError("Audio or models missing.")
|
| 443 |
|
| 444 |
+
temp_audio_path = None
|
| 445 |
try:
|
| 446 |
+
# Limit to 2 models for testing
|
| 447 |
+
model_keys = model_keys[:2]
|
| 448 |
+
|
| 449 |
if isinstance(audio, tuple):
|
| 450 |
sample_rate, data = audio
|
| 451 |
temp_audio_path = os.path.join("/tmp", "temp_audio.wav")
|
|
|
|
| 482 |
use_autocast=use_autocast,
|
| 483 |
mdxc_params={"segment_size": seg_size, "overlap": overlap, "use_tta": use_tta, "batch_size": batch_size}
|
| 484 |
)
|
| 485 |
+
logger.info(f"Loading {model_key}")
|
| 486 |
separator.load_model(model_filename=model)
|
| 487 |
+
logger.info(f"Separating with {model_key}")
|
| 488 |
separation = separator.separate(audio)
|
| 489 |
stems = [os.path.join(output_dir, file_name) for file_name in separation]
|
| 490 |
|
|
|
|
| 494 |
all_stems.extend(filtered_stems)
|
| 495 |
else:
|
| 496 |
all_stems.extend(stems)
|
| 497 |
+
|
| 498 |
+
# Clean up model to free memory
|
| 499 |
+
separator = None
|
| 500 |
+
gc.collect()
|
| 501 |
+
if torch.cuda.is_available():
|
| 502 |
+
torch.cuda.empty_cache()
|
| 503 |
|
| 504 |
if not all_stems:
|
| 505 |
raise ValueError("No valid stems for ensemble after exclusion.")
|
|
|
|
| 515 |
"--weights", *[str(w) for w in weights],
|
| 516 |
"--output", output_file
|
| 517 |
]
|
| 518 |
+
logger.info("Running ensemble...")
|
| 519 |
ensemble_files(ensemble_args)
|
| 520 |
|
| 521 |
+
logger.info("Ensemble complete")
|
| 522 |
return output_file, f"Ensemble completed with {ensemble_method}, excluded: {exclude_stems if exclude_stems else 'None'}"
|
| 523 |
except Exception as e:
|
| 524 |
logger.error(f"Ensemble failed: {e}")
|
| 525 |
raise RuntimeError(f"Ensemble failed: {e}")
|
| 526 |
finally:
|
|
|
|
| 527 |
if temp_audio_path and os.path.exists(temp_audio_path):
|
| 528 |
try:
|
| 529 |
os.remove(temp_audio_path)
|
| 530 |
+
logger.info(f"Successfully cleaned up {temp_audio_path}")
|
| 531 |
except Exception as e:
|
| 532 |
+
logger.error(f"Failed to clean up {temp_audio_path}: {e}")
|
| 533 |
|
| 534 |
def update_roformer_models(category):
|
| 535 |
"""Update Roformer model dropdown based on selected category."""
|