"""Core audio stem separation logic.""" from __future__ import annotations import logging from dataclasses import dataclass from enum import Enum from pathlib import Path from typing import Optional from stemsplitter.config import Settings, get_settings logger = logging.getLogger(__name__) class StemMode(str, Enum): """Separation mode.""" TWO_STEM = "2stem" FOUR_STEM = "4stem" class OutputFormat(str, Enum): """Supported output audio formats.""" WAV = "WAV" MP3 = "MP3" FLAC = "FLAC" STEM_LABELS: dict[StemMode, list[str]] = { StemMode.TWO_STEM: ["Vocals", "Instrumental"], StemMode.FOUR_STEM: ["Vocals", "Drums", "Bass", "Other"], } @dataclass class SeparationResult: """Result of a stem separation operation.""" input_file: str output_files: list[str] mode: StemMode output_format: OutputFormat model_used: str class StemSplitter: """High-level wrapper around audio-separator's Separator.""" def __init__(self, settings: Optional[Settings] = None) -> None: self._settings = settings or get_settings() self._separator = None self._loaded_model: str | None = None def _ensure_separator(self) -> None: """Lazily create the underlying Separator instance.""" if self._separator is not None: return from audio_separator.separator import Separator self._separator = Separator( output_dir=self._settings.output_dir, model_file_dir=self._settings.model_file_dir, output_format=self._settings.output_format, normalization_threshold=self._settings.normalization, sample_rate=self._settings.sample_rate, log_level=logging.getLevelName(self._settings.log_level), ) def _load_model_for_mode( self, mode: StemMode, model_override: str | None = None ) -> str: """Load the appropriate model, returning the model filename used.""" self._ensure_separator() if model_override: model_filename = model_override elif mode == StemMode.TWO_STEM: model_filename = self._settings.default_2stem_model else: model_filename = self._settings.default_4stem_model if self._loaded_model != model_filename: logger.info("Loading model: %s", model_filename) self._separator.load_model(model_filename=model_filename) self._loaded_model = model_filename return model_filename def separate( self, input_path: str | Path, mode: StemMode = StemMode.TWO_STEM, output_format: OutputFormat | None = None, model_override: str | None = None, ) -> SeparationResult: """Separate an audio file into stems. Args: input_path: Path to the input audio file. mode: TWO_STEM or FOUR_STEM separation. output_format: Override the configured output format. model_override: Use a specific model filename instead of the default for the chosen mode. Returns: SeparationResult with paths to all output stem files. Raises: FileNotFoundError: If input_path does not exist. RuntimeError: If separation fails. """ input_path = Path(input_path) if not input_path.is_file(): raise FileNotFoundError(f"Input file not found: {input_path}") fmt = output_format or OutputFormat(self._settings.output_format) if output_format: self._ensure_separator() self._separator.output_format = fmt.value model_used = self._load_model_for_mode(mode, model_override) logger.info( "Separating '%s' (mode=%s, format=%s, model=%s)", input_path.name, mode.value, fmt.value, model_used, ) try: output_files = self._separator.separate(str(input_path)) except Exception as exc: raise RuntimeError( f"Separation failed for '{input_path}': {exc}" ) from exc return SeparationResult( input_file=str(input_path), output_files=list(output_files), mode=mode, output_format=fmt, model_used=model_used, )