ASesYusuf1 commited on
Commit
1d87edf
·
verified ·
1 Parent(s): 394e662

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -15
app.py CHANGED
@@ -555,17 +555,13 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
555
  permanent_output_dir = os.path.join(output_dir, "permanent_stems")
556
  os.makedirs(permanent_output_dir, exist_ok=True)
557
 
558
- model_cache = {}
559
- all_stems = []
560
- total_tasks = len(model_keys)
561
- current_idx = state["current_model_idx"]
562
- logger.info(f"Current model index: {current_idx}, total models: {len(model_keys)}")
563
-
564
- if current_idx >= len(model_keys):
565
  logger.info("All models processed, running ensemble...")
566
  progress(0.9, desc="Running ensemble...")
567
 
568
  excluded_stems_list = [s.strip().lower() for s in exclude_stems.split(',')] if exclude_stems.strip() else []
 
569
  for model_key, stems_dict in state["model_outputs"].items():
570
  for stem_type in ["vocals", "other"]:
571
  if stems_dict[stem_type]:
@@ -594,6 +590,7 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
594
  if result is None or not os.path.exists(output_file):
595
  raise RuntimeError(f"Ensemble failed, output file not created: {output_file}")
596
 
 
597
  state["current_model_idx"] = 0
598
  state["current_audio"] = None
599
  state["processed_stems"] = []
@@ -610,10 +607,12 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
610
  status += "</ul>"
611
  return output_file, status, file_list, state
612
 
613
- model_key = model_keys[current_idx]
614
- logger.info(f"Processing model {current_idx + 1}/{len(model_keys)}: {model_key}")
 
615
  progress(0.1, desc=f"Processing model {model_key}...")
616
 
 
617
  with torch.no_grad():
618
  for attempt in range(max_retries + 1):
619
  try:
@@ -692,12 +691,13 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
692
  elapsed = time.time() - start_time
693
  logger.info(f"Model {model_key} completed in {elapsed:.2f}s")
694
 
695
- if state["current_model_idx"] >= len(model_keys):
696
- logger.info("Last model processed, running ensemble immediately...")
697
- return auto_ensemble_process(audio, model_keys, state, seg_size, overlap, out_format, use_tta, model_dir, output_dir, norm_thresh, amp_thresh, batch_size, ensemble_method, exclude_stems, weights_str, progress)
698
-
699
  file_list = state["processed_stems"]
700
- status = f"Model {model_key} (Model {current_idx + 1}/{len(model_keys)}) completed in {elapsed:.2f}s<br>Click 'Run Ensemble!' to process the next model.<br>Processed stems:<ul>"
 
 
 
 
 
701
  for file in file_list:
702
  file_name = os.path.basename(file)
703
  status += f"<li><a href='file={file}' download>{file_name}</a></li>"
@@ -725,7 +725,7 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
725
  if torch.cuda.is_available():
726
  torch.cuda.empty_cache()
727
  logger.info("GPU memory cleared")
728
-
729
  def update_roformer_models(category):
730
  choices = list(ROFORMER_MODELS.get(category, {}).keys()) or []
731
  logger.debug(f"Updating roformer models for category {category}: {choices}")
 
555
  permanent_output_dir = os.path.join(output_dir, "permanent_stems")
556
  os.makedirs(permanent_output_dir, exist_ok=True)
557
 
558
+ # Check if all models have been processed
559
+ if state["current_model_idx"] >= len(model_keys):
 
 
 
 
 
560
  logger.info("All models processed, running ensemble...")
561
  progress(0.9, desc="Running ensemble...")
562
 
563
  excluded_stems_list = [s.strip().lower() for s in exclude_stems.split(',')] if exclude_stems.strip() else []
564
+ all_stems = []
565
  for model_key, stems_dict in state["model_outputs"].items():
566
  for stem_type in ["vocals", "other"]:
567
  if stems_dict[stem_type]:
 
590
  if result is None or not os.path.exists(output_file):
591
  raise RuntimeError(f"Ensemble failed, output file not created: {output_file}")
592
 
593
+ # Reset state after ensemble
594
  state["current_model_idx"] = 0
595
  state["current_audio"] = None
596
  state["processed_stems"] = []
 
607
  status += "</ul>"
608
  return output_file, status, file_list, state
609
 
610
+ # Process the next model
611
+ model_key = model_keys[state["current_model_idx"]]
612
+ logger.info(f"Processing model {state['current_model_idx'] + 1}/{len(model_keys)}: {model_key}")
613
  progress(0.1, desc=f"Processing model {model_key}...")
614
 
615
+ model_cache = {}
616
  with torch.no_grad():
617
  for attempt in range(max_retries + 1):
618
  try:
 
691
  elapsed = time.time() - start_time
692
  logger.info(f"Model {model_key} completed in {elapsed:.2f}s")
693
 
 
 
 
 
694
  file_list = state["processed_stems"]
695
+ status = f"Model {model_key} (Model {state['current_model_idx']}/{len(model_keys)}) completed in {elapsed:.2f}s<br>"
696
+ if state["current_model_idx"] >= len(model_keys):
697
+ status += "All models processed. Click 'Run Ensemble!' to combine the stems.<br>"
698
+ else:
699
+ status += "Click 'Run Ensemble!' to process the next model.<br>"
700
+ status += "Processed stems:<ul>"
701
  for file in file_list:
702
  file_name = os.path.basename(file)
703
  status += f"<li><a href='file={file}' download>{file_name}</a></li>"
 
725
  if torch.cuda.is_available():
726
  torch.cuda.empty_cache()
727
  logger.info("GPU memory cleared")
728
+
729
  def update_roformer_models(category):
730
  choices = list(ROFORMER_MODELS.get(category, {}).keys()) or []
731
  logger.debug(f"Updating roformer models for category {category}: {choices}")