|
|
| import os
|
| import subprocess
|
| import librosa
|
| import soundfile as sf
|
| import numpy as np
|
| from helpers import clamp_percentage, sanitize_filename
|
|
|
| def process_with_apollo(
|
| output_files,
|
| output_dir,
|
| apollo_chunk_size,
|
| apollo_overlap,
|
| apollo_method,
|
| apollo_normal_model,
|
| apollo_midside_model,
|
| output_format,
|
| progress=None,
|
| total_progress_start=80,
|
| total_progress_end=100
|
| ):
|
| """
|
| Process audio files with Apollo enhancement.
|
|
|
| Args:
|
| output_files: List of input audio file paths to process.
|
| output_dir: Directory to store enhanced output files.
|
| apollo_chunk_size: Chunk size for Apollo processing.
|
| apollo_overlap: Overlap for Apollo processing.
|
| apollo_method: Apollo processing method ('normal_method' or 'mid_side_method').
|
| apollo_normal_model: Apollo model for normal method.
|
| apollo_midside_model: Apollo model for mid-side method.
|
| output_format: Output audio format (e.g., 'wav').
|
| progress: Gradio progress object for UI updates.
|
| total_progress_start: Starting progress percentage (default: 80).
|
| total_progress_end: Ending progress percentage (default: 100).
|
|
|
| Returns:
|
| List of enhanced file paths or original files if processing fails.
|
| """
|
| try:
|
| apollo_script = "/content/Apollo/inference.py"
|
| print(f"Apollo parameters - chunk_size: {apollo_chunk_size}, overlap: {apollo_overlap}, method: {apollo_method}, normal_model: {apollo_normal_model}, midside_model: {apollo_midside_model}")
|
|
|
|
|
| if apollo_method == "mid_side_method":
|
| if apollo_midside_model == "MP3 Enhancer":
|
| ckpt = "/content/Apollo/model/pytorch_model.bin"
|
| config = "/content/Apollo/configs/apollo.yaml"
|
| elif apollo_midside_model == "Lew Vocal Enhancer":
|
| ckpt = "/content/Apollo/model/apollo_model.ckpt"
|
| config = "/content/Apollo/configs/apollo.yaml"
|
| elif apollo_midside_model == "Lew Vocal Enhancer v2 (beta)":
|
| ckpt = "/content/Apollo/model/apollo_model_v2.ckpt"
|
| config = "/content/Apollo/configs/config_apollo_vocal.yaml"
|
| else:
|
| ckpt = "/content/Apollo/model/apollo_universal_model.ckpt"
|
| config = "/content/Apollo/configs/config_apollo.yaml"
|
| else:
|
| if apollo_normal_model == "MP3 Enhancer":
|
| ckpt = "/content/Apollo/model/pytorch_model.bin"
|
| config = "/content/Apollo/configs/apollo.yaml"
|
| elif apollo_normal_model == "Lew Vocal Enhancer":
|
| ckpt = "/content/Apollo/model/apollo_model.ckpt"
|
| config = "/content/Apollo/configs/apollo.yaml"
|
| elif apollo_normal_model == "Lew Vocal Enhancer v2 (beta)":
|
| ckpt = "/content/Apollo/model/apollo_model_v2.ckpt"
|
| config = "/content/Apollo/configs/config_apollo_vocal.yaml"
|
| else:
|
| ckpt = "/content/Apollo/model/apollo_universal_model.ckpt"
|
| config = "/content/Apollo/configs/config_apollo.yaml"
|
|
|
| if not os.path.exists(ckpt):
|
| raise FileNotFoundError(f"Apollo checkpoint file not found: {ckpt}")
|
| if not os.path.exists(config):
|
| raise FileNotFoundError(f"Apollo configuration file not found: {config}")
|
|
|
| enhanced_files = []
|
| total_files = len([f for f in output_files if f and os.path.exists(f)])
|
| progress_per_file = (total_progress_end - total_progress_start) / total_files if total_files > 0 else (total_progress_end - total_progress_start)
|
|
|
| for idx, output_file in enumerate(output_files):
|
| if output_file and os.path.exists(output_file):
|
| original_file_name = sanitize_filename(os.path.splitext(os.path.basename(output_file))[0])
|
| enhancement_suffix = "_Mid_Side_Enhanced" if apollo_method == "mid_side_method" else "_Enhanced"
|
| enhanced_output = os.path.join(output_dir, f"{original_file_name}{enhancement_suffix}.{output_format}")
|
|
|
| try:
|
|
|
| if progress is not None and callable(getattr(progress, '__call__', None)):
|
| current_progress = total_progress_start + (idx * progress_per_file)
|
| current_progress = clamp_percentage(current_progress)
|
| progress(current_progress, desc=f"Enhancing with Apollo... ({idx+1}/{total_files})")
|
| else:
|
| print(f"Progress is not callable or None, skipping Apollo progress update: file {idx+1}/{total_files}")
|
|
|
| if apollo_method == "mid_side_method":
|
| audio, sr = librosa.load(output_file, mono=False, sr=None)
|
| if audio.ndim == 1:
|
| audio = np.array([audio, audio])
|
|
|
| mid = (audio[0] + audio[1]) * 0.5
|
| side = (audio[0] - audio[1]) * 0.5
|
|
|
| mid_file = os.path.join(output_dir, f"{original_file_name}_mid_temp.wav")
|
| side_file = os.path.join(output_dir, f"{original_file_name}_side_temp.wav")
|
| sf.write(mid_file, mid, sr)
|
| sf.write(side_file, side, sr)
|
|
|
| mid_output = os.path.join(output_dir, f"{original_file_name}_mid_enhanced.{output_format}")
|
| command_mid = [
|
| "python", apollo_script,
|
| "--in_wav", mid_file,
|
| "--out_wav", mid_output,
|
| "--chunk_size", str(int(apollo_chunk_size)),
|
| "--overlap", str(int(apollo_overlap)),
|
| "--ckpt", ckpt,
|
| "--config", config
|
| ]
|
| print(f"Running Apollo Mid command: {' '.join(command_mid)}")
|
| result_mid = subprocess.run(command_mid, capture_output=True, text=True)
|
| if result_mid.returncode != 0:
|
| print(f"Apollo Mid processing failed: {result_mid.stderr}")
|
| enhanced_files.append(output_file)
|
| continue
|
|
|
| side_output = os.path.join(output_dir, f"{original_file_name}_side_enhanced.{output_format}")
|
| command_side = [
|
| "python", apollo_script,
|
| "--in_wav", side_file,
|
| "--out_wav", side_output,
|
| "--chunk_size", str(int(apollo_chunk_size)),
|
| "--overlap", str(int(apollo_overlap)),
|
| "--ckpt", ckpt,
|
| "--config", config
|
| ]
|
| print(f"Running Apollo Side command: {' '.join(command_side)}")
|
| result_side = subprocess.run(command_side, capture_output=True, text=True)
|
| if result_side.returncode != 0:
|
| print(f"Apollo Side processing failed: {result_side.stderr}")
|
| enhanced_files.append(output_file)
|
| continue
|
|
|
| if not (os.path.exists(mid_output) and os.path.exists(side_output)):
|
| print(f"Apollo outputs missing: mid={mid_output}, side={side_output}")
|
| enhanced_files.append(output_file)
|
| continue
|
|
|
| mid_audio, _ = librosa.load(mid_output, sr=sr, mono=True)
|
| side_audio, _ = librosa.load(side_output, sr=sr, mono=True)
|
| left = mid_audio + side_audio
|
| right = mid_audio - side_audio
|
| combined = np.array([left, right])
|
|
|
| os.makedirs(os.path.dirname(enhanced_output), exist_ok=True)
|
| sf.write(enhanced_output, combined.T, sr)
|
|
|
| temp_files = [mid_file, side_file, mid_output, side_output]
|
| for temp_file in temp_files:
|
| try:
|
| if os.path.exists(temp_file):
|
| os.remove(temp_file)
|
| except Exception as e:
|
| print(f"Could not delete temporary file {temp_file}: {str(e)}")
|
|
|
| enhanced_files.append(enhanced_output)
|
| else:
|
| command = [
|
| "python", apollo_script,
|
| "--in_wav", output_file,
|
| "--out_wav", enhanced_output,
|
| "--chunk_size", str(int(apollo_chunk_size)),
|
| "--overlap", str(int(apollo_overlap)),
|
| "--ckpt", ckpt,
|
| "--config", config
|
| ]
|
| print(f"Running Apollo Normal command: {' '.join(command)}")
|
| apollo_process = subprocess.run(
|
| command,
|
| capture_output=True,
|
| text=True
|
| )
|
| if apollo_process.returncode != 0:
|
| print(f"Apollo processing failed: {output_file}: {apollo_process.stderr}")
|
| enhanced_files.append(output_file)
|
| continue
|
|
|
| if not os.path.exists(enhanced_output):
|
| print(f"Apollo output missing: {enhanced_output}")
|
| enhanced_files.append(output_file)
|
| continue
|
|
|
| enhanced_files.append(enhanced_output)
|
|
|
|
|
| if progress is not None and callable(getattr(progress, '__call__', None)):
|
| current_progress = total_progress_start + ((idx + 1) * progress_per_file)
|
| current_progress = clamp_percentage(current_progress)
|
| progress(current_progress, desc=f"Enhancing with Apollo... ({idx+1}/{total_files})")
|
|
|
| except Exception as e:
|
| print(f"Error during Apollo processing: {output_file}: {str(e)}")
|
| enhanced_files.append(output_file)
|
| continue
|
| else:
|
| enhanced_files.append(output_file)
|
|
|
|
|
| if progress is not None and callable(getattr(progress, '__call__', None)):
|
| progress(total_progress_end, desc="Apollo enhancement complete")
|
|
|
| return enhanced_files
|
|
|
| except Exception as e:
|
| print(f"Apollo processing error: {str(e)}")
|
| return [f for f in output_files] |