Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -485,6 +485,7 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
|
|
| 485 |
def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, out_format="wav", use_tta="False", model_dir="/tmp/audio-separator-models/", output_dir="output", norm_thresh=0.9, amp_thresh=0.9, batch_size=1, ensemble_method="avg_wave", exclude_stems="", weights_str="", progress=gr.Progress(track_tqdm=True)):
|
| 486 |
temp_audio_path = None
|
| 487 |
extracted_audio_path = None
|
|
|
|
| 488 |
start_time = time.time()
|
| 489 |
try:
|
| 490 |
if not audio:
|
|
@@ -505,7 +506,7 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 505 |
extracted_audio_path = os.path.join("/tmp", f"extracted_audio_{os.path.basename(audio)}.wav")
|
| 506 |
logger.info(f"Extracting audio from video file: {audio}")
|
| 507 |
ffmpeg_command = [
|
| 508 |
-
"ffmpeg", "-i", audio, "-vn", "-acodec", "pcm_s16le", "-ar", "
|
| 509 |
extracted_audio_path, "-y"
|
| 510 |
]
|
| 511 |
try:
|
|
@@ -521,9 +522,21 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 521 |
else:
|
| 522 |
raise RuntimeError(f"Failed to extract audio from video: {error_message}")
|
| 523 |
|
|
|
|
| 524 |
audio_data, sr = librosa.load(audio_to_process, sr=None, mono=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
duration = librosa.get_duration(y=audio_data, sr=sr)
|
| 526 |
-
logger.info(f"Audio duration: {duration:.2f} seconds")
|
| 527 |
dynamic_batch_size = max(1, min(4, 1 + int(900 / (duration + 1)) - len(model_keys) // 2))
|
| 528 |
logger.info(f"Using batch size: {dynamic_batch_size} for {len(model_keys)} models, duration {duration:.2f}s")
|
| 529 |
|
|
@@ -555,13 +568,17 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 555 |
permanent_output_dir = os.path.join(output_dir, "permanent_stems")
|
| 556 |
os.makedirs(permanent_output_dir, exist_ok=True)
|
| 557 |
|
| 558 |
-
|
| 559 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 560 |
logger.info("All models processed, running ensemble...")
|
| 561 |
progress(0.9, desc="Running ensemble...")
|
| 562 |
|
| 563 |
excluded_stems_list = [s.strip().lower() for s in exclude_stems.split(',')] if exclude_stems.strip() else []
|
| 564 |
-
all_stems = []
|
| 565 |
for model_key, stems_dict in state["model_outputs"].items():
|
| 566 |
for stem_type in ["vocals", "other"]:
|
| 567 |
if stems_dict[stem_type]:
|
|
@@ -590,7 +607,6 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 590 |
if result is None or not os.path.exists(output_file):
|
| 591 |
raise RuntimeError(f"Ensemble failed, output file not created: {output_file}")
|
| 592 |
|
| 593 |
-
# Reset state after ensemble
|
| 594 |
state["current_model_idx"] = 0
|
| 595 |
state["current_audio"] = None
|
| 596 |
state["processed_stems"] = []
|
|
@@ -607,12 +623,10 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 607 |
status += "</ul>"
|
| 608 |
return output_file, status, file_list, state
|
| 609 |
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
logger.info(f"Processing model {state['current_model_idx'] + 1}/{len(model_keys)}: {model_key}")
|
| 613 |
progress(0.1, desc=f"Processing model {model_key}...")
|
| 614 |
|
| 615 |
-
model_cache = {}
|
| 616 |
with torch.no_grad():
|
| 617 |
for attempt in range(max_retries + 1):
|
| 618 |
try:
|
|
@@ -691,13 +705,12 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 691 |
elapsed = time.time() - start_time
|
| 692 |
logger.info(f"Model {model_key} completed in {elapsed:.2f}s")
|
| 693 |
|
| 694 |
-
file_list = state["processed_stems"]
|
| 695 |
-
status = f"Model {model_key} (Model {state['current_model_idx']}/{len(model_keys)}) completed in {elapsed:.2f}s<br>"
|
| 696 |
if state["current_model_idx"] >= len(model_keys):
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
|
|
|
| 701 |
for file in file_list:
|
| 702 |
file_name = os.path.basename(file)
|
| 703 |
status += f"<li><a href='file={file}' download>{file_name}</a></li>"
|
|
@@ -710,18 +723,13 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 710 |
raise RuntimeError(error_msg)
|
| 711 |
|
| 712 |
finally:
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
try:
|
| 721 |
-
os.remove(extracted_audio_path)
|
| 722 |
-
logger.info(f"Extracted audio file deleted: {extracted_audio_path}")
|
| 723 |
-
except Exception as e:
|
| 724 |
-
logger.warning(f"Failed to delete extracted audio file {extracted_audio_path}: {e}")
|
| 725 |
if torch.cuda.is_available():
|
| 726 |
torch.cuda.empty_cache()
|
| 727 |
logger.info("GPU memory cleared")
|
|
|
|
| 485 |
def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, out_format="wav", use_tta="False", model_dir="/tmp/audio-separator-models/", output_dir="output", norm_thresh=0.9, amp_thresh=0.9, batch_size=1, ensemble_method="avg_wave", exclude_stems="", weights_str="", progress=gr.Progress(track_tqdm=True)):
|
| 486 |
temp_audio_path = None
|
| 487 |
extracted_audio_path = None
|
| 488 |
+
resampled_audio_path = None
|
| 489 |
start_time = time.time()
|
| 490 |
try:
|
| 491 |
if not audio:
|
|
|
|
| 506 |
extracted_audio_path = os.path.join("/tmp", f"extracted_audio_{os.path.basename(audio)}.wav")
|
| 507 |
logger.info(f"Extracting audio from video file: {audio}")
|
| 508 |
ffmpeg_command = [
|
| 509 |
+
"ffmpeg", "-i", audio, "-vn", "-acodec", "pcm_s16le", "-ar", "48000", "-ac", "2",
|
| 510 |
extracted_audio_path, "-y"
|
| 511 |
]
|
| 512 |
try:
|
|
|
|
| 522 |
else:
|
| 523 |
raise RuntimeError(f"Failed to extract audio from video: {error_message}")
|
| 524 |
|
| 525 |
+
# Load audio and resample to 48 kHz
|
| 526 |
audio_data, sr = librosa.load(audio_to_process, sr=None, mono=False)
|
| 527 |
+
logger.info(f"Original sample rate: {sr} Hz, Audio duration: {librosa.get_duration(y=audio_data, sr=sr):.2f} seconds")
|
| 528 |
+
if sr != 48000:
|
| 529 |
+
logger.info(f"Resampling audio from {sr} Hz to 48000 Hz")
|
| 530 |
+
resampled_audio_path = os.path.join("/tmp", f"resampled_audio_{os.path.basename(audio)}.wav")
|
| 531 |
+
waveform, _ = torchaudio.load(audio_to_process)
|
| 532 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=48000)
|
| 533 |
+
resampled_waveform = resampler(waveform)
|
| 534 |
+
torchaudio.save(resampled_audio_path, resampled_waveform, 48000)
|
| 535 |
+
audio_to_process = resampled_audio_path
|
| 536 |
+
audio_data, sr = librosa.load(audio_to_process, sr=None, mono=False)
|
| 537 |
+
logger.info(f"Resampled audio saved to: {resampled_audio_path}, new sample rate: {sr} Hz")
|
| 538 |
+
|
| 539 |
duration = librosa.get_duration(y=audio_data, sr=sr)
|
|
|
|
| 540 |
dynamic_batch_size = max(1, min(4, 1 + int(900 / (duration + 1)) - len(model_keys) // 2))
|
| 541 |
logger.info(f"Using batch size: {dynamic_batch_size} for {len(model_keys)} models, duration {duration:.2f}s")
|
| 542 |
|
|
|
|
| 568 |
permanent_output_dir = os.path.join(output_dir, "permanent_stems")
|
| 569 |
os.makedirs(permanent_output_dir, exist_ok=True)
|
| 570 |
|
| 571 |
+
model_cache = {}
|
| 572 |
+
all_stems = []
|
| 573 |
+
total_tasks = len(model_keys)
|
| 574 |
+
current_idx = state["current_model_idx"]
|
| 575 |
+
logger.info(f"Current model index: {current_idx}, total models: {len(model_keys)}")
|
| 576 |
+
|
| 577 |
+
if current_idx >= len(model_keys):
|
| 578 |
logger.info("All models processed, running ensemble...")
|
| 579 |
progress(0.9, desc="Running ensemble...")
|
| 580 |
|
| 581 |
excluded_stems_list = [s.strip().lower() for s in exclude_stems.split(',')] if exclude_stems.strip() else []
|
|
|
|
| 582 |
for model_key, stems_dict in state["model_outputs"].items():
|
| 583 |
for stem_type in ["vocals", "other"]:
|
| 584 |
if stems_dict[stem_type]:
|
|
|
|
| 607 |
if result is None or not os.path.exists(output_file):
|
| 608 |
raise RuntimeError(f"Ensemble failed, output file not created: {output_file}")
|
| 609 |
|
|
|
|
| 610 |
state["current_model_idx"] = 0
|
| 611 |
state["current_audio"] = None
|
| 612 |
state["processed_stems"] = []
|
|
|
|
| 623 |
status += "</ul>"
|
| 624 |
return output_file, status, file_list, state
|
| 625 |
|
| 626 |
+
model_key = model_keys[current_idx]
|
| 627 |
+
logger.info(f"Processing model {current_idx + 1}/{len(model_keys)}: {model_key}")
|
|
|
|
| 628 |
progress(0.1, desc=f"Processing model {model_key}...")
|
| 629 |
|
|
|
|
| 630 |
with torch.no_grad():
|
| 631 |
for attempt in range(max_retries + 1):
|
| 632 |
try:
|
|
|
|
| 705 |
elapsed = time.time() - start_time
|
| 706 |
logger.info(f"Model {model_key} completed in {elapsed:.2f}s")
|
| 707 |
|
|
|
|
|
|
|
| 708 |
if state["current_model_idx"] >= len(model_keys):
|
| 709 |
+
logger.info("Last model processed, running ensemble immediately...")
|
| 710 |
+
return auto_ensemble_process(audio, model_keys, state, seg_size, overlap, out_format, use_tta, model_dir, output_dir, norm_thresh, amp_thresh, batch_size, ensemble_method, exclude_stems, weights_str, progress)
|
| 711 |
+
|
| 712 |
+
file_list = state["processed_stems"]
|
| 713 |
+
status = f"Model {model_key} (Model {current_idx + 1}/{len(model_keys)}) completed in {elapsed:.2f}s<br>Click 'Run Ensemble!' to process the next model.<br>Processed stems:<ul>"
|
| 714 |
for file in file_list:
|
| 715 |
file_name = os.path.basename(file)
|
| 716 |
status += f"<li><a href='file={file}' download>{file_name}</a></li>"
|
|
|
|
| 723 |
raise RuntimeError(error_msg)
|
| 724 |
|
| 725 |
finally:
|
| 726 |
+
for temp_file in [temp_audio_path, extracted_audio_path, resampled_audio_path]:
|
| 727 |
+
if temp_file and os.path.exists(temp_file):
|
| 728 |
+
try:
|
| 729 |
+
os.remove(temp_file)
|
| 730 |
+
logger.info(f"Temporary file deleted: {temp_file}")
|
| 731 |
+
except Exception as e:
|
| 732 |
+
logger.warning(f"Failed to delete temporary file {temp_file}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 733 |
if torch.cuda.is_available():
|
| 734 |
torch.cuda.empty_cache()
|
| 735 |
logger.info("GPU memory cleared")
|