ASesYusuf1 commited on
Commit
3047c6d
·
verified ·
1 Parent(s): 50d1ae7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -42
app.py CHANGED
@@ -348,7 +348,7 @@ def download_audio(url, cookie_file=None):
348
  output_path = 'ytdl/gdrive_audio.wav'
349
  audio = AudioSegment.from_file(temp_output_path)
350
  audio.export(output_path, format="wav")
351
- sample_rate, data = scipy.io.wavfile.read(output_path) # Fixed: Use scipy.io.wavfile.read
352
  return output_path, "Download successful", (sample_rate, data)
353
  else:
354
  os.makedirs('ytdl', exist_ok=True)
@@ -360,7 +360,7 @@ def download_audio(url, cookie_file=None):
360
  file_path = file_path.replace(ext, '.wav')
361
  if not os.path.exists(file_path):
362
  return None, "Downloaded file not found", None
363
- sample_rate, data = scipy.io.wavfile.read(file_path) # Fixed: Use scipy.io.wavfile.read
364
  return file_path, "Download successful", (sample_rate, data)
365
  except yt_dlp.utils.ExtractorError as e:
366
  if "Sign in to confirm you’re not a bot" in str(e):
@@ -430,9 +430,10 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
430
  logger.info("GPU memory cleared")
431
 
432
  @spaces.GPU
433
- def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_format="wav", use_tta="False", model_dir="/tmp/audio-separator-models/", output_dir="output", norm_thresh=0.9, amp_thresh=0.9, batch_size=1, ensemble_method="avg_wave", exclude_stems="", weights_str=""):
434
  temp_audio_path = None
435
  chunk_paths = []
 
436
  try:
437
  if not audio:
438
  raise ValueError("No audio file provided.")
@@ -464,13 +465,15 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
464
  chunks = [audio]
465
  use_tta = use_tta == "True"
466
  if os.path.exists(output_dir):
467
- shutil.rmtree(output_dir)
 
468
  os.makedirs(output_dir, exist_ok=True)
469
  base_name = os.path.splitext(os.path.basename(audio))[0]
470
  logger.info(f"Ensemble for {base_name} with {model_keys} on {device}")
471
  all_stems = []
472
  model_stems = {}
473
- for model_key in model_keys:
 
474
  model_stems[model_key] = {"vocals": [], "other": []}
475
  for category, models in ROFORMER_MODELS.items():
476
  if model_key in models:
@@ -480,44 +483,62 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
480
  logger.warning(f"Model {model_key} not found, skipping")
481
  continue
482
  for chunk_idx, chunk_path in enumerate(chunks):
483
- separator = Separator(
484
- log_level=logging.INFO,
485
- model_file_dir=model_dir,
486
- output_dir=output_dir,
487
- output_format=out_format,
488
- normalization_threshold=norm_thresh,
489
- amplification_threshold=amp_thresh,
490
- use_autocast=use_autocast,
491
- mdxc_params={"segment_size": seg_size, "overlap": overlap, "use_tta": use_tta, "batch_size": batch_size}
492
- )
493
- logger.info(f"Loading {model_key} for chunk {chunk_idx}")
494
- separator.load_model(model_filename=model)
495
- logger.info(f"Separating chunk {chunk_idx} with {model_key}")
496
- separation = separator.separate(chunk_path)
497
- stems = [os.path.join(output_dir, file_name) for file_name in separation]
498
- for stem in stems:
499
- if "vocals" in os.path.basename(stem).lower():
500
- model_stems[model_key]["vocals"].append(stem)
501
- elif "other" in os.path.basename(stem).lower():
502
- model_stems[model_key]["other"].append(stem)
503
- separator = None
504
- gc.collect()
505
- if torch.cuda.is_available():
506
- torch.cuda.empty_cache()
507
- logger.info(f"Cleared CUDA cache after {model_key} chunk {chunk_idx}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  for model_key, stems_dict in model_stems.items():
509
  for stem_type in ["vocals", "other"]:
510
  if stems_dict[stem_type]:
511
  combined_path = os.path.join(output_dir, f"{base_name}_{stem_type}_{model_key.replace(' | ', '_').replace(' ', '_')}.wav")
512
- with sf.SoundFile(combined_path, 'w', sr, channels=2 if audio_data.ndim == 2 else 1) as f:
513
- for stem_path in stems_dict[stem_type]:
514
- data, _ = librosa.load(stem_path, sr=sr, mono=False)
515
- f.write(data.T if data.ndim == 2 else data)
516
- logger.info(f"Combined {stem_type} for {model_key}: {combined_path}")
517
- if exclude_stems.strip() and stem_type.lower() in [s.strip().lower() for s in exclude_stems.split(',')]:
518
- logger.info(f"Excluding {stem_type} for {model_key}")
519
- continue
520
- all_stems.append(combined_path)
 
 
 
521
  all_stems = [stem for stem in all_stems if os.path.exists(stem)]
522
  if not all_stems:
523
  raise ValueError("No valid stems found for ensemble.")
@@ -532,10 +553,16 @@ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_form
532
  "--weights", *[str(w) for w in weights],
533
  "--output", output_file
534
  ]
 
535
  logger.info(f"Running ensemble with args: {ensemble_args}")
536
- ensemble_files(ensemble_args)
537
- logger.info("Ensemble completed")
538
- return output_file, f"Ensemble completed with {ensemble_method}, excluded: {exclude_stems if exclude_stems else 'None'}"
 
 
 
 
 
539
  except Exception as e:
540
  logger.error(f"Ensemble error: {e}")
541
  raise RuntimeError(f"Ensemble error: {e}")
 
348
  output_path = 'ytdl/gdrive_audio.wav'
349
  audio = AudioSegment.from_file(temp_output_path)
350
  audio.export(output_path, format="wav")
351
+ sample_rate, data = scipy.io.wavfile.read(output_path)
352
  return output_path, "Download successful", (sample_rate, data)
353
  else:
354
  os.makedirs('ytdl', exist_ok=True)
 
360
  file_path = file_path.replace(ext, '.wav')
361
  if not os.path.exists(file_path):
362
  return None, "Downloaded file not found", None
363
+ sample_rate, data = scipy.io.wavfile.read(file_path)
364
  return file_path, "Download successful", (sample_rate, data)
365
  except yt_dlp.utils.ExtractorError as e:
366
  if "Sign in to confirm you’re not a bot" in str(e):
 
430
  logger.info("GPU memory cleared")
431
 
432
  @spaces.GPU
433
+ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_format="wav", use_tta="False", model_dir="/tmp/audio-separator-models/", output_dir="output", norm_thresh=0.9, amp_thresh=0.9, batch_size=1, ensemble_method="avg_wave", exclude_stems="", weights_str="", progress=gr.Progress(track_tqdm=True)):
434
  temp_audio_path = None
435
  chunk_paths = []
436
+ max_retries = 2 # Retry attempts for ZeroGPU session issues
437
  try:
438
  if not audio:
439
  raise ValueError("No audio file provided.")
 
465
  chunks = [audio]
466
  use_tta = use_tta == "True"
467
  if os.path.exists(output_dir):
468
+ shutil.rmtree(outputwatermark = True
469
+ shutil.copyfile(audio, os.path.join(output_dir, os.path.basename(audio)))
470
  os.makedirs(output_dir, exist_ok=True)
471
  base_name = os.path.splitext(os.path.basename(audio))[0]
472
  logger.info(f"Ensemble for {base_name} with {model_keys} on {device}")
473
  all_stems = []
474
  model_stems = {}
475
+ total_models = len(model_keys)
476
+ for model_idx, model_key in enumerate(model_keys):
477
  model_stems[model_key] = {"vocals": [], "other": []}
478
  for category, models in ROFORMER_MODELS.items():
479
  if model_key in models:
 
483
  logger.warning(f"Model {model_key} not found, skipping")
484
  continue
485
  for chunk_idx, chunk_path in enumerate(chunks):
486
+ retry_count = 0
487
+ while retry_count <= max_retries:
488
+ try:
489
+ progress((model_idx + 0.1) / total_models, desc=f"Loading {model_key} for chunk {chunk_idx}")
490
+ separator = Separator(
491
+ log_level=logging.INFO,
492
+ model_file_dir=model_dir,
493
+ output_dir=output_dir,
494
+ output_format=out_format,
495
+ normalization_threshold=norm_thresh,
496
+ amplification_threshold=amp_thresh,
497
+ use_autocast=use_autocast,
498
+ mdxc_params={"segment_size": seg_size, "overlap": overlap, "use_tta": use_tta, "batch_size": batch_size}
499
+ )
500
+ logger.info(f"Loading {model_key} for chunk {chunk_idx}")
501
+ separator.load_model(model_filename=model)
502
+ progress((model_idx + 0.5) / total_models, desc=f"Separating chunk {chunk_idx} with {model_key}")
503
+ logger.info(f"Separating chunk {chunk_idx} with {model_key}")
504
+ separation = separator.separate(chunk_path)
505
+ stems = [os.path.join(output_dir, file_name) for file_name in separation]
506
+ for stem in stems:
507
+ if "vocals" in os.path.basename(stem).lower():
508
+ model_stems[model_key]["vocals"].append(stem)
509
+ elif "other" in os.path.basename(stem).lower() or "instrumental" in os.path.basename(stem).lower():
510
+ model_stems[model_key]["other"].append(stem)
511
+ break # Success, exit retry loop
512
+ except Exception as e:
513
+ retry_count += 1
514
+ logger.error(f"Error processing {model_key} chunk {chunk_idx}, attempt {retry_count}/{max_retries}: {e}")
515
+ if retry_count > max_retries:
516
+ logger.error(f"Max retries reached for {model_key} chunk {chunk_idx}, skipping")
517
+ break
518
+ time.sleep(2) # Wait before retrying
519
+ finally:
520
+ separator = None
521
+ gc.collect()
522
+ if torch.cuda.is_available():
523
+ torch.cuda.empty_cache()
524
+ logger.info(f"Cleared CUDA cache after {model_key} chunk {chunk_idx}")
525
+ progress(0.8, desc="Combining stems...")
526
  for model_key, stems_dict in model_stems.items():
527
  for stem_type in ["vocals", "other"]:
528
  if stems_dict[stem_type]:
529
  combined_path = os.path.join(output_dir, f"{base_name}_{stem_type}_{model_key.replace(' | ', '_').replace(' ', '_')}.wav")
530
+ try:
531
+ with sf.SoundFile(combined_path, 'w', sr, channels=2 if audio_data.ndim == 2 else 1) as f:
532
+ for stem_path in stems_dict[stem_type]:
533
+ data, _ = librosa.load(stem_path, sr=sr, mono=False)
534
+ f.write(data.T if data.ndim == 2 else data)
535
+ logger.info(f"Combined {stem_type} for {model_key}: {combined_path}")
536
+ if exclude_stems.strip() and stem_type.lower() in [s.strip().lower() for s in exclude_stems.split(',')]:
537
+ logger.info(f"Excluding {stem_type} for {model_key}")
538
+ continue
539
+ all_stems.append(combined_path)
540
+ except Exception as e:
541
+ logger.error(f"Error combining {stem_type} for {model_key}: {e}")
542
  all_stems = [stem for stem in all_stems if os.path.exists(stem)]
543
  if not all_stems:
544
  raise ValueError("No valid stems found for ensemble.")
 
553
  "--weights", *[str(w) for w in weights],
554
  "--output", output_file
555
  ]
556
+ progress(0.9, desc="Running ensemble...")
557
  logger.info(f"Running ensemble with args: {ensemble_args}")
558
+ try:
559
+ ensemble_files(ensemble_args)
560
+ logger.info("Ensemble completed")
561
+ progress(1.0, desc="Ensemble completed")
562
+ return output_file, f"Ensemble completed with {ensemble_method}, excluded: {exclude_stems if exclude_stems else 'None'}"
563
+ except Exception as e:
564
+ logger.error(f"Ensemble processing error: {e}")
565
+ raise RuntimeError(f"Ensemble processing error: {e}")
566
  except Exception as e:
567
  logger.error(f"Ensemble error: {e}")
568
  raise RuntimeError(f"Ensemble error: {e}")