ASesYusuf1 commited on
Commit
05553c4
·
verified ·
1 Parent(s): 4f779d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -145
app.py CHANGED
@@ -437,10 +437,11 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
437
 
438
  @spaces.GPU
439
  def auto_ensemble_process(audio, model_keys, seg_size=64, overlap=0.1, out_format="wav", use_tta="False", model_dir="/tmp/audio-separator-models/", output_dir="output", norm_thresh=0.9, amp_thresh=0.9, batch_size=1, ensemble_method="avg_wave", exclude_stems="", weights_str="", progress=gr.Progress(track_tqdm=True)):
 
440
  temp_audio_path = None
441
  max_retries = 2
442
  start_time = time.time()
443
- time_budget = 100 # seconds
444
  max_models = 6
445
  gpu_lock = Lock()
446
 
@@ -453,7 +454,7 @@ def auto_ensemble_process(audio, model_keys, seg_size=64, overlap=0.1, out_forma
453
  logger.warning(f"Selected {len(model_keys)} models, limiting to {max_models}.")
454
  model_keys = model_keys[:max_models]
455
 
456
- # Dynamic batch size based on audio duration and model count
457
  audio_data, sr = librosa.load(audio, sr=None, mono=False)
458
  duration = librosa.get_duration(y=audio_data, sr=sr)
459
  logger.info(f"Audio duration: {duration:.2f} seconds")
@@ -466,173 +467,191 @@ def auto_ensemble_process(audio, model_keys, seg_size=64, overlap=0.1, out_forma
466
  scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
467
  audio = temp_audio_path
468
 
 
 
 
 
 
 
 
 
469
  use_tta = use_tta == "True"
470
- if os.path.exists(output_dir):
471
- shutil.rmtree(output_dir)
472
- os.makedirs(output_dir, exist_ok=True)
473
  base_name = os.path.splitext(os.path.basename(audio))[0]
474
  logger.info(f"Ensemble for {base_name} with {model_keys} on {device}")
475
 
 
 
 
 
476
  # Model cache
477
  model_cache = {}
478
  all_stems = []
479
- model_stems = {model_key: {"vocals": [], "other": []} for model_key in model_keys}
480
  total_tasks = len(model_keys)
481
 
482
- def process_model(model_key, model_idx):
483
- with torch.no_grad():
484
- for attempt in range(max_retries + 1):
485
- try:
486
- # Find model
487
- for category, models in ROFORMER_MODELS.items():
488
- if model_key in models:
489
- model = models[model_key]
490
- break
491
- else:
492
- logger.warning(f"Model {model_key} not found, skipping")
493
- return []
494
-
495
- # Check time budget
496
- elapsed = time.time() - start_time
497
- if elapsed > time_budget:
498
- logger.error(f"Time budget ({time_budget}s) exceeded")
499
- raise TimeoutError("Processing took too long")
500
-
501
- # Initialize separator
502
- model_path = os.path.join(model_dir, model)
503
- if model_key not in model_cache:
504
- logger.info(f"Loading {model_key} into cache")
505
- separator = Separator(
506
- log_level=logging.INFO,
507
- model_file_dir=model_dir,
508
- output_dir=output_dir,
509
- output_format=out_format,
510
- normalization_threshold=norm_thresh,
511
- amplification_threshold=amp_thresh,
512
- use_autocast=use_autocast,
513
- mdxc_params={
514
- "segment_size": seg_size,
515
- "overlap": overlap,
516
- "use_tta": use_tta,
517
- "batch_size": dynamic_batch_size
518
- }
519
- )
520
- separator.load_model(model_filename=model)
521
- model_cache[model_key] = separator
522
- else:
523
- separator = model_cache[model_key]
524
-
525
- # Process with GPU lock
526
- with gpu_lock:
527
- progress(0.3 + (model_idx / total_tasks) * 0.5, desc=f"Separating with {model_key}")
528
- logger.info(f"Separating with {model_key}")
529
- separation = separator.separate(audio)
530
- stems = [os.path.join(output_dir, file_name) for file_name in separation]
531
- result = []
532
- for stem in stems:
533
- if "vocals" in os.path.basename(stem).lower():
534
- model_stems[model_key]["vocals"].append(stem)
535
- elif "other" in os.path.basename(stem).lower() or "instrumental" in os.path.basename(stem).lower():
536
- model_stems[model_key]["other"].append(stem)
537
- result.append(stem)
538
- return result
539
- except Exception as e:
540
- logger.error(f"Error processing {model_key}, attempt {attempt + 1}/{max_retries + 1}: {e}")
541
- if attempt == max_retries:
542
- logger.error(f"Max retries reached for {model_key}, skipping")
543
- return []
544
- time.sleep(1)
545
- finally:
546
- if torch.cuda.is_available():
547
- torch.cuda.empty_cache()
548
- logger.info(f"Cleared CUDA cache after {model_key}")
549
-
550
- # Parallel processing
551
- progress(0.1, desc="Starting model separations...")
552
- with ThreadPoolExecutor(max_workers=min(4, len(model_keys))) as executor:
553
- future_to_task = {executor.submit(process_model, model_key, idx): model_key for idx, model_key in enumerate(model_keys)}
554
- for future in as_completed(future_to_task):
555
- model_key = future_to_task[future]
556
- try:
557
- stems = future.result()
558
- if stems:
559
- logger.info(f"Completed {model_key}")
560
- else:
561
- logger.warning(f"No stems produced for {model_key}")
562
- except Exception as e:
563
- logger.error(f"Task {model_key} failed: {e}")
564
 
565
- # Clear model cache
566
- model_cache.clear()
567
- gc.collect()
568
- if torch.cuda.is_available():
569
- torch.cuda.empty_cache()
570
- logger.info("Cleared model cache and GPU memory")
571
 
572
- # Combine stems
573
- progress(0.8, desc="Combining stems...")
574
- for model_key, stems_dict in model_stems.items():
575
- for stem_type in ["vocals", "other"]:
576
- if stems_dict[stem_type]:
577
- combined_path = os.path.join(output_dir, f"{base_name}_{stem_type}_{model_key.replace(' | ', '_').replace(' ', '_')}.wav")
578
- try:
579
- data, _ = librosa.load(stems_dict[stem_type][0], sr=sr, mono=False)
580
- with sf.SoundFile(combined_path, 'w', sr, channels=2 if data.ndim == 2 else 1) as f:
581
- f.write(data.T if data.ndim == 2 else data)
582
- logger.info(f"Combined {stem_type} for {model_key}: {combined_path}")
583
- if exclude_stems.strip() and stem_type.lower() in [s.strip().lower() for s in exclude_stems.split(',')]:
584
- logger.info(f"Excluding {stem_type} for {model_key}")
585
  continue
586
- all_stems.append(combined_path)
587
- except Exception as e:
588
- logger.error(f"Error combining {stem_type} for {model_key}: {e}")
589
-
590
- all_stems = [stem for stem in all_stems if os.path.exists(stem)]
591
- if not all_stems:
592
- raise ValueError("No valid stems found for ensemble. Try uploading a local WAV file.")
593
-
594
- # Ensemble
595
- weights = [float(w.strip()) for w in weights_str.split(',')] if weights_str.strip() else [1.0] * len(all_stems)
596
- if len(weights) != len(all_stems):
597
- weights = [1.0] * len(all_stems)
598
- logger.info("Weights mismatched, defaulting to 1.0")
599
- output_file = os.path.join(output_dir, f"{base_name}_ensemble_{ensemble_method}.{out_format}")
600
- ensemble_args = [
601
- "--files", *all_stems,
602
- "--type", ensemble_method,
603
- "--weights", *[str(w) for w in weights],
604
- "--output", output_file
605
- ]
606
- progress(0.9, desc="Running ensemble...")
607
- logger.info(f"Running ensemble with args: {ensemble_args}")
608
- try:
609
  result = ensemble_files(ensemble_args)
610
  if result is None or not os.path.exists(output_file):
611
  raise RuntimeError(f"Ensemble failed, output file not created: {output_file}")
612
- logger.info(f"Ensemble completed, output: {output_file}")
613
- progress(1.0, desc="Ensemble completed")
 
 
 
 
 
614
  elapsed = time.time() - start_time
615
- logger.info(f"Total processing time: {elapsed:.2f}s")
616
- # Prepare file list for download
617
- file_list = [output_file] + all_stems
618
- # Create status message with download links
619
  status = f"Ensemble completed with {ensemble_method}, excluded: {exclude_stems if exclude_stems else 'None'}, {len(model_keys)} models in {elapsed:.2f}s<br>Download files:<ul>"
 
620
  for file in file_list:
621
  file_name = os.path.basename(file)
622
  status += f"<li><a href='file={file}' download>{file_name}</a></li>"
623
  status += "</ul>"
624
  return output_file, status, file_list
625
- except Exception as e:
626
- logger.error(f"Ensemble processing error: {e}")
627
- if "numpy" in str(e).lower() or "copy" in str(e).lower():
628
- error_msg = f"NumPy compatibility error: {e}. Try installing numpy<2.0.0 or contact support."
629
- else:
630
- error_msg = f"Ensemble processing error: {e}"
631
- raise RuntimeError(error_msg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632
  except Exception as e:
633
  logger.error(f"Ensemble error: {e}")
634
- error_msg = f"Processing failed. Try fewer models (max {max_models}), shorter audio, or uploading a local WAV file."
635
  raise RuntimeError(error_msg)
 
636
  finally:
637
  if temp_audio_path and os.path.exists(temp_audio_path):
638
  try:
@@ -643,7 +662,7 @@ def auto_ensemble_process(audio, model_keys, seg_size=64, overlap=0.1, out_forma
643
  if torch.cuda.is_available():
644
  torch.cuda.empty_cache()
645
  logger.info("GPU memory cleared")
646
-
647
  def update_roformer_models(category):
648
  """Update Roformer model dropdown based on selected category."""
649
  choices = list(ROFORMER_MODELS.get(category, {}).keys()) or []
 
437
 
438
  @spaces.GPU
439
  def auto_ensemble_process(audio, model_keys, seg_size=64, overlap=0.1, out_format="wav", use_tta="False", model_dir="/tmp/audio-separator-models/", output_dir="output", norm_thresh=0.9, amp_thresh=0.9, batch_size=1, ensemble_method="avg_wave", exclude_stems="", weights_str="", progress=gr.Progress(track_tqdm=True)):
440
+ global ensemble_state
441
  temp_audio_path = None
442
  max_retries = 2
443
  start_time = time.time()
444
+ time_budget = 300 # ZeroGPU için işlem sınırı
445
  max_models = 6
446
  gpu_lock = Lock()
447
 
 
454
  logger.warning(f"Selected {len(model_keys)} models, limiting to {max_models}.")
455
  model_keys = model_keys[:max_models]
456
 
457
+ # Audio süresine göre dinamik batch size
458
  audio_data, sr = librosa.load(audio, sr=None, mono=False)
459
  duration = librosa.get_duration(y=audio_data, sr=sr)
460
  logger.info(f"Audio duration: {duration:.2f} seconds")
 
467
  scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
468
  audio = temp_audio_path
469
 
470
+ # Aynı ses dosyası mı kontrolü
471
+ if ensemble_state["current_audio"] != audio:
472
+ ensemble_state["current_audio"] = audio
473
+ ensemble_state["current_model_idx"] = 0
474
+ ensemble_state["processed_stems"] = []
475
+ ensemble_state["model_outputs"] = {model_key: {"vocals": [], "other": []} for model_key in model_keys}
476
+ logger.info("New audio file detected, resetting ensemble state.")
477
+
478
  use_tta = use_tta == "True"
 
 
 
479
  base_name = os.path.splitext(os.path.basename(audio))[0]
480
  logger.info(f"Ensemble for {base_name} with {model_keys} on {device}")
481
 
482
+ # Kalıcı bir klasör oluştur
483
+ permanent_output_dir = os.path.join(output_dir, "permanent_stems")
484
+ os.makedirs(permanent_output_dir, exist_ok=True)
485
+
486
  # Model cache
487
  model_cache = {}
488
  all_stems = []
 
489
  total_tasks = len(model_keys)
490
 
491
+ # Şu anki modeli işle
492
+ current_idx = ensemble_state["current_model_idx"]
493
+ if current_idx >= len(model_keys):
494
+ # Tüm modeller işlendiyse ensemble işlemini yap
495
+ logger.info("All models processed, running ensemble...")
496
+ progress(0.9, desc="Running ensemble...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
 
498
+ # "Exclude Stems" listesindeki stem'leri belirle
499
+ excluded_stems_list = [s.strip().lower() for s in exclude_stems.split(',')] if exclude_stems.strip() else []
 
 
 
 
500
 
501
+ # Tüm stem’leri topla, ama "Exclude Stems" ile belirtilenleri hariç tut
502
+ for model_key, stems_dict in ensemble_state["model_outputs"].items():
503
+ for stem_type in ["vocals", "other"]:
504
+ if stems_dict[stem_type]:
505
+ # Stem tipini kontrol et, excluded listesinde varsa atla
506
+ if stem_type.lower() in excluded_stems_list:
507
+ logger.info(f"Excluding {stem_type} for {model_key} from ensemble")
 
 
 
 
 
 
508
  continue
509
+ all_stems.extend(stems_dict[stem_type])
510
+
511
+ all_stems = [stem for stem in all_stems if os.path.exists(stem)]
512
+ if not all_stems:
513
+ raise ValueError("No valid stems found for ensemble after excluding specified stems.")
514
+
515
+ # Ensemble işlemi
516
+ weights = [float(w.strip()) for w in weights_str.split(',')] if weights_str.strip() else [1.0] * len(all_stems)
517
+ if len(weights) != len(all_stems):
518
+ weights = [1.0] * len(all_stems)
519
+ logger.info("Weights mismatched, defaulting to 1.0")
520
+ output_file = os.path.join(output_dir, f"{base_name}_ensemble_{ensemble_method}.{out_format}")
521
+ ensemble_args = [
522
+ "--files", *all_stems,
523
+ "--type", ensemble_method,
524
+ "--weights", *[str(w) for w in weights],
525
+ "--output", output_file
526
+ ]
527
+ logger.info(f"Running ensemble with args: {ensemble_args}")
 
 
 
 
528
  result = ensemble_files(ensemble_args)
529
  if result is None or not os.path.exists(output_file):
530
  raise RuntimeError(f"Ensemble failed, output file not created: {output_file}")
531
+
532
+ # Durumu sıfırla
533
+ ensemble_state["current_model_idx"] = 0
534
+ ensemble_state["current_audio"] = None
535
+ ensemble_state["processed_stems"] = []
536
+ ensemble_state["model_outputs"] = {}
537
+
538
  elapsed = time.time() - start_time
539
+ logger.info(f"Ensemble completed, output: {output_file}, took {elapsed:.2f}s")
540
+ progress(1.0, desc="Ensemble completed")
 
 
541
  status = f"Ensemble completed with {ensemble_method}, excluded: {exclude_stems if exclude_stems else 'None'}, {len(model_keys)} models in {elapsed:.2f}s<br>Download files:<ul>"
542
+ file_list = [output_file] + all_stems
543
  for file in file_list:
544
  file_name = os.path.basename(file)
545
  status += f"<li><a href='file={file}' download>{file_name}</a></li>"
546
  status += "</ul>"
547
  return output_file, status, file_list
548
+
549
+ # Şu anki modeli işle
550
+ model_key = model_keys[current_idx]
551
+ logger.info(f"Processing model {current_idx + 1}/{len(model_keys)}: {model_key}")
552
+ progress(0.1, desc=f"Processing model {model_key}...")
553
+
554
+ with torch.no_grad():
555
+ for attempt in range(max_retries + 1):
556
+ try:
557
+ # Modeli bul
558
+ for category, models in ROFORMER_MODELS.items():
559
+ if model_key in models:
560
+ model = models[model_key]
561
+ break
562
+ else:
563
+ logger.warning(f"Model {model_key} not found, skipping")
564
+ ensemble_state["current_model_idx"] += 1
565
+ return None, f"Model {model_key} not found, proceeding to next model.", []
566
+
567
+ # Zaman kontrolü
568
+ elapsed = time.time() - start_time
569
+ if elapsed > time_budget:
570
+ logger.error(f"Time budget ({time_budget}s) exceeded")
571
+ raise TimeoutError("Processing took too long")
572
+
573
+ # Separator oluştur
574
+ if model_key not in model_cache:
575
+ logger.info(f"Loading {model_key} into cache")
576
+ separator = Separator(
577
+ log_level=logging.INFO,
578
+ model_file_dir=model_dir,
579
+ output_dir=output_dir,
580
+ output_format=out_format,
581
+ normalization_threshold=norm_thresh,
582
+ amplification_threshold=amp_thresh,
583
+ use_autocast=use_autocast,
584
+ mdxc_params={
585
+ "segment_size": seg_size,
586
+ "overlap": overlap,
587
+ "use_tta": use_tta,
588
+ "batch_size": dynamic_batch_size
589
+ }
590
+ )
591
+ separator.load_model(model_filename=model)
592
+ model_cache[model_key] = separator
593
+ else:
594
+ separator = model_cache[model_key]
595
+
596
+ # GPU ile işlem
597
+ with gpu_lock:
598
+ progress(0.3, desc=f"Separating with {model_key}")
599
+ logger.info(f"Separating with {model_key}")
600
+ separation = separator.separate(audio)
601
+ stems = [os.path.join(output_dir, file_name) for file_name in separation]
602
+ result = []
603
+
604
+ # Stem’leri kalıcı klasöre taşı
605
+ for stem in stems:
606
+ stem_type = "vocals" if "vocals" in os.path.basename(stem).lower() else "other"
607
+ permanent_stem_path = os.path.join(permanent_output_dir, f"{base_name}_{stem_type}_{model_key.replace(' | ', '_').replace(' ', '_')}.{out_format}")
608
+ shutil.copy(stem, permanent_stem_path)
609
+ ensemble_state["model_outputs"][model_key][stem_type].append(permanent_stem_path)
610
+ if stem_type not in exclude_stems.lower():
611
+ result.append(permanent_stem_path)
612
+
613
+ ensemble_state["processed_stems"].extend(result)
614
+ break
615
+
616
+ except Exception as e:
617
+ logger.error(f"Error processing {model_key}, attempt {attempt + 1}/{max_retries + 1}: {e}")
618
+ if attempt == max_retries:
619
+ logger.error(f"Max retries reached for {model_key}, skipping")
620
+ ensemble_state["current_model_idx"] += 1
621
+ return None, f"Failed to process {model_key} after {max_retries} attempts.", []
622
+ time.sleep(1)
623
+
624
+ finally:
625
+ if torch.cuda.is_available():
626
+ torch.cuda.empty_cache()
627
+ logger.info(f"Cleared CUDA cache after {model_key}")
628
+
629
+ # Model cache temizliği
630
+ model_cache.clear()
631
+ gc.collect()
632
+ if torch.cuda.is_available():
633
+ torch.cuda.empty_cache()
634
+ logger.info("Cleared model cache and GPU memory")
635
+
636
+ # Bir sonraki modele geç
637
+ ensemble_state["current_model_idx"] += 1
638
+ elapsed = time.time() - start_time
639
+ logger.info(f"Model {model_key} completed in {elapsed:.2f}s")
640
+
641
+ # Çıktılar
642
+ file_list = ensemble_state["processed_stems"]
643
+ status = f"Model {model_key} (Model {current_idx + 1}/{len(model_keys)}) completed in {elapsed:.2f}s<br>Click 'Run Ensemble!' to process the next model.<br>Processed stems:<ul>"
644
+ for file in file_list:
645
+ file_name = os.path.basename(file)
646
+ status += f"<li><a href='file={file}' download>{file_name}</a></li>"
647
+ status += "</ul>"
648
+ return file_list[0] if file_list else None, status, file_list
649
+
650
  except Exception as e:
651
  logger.error(f"Ensemble error: {e}")
652
+ error_msg = f"Processing failed: {e}. Try fewer models (max {max_models}) or uploading a local WAV file."
653
  raise RuntimeError(error_msg)
654
+
655
  finally:
656
  if temp_audio_path and os.path.exists(temp_audio_path):
657
  try:
 
662
  if torch.cuda.is_available():
663
  torch.cuda.empty_cache()
664
  logger.info("GPU memory cleared")
665
+
666
  def update_roformer_models(category):
667
  """Update Roformer model dropdown based on selected category."""
668
  choices = list(ROFORMER_MODELS.get(category, {}).keys()) or []