POC_ASR_v6 / app /services /vocal_separator.py
vyluong's picture
PoC deployment
35b29f2 verified
"""
Vocal Separation Service using MDX-Net (via audio-separator).
Isolates vocals from audio files using state-of-the-art MDX-Net models.
"""
import os
import asyncio
import logging
from pathlib import Path
from typing import Optional
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class VocalSeparationError(Exception):
"""Custom exception for vocal separation errors."""
pass
class VocalSeparator:
"""
Service for separating vocals from audio using MDX-Net.
Uses the audio-separator library which supports UVR models.
"""
_separator = None
_model_name: str = None
@classmethod
def _get_separator(cls):
"""Lazy load the Audio Separator."""
if cls._separator is None or cls._model_name != settings.mdx_model:
from audio_separator.separator import Separator
logger.debug(f"Initializing MDX-Net separator with model: {settings.mdx_model}")
# Initialize separator
# Note: audio-separator expects output_dir to exist
settings.processed_dir.mkdir(parents=True, exist_ok=True)
separator = Separator(
output_dir=str(settings.processed_dir),
output_format="WAV",
normalization_threshold=0.9
)
# Load model
separator.load_model(settings.mdx_model)
cls._separator = separator
cls._model_name = settings.mdx_model
logger.debug(f"MDX-Net model loaded on {settings.resolved_device}")
return cls._separator
@classmethod
async def separate_vocals(cls, input_path: Path) -> Path:
"""
Separate vocals from audio file using MDX-Net.
Args:
input_path: Path to input audio file
Returns:
Path to separated vocals WAV file
"""
if not settings.enable_vocal_separation:
logger.debug("Vocal separation disabled, skipping...")
return input_path
logger.debug(f"Starting vocal separation for: {input_path.name}")
try:
# Run separation in executor to not block
loop = asyncio.get_event_loop()
vocals_path = await loop.run_in_executor(
None,
lambda: cls._run_separation(input_path)
)
logger.info(f"Vocal separation complete: {vocals_path.name}")
return vocals_path
except Exception as e:
logger.error(f"Vocal separation failed: {e}")
# Fallback to original
logger.warning("Falling back to original audio.")
return input_path
@classmethod
def _run_separation(cls, input_path: Path) -> Path:
"""Run the actual separation (blocking)."""
separator = cls._get_separator()
# separate() returns a list of output filenames
output_files = separator.separate(str(input_path))
# audio-separator usually produces multiple files (Vocals, Instrumental)
# We need to find the vocals one.
# It typically names them like {input_stem}_(Vocals)_{model}.wav
vocals_file = None
for file in output_files:
if "Vocals" in file:
vocals_file = settings.processed_dir / file
break
if not vocals_file:
# If we can't find the vocals file specifically, just take the first one or fail
logger.warning("Could not identify vocals stem in output files.")
if output_files:
vocals_file = settings.processed_dir / output_files[0]
else:
raise VocalSeparationError("No output files generated by separator.")
return vocals_file