ASesYusuf1 commited on
Commit
01781d2
·
verified ·
1 Parent(s): 1d87edf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -28
app.py CHANGED
@@ -485,6 +485,7 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
485
  def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, out_format="wav", use_tta="False", model_dir="/tmp/audio-separator-models/", output_dir="output", norm_thresh=0.9, amp_thresh=0.9, batch_size=1, ensemble_method="avg_wave", exclude_stems="", weights_str="", progress=gr.Progress(track_tqdm=True)):
486
  temp_audio_path = None
487
  extracted_audio_path = None
 
488
  start_time = time.time()
489
  try:
490
  if not audio:
@@ -505,7 +506,7 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
505
  extracted_audio_path = os.path.join("/tmp", f"extracted_audio_{os.path.basename(audio)}.wav")
506
  logger.info(f"Extracting audio from video file: {audio}")
507
  ffmpeg_command = [
508
- "ffmpeg", "-i", audio, "-vn", "-acodec", "pcm_s16le", "-ar", "44100", "-ac", "2",
509
  extracted_audio_path, "-y"
510
  ]
511
  try:
@@ -521,9 +522,21 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
521
  else:
522
  raise RuntimeError(f"Failed to extract audio from video: {error_message}")
523
 
 
524
  audio_data, sr = librosa.load(audio_to_process, sr=None, mono=False)
 
 
 
 
 
 
 
 
 
 
 
 
525
  duration = librosa.get_duration(y=audio_data, sr=sr)
526
- logger.info(f"Audio duration: {duration:.2f} seconds")
527
  dynamic_batch_size = max(1, min(4, 1 + int(900 / (duration + 1)) - len(model_keys) // 2))
528
  logger.info(f"Using batch size: {dynamic_batch_size} for {len(model_keys)} models, duration {duration:.2f}s")
529
 
@@ -555,13 +568,17 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
555
  permanent_output_dir = os.path.join(output_dir, "permanent_stems")
556
  os.makedirs(permanent_output_dir, exist_ok=True)
557
 
558
- # Check if all models have been processed
559
- if state["current_model_idx"] >= len(model_keys):
 
 
 
 
 
560
  logger.info("All models processed, running ensemble...")
561
  progress(0.9, desc="Running ensemble...")
562
 
563
  excluded_stems_list = [s.strip().lower() for s in exclude_stems.split(',')] if exclude_stems.strip() else []
564
- all_stems = []
565
  for model_key, stems_dict in state["model_outputs"].items():
566
  for stem_type in ["vocals", "other"]:
567
  if stems_dict[stem_type]:
@@ -590,7 +607,6 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
590
  if result is None or not os.path.exists(output_file):
591
  raise RuntimeError(f"Ensemble failed, output file not created: {output_file}")
592
 
593
- # Reset state after ensemble
594
  state["current_model_idx"] = 0
595
  state["current_audio"] = None
596
  state["processed_stems"] = []
@@ -607,12 +623,10 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
607
  status += "</ul>"
608
  return output_file, status, file_list, state
609
 
610
- # Process the next model
611
- model_key = model_keys[state["current_model_idx"]]
612
- logger.info(f"Processing model {state['current_model_idx'] + 1}/{len(model_keys)}: {model_key}")
613
  progress(0.1, desc=f"Processing model {model_key}...")
614
 
615
- model_cache = {}
616
  with torch.no_grad():
617
  for attempt in range(max_retries + 1):
618
  try:
@@ -691,13 +705,12 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
691
  elapsed = time.time() - start_time
692
  logger.info(f"Model {model_key} completed in {elapsed:.2f}s")
693
 
694
- file_list = state["processed_stems"]
695
- status = f"Model {model_key} (Model {state['current_model_idx']}/{len(model_keys)}) completed in {elapsed:.2f}s<br>"
696
  if state["current_model_idx"] >= len(model_keys):
697
- status += "All models processed. Click 'Run Ensemble!' to combine the stems.<br>"
698
- else:
699
- status += "Click 'Run Ensemble!' to process the next model.<br>"
700
- status += "Processed stems:<ul>"
 
701
  for file in file_list:
702
  file_name = os.path.basename(file)
703
  status += f"<li><a href='file={file}' download>{file_name}</a></li>"
@@ -710,18 +723,13 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
710
  raise RuntimeError(error_msg)
711
 
712
  finally:
713
- if temp_audio_path and os.path.exists(temp_audio_path):
714
- try:
715
- os.remove(temp_audio_path)
716
- logger.info(f"Temporary file deleted: {temp_audio_path}")
717
- except Exception as e:
718
- logger.warning(f"Failed to delete temporary file {temp_audio_path}: {e}")
719
- if extracted_audio_path and os.path.exists(extracted_audio_path):
720
- try:
721
- os.remove(extracted_audio_path)
722
- logger.info(f"Extracted audio file deleted: {extracted_audio_path}")
723
- except Exception as e:
724
- logger.warning(f"Failed to delete extracted audio file {extracted_audio_path}: {e}")
725
  if torch.cuda.is_available():
726
  torch.cuda.empty_cache()
727
  logger.info("GPU memory cleared")
 
485
  def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, out_format="wav", use_tta="False", model_dir="/tmp/audio-separator-models/", output_dir="output", norm_thresh=0.9, amp_thresh=0.9, batch_size=1, ensemble_method="avg_wave", exclude_stems="", weights_str="", progress=gr.Progress(track_tqdm=True)):
486
  temp_audio_path = None
487
  extracted_audio_path = None
488
+ resampled_audio_path = None
489
  start_time = time.time()
490
  try:
491
  if not audio:
 
506
  extracted_audio_path = os.path.join("/tmp", f"extracted_audio_{os.path.basename(audio)}.wav")
507
  logger.info(f"Extracting audio from video file: {audio}")
508
  ffmpeg_command = [
509
+ "ffmpeg", "-i", audio, "-vn", "-acodec", "pcm_s16le", "-ar", "48000", "-ac", "2",
510
  extracted_audio_path, "-y"
511
  ]
512
  try:
 
522
  else:
523
  raise RuntimeError(f"Failed to extract audio from video: {error_message}")
524
 
525
+ # Load audio and resample to 48 kHz
526
  audio_data, sr = librosa.load(audio_to_process, sr=None, mono=False)
527
+ logger.info(f"Original sample rate: {sr} Hz, Audio duration: {librosa.get_duration(y=audio_data, sr=sr):.2f} seconds")
528
+ if sr != 48000:
529
+ logger.info(f"Resampling audio from {sr} Hz to 48000 Hz")
530
+ resampled_audio_path = os.path.join("/tmp", f"resampled_audio_{os.path.basename(audio)}.wav")
531
+ waveform, _ = torchaudio.load(audio_to_process)
532
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=48000)
533
+ resampled_waveform = resampler(waveform)
534
+ torchaudio.save(resampled_audio_path, resampled_waveform, 48000)
535
+ audio_to_process = resampled_audio_path
536
+ audio_data, sr = librosa.load(audio_to_process, sr=None, mono=False)
537
+ logger.info(f"Resampled audio saved to: {resampled_audio_path}, new sample rate: {sr} Hz")
538
+
539
  duration = librosa.get_duration(y=audio_data, sr=sr)
 
540
  dynamic_batch_size = max(1, min(4, 1 + int(900 / (duration + 1)) - len(model_keys) // 2))
541
  logger.info(f"Using batch size: {dynamic_batch_size} for {len(model_keys)} models, duration {duration:.2f}s")
542
 
 
568
  permanent_output_dir = os.path.join(output_dir, "permanent_stems")
569
  os.makedirs(permanent_output_dir, exist_ok=True)
570
 
571
+ model_cache = {}
572
+ all_stems = []
573
+ total_tasks = len(model_keys)
574
+ current_idx = state["current_model_idx"]
575
+ logger.info(f"Current model index: {current_idx}, total models: {len(model_keys)}")
576
+
577
+ if current_idx >= len(model_keys):
578
  logger.info("All models processed, running ensemble...")
579
  progress(0.9, desc="Running ensemble...")
580
 
581
  excluded_stems_list = [s.strip().lower() for s in exclude_stems.split(',')] if exclude_stems.strip() else []
 
582
  for model_key, stems_dict in state["model_outputs"].items():
583
  for stem_type in ["vocals", "other"]:
584
  if stems_dict[stem_type]:
 
607
  if result is None or not os.path.exists(output_file):
608
  raise RuntimeError(f"Ensemble failed, output file not created: {output_file}")
609
 
 
610
  state["current_model_idx"] = 0
611
  state["current_audio"] = None
612
  state["processed_stems"] = []
 
623
  status += "</ul>"
624
  return output_file, status, file_list, state
625
 
626
+ model_key = model_keys[current_idx]
627
+ logger.info(f"Processing model {current_idx + 1}/{len(model_keys)}: {model_key}")
 
628
  progress(0.1, desc=f"Processing model {model_key}...")
629
 
 
630
  with torch.no_grad():
631
  for attempt in range(max_retries + 1):
632
  try:
 
705
  elapsed = time.time() - start_time
706
  logger.info(f"Model {model_key} completed in {elapsed:.2f}s")
707
 
 
 
708
  if state["current_model_idx"] >= len(model_keys):
709
+ logger.info("Last model processed, running ensemble immediately...")
710
+ return auto_ensemble_process(audio, model_keys, state, seg_size, overlap, out_format, use_tta, model_dir, output_dir, norm_thresh, amp_thresh, batch_size, ensemble_method, exclude_stems, weights_str, progress)
711
+
712
+ file_list = state["processed_stems"]
713
+ status = f"Model {model_key} (Model {current_idx + 1}/{len(model_keys)}) completed in {elapsed:.2f}s<br>Click 'Run Ensemble!' to process the next model.<br>Processed stems:<ul>"
714
  for file in file_list:
715
  file_name = os.path.basename(file)
716
  status += f"<li><a href='file={file}' download>{file_name}</a></li>"
 
723
  raise RuntimeError(error_msg)
724
 
725
  finally:
726
+ for temp_file in [temp_audio_path, extracted_audio_path, resampled_audio_path]:
727
+ if temp_file and os.path.exists(temp_file):
728
+ try:
729
+ os.remove(temp_file)
730
+ logger.info(f"Temporary file deleted: {temp_file}")
731
+ except Exception as e:
732
+ logger.warning(f"Failed to delete temporary file {temp_file}: {e}")
 
 
 
 
 
733
  if torch.cuda.is_available():
734
  torch.cuda.empty_cache()
735
  logger.info("GPU memory cleared")