Spaces:
No application file
No application file
| """Core audio stem separation logic.""" | |
| from __future__ import annotations | |
| import logging | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from pathlib import Path | |
| from typing import Optional | |
| from stemsplitter.config import Settings, get_settings | |
| logger = logging.getLogger(__name__) | |
| class StemMode(str, Enum): | |
| """Separation mode.""" | |
| TWO_STEM = "2stem" | |
| FOUR_STEM = "4stem" | |
| class OutputFormat(str, Enum): | |
| """Supported output audio formats.""" | |
| WAV = "WAV" | |
| MP3 = "MP3" | |
| FLAC = "FLAC" | |
| STEM_LABELS: dict[StemMode, list[str]] = { | |
| StemMode.TWO_STEM: ["Vocals", "Instrumental"], | |
| StemMode.FOUR_STEM: ["Vocals", "Drums", "Bass", "Other"], | |
| } | |
| class SeparationResult: | |
| """Result of a stem separation operation.""" | |
| input_file: str | |
| output_files: list[str] | |
| mode: StemMode | |
| output_format: OutputFormat | |
| model_used: str | |
| class StemSplitter: | |
| """High-level wrapper around audio-separator's Separator.""" | |
| def __init__(self, settings: Optional[Settings] = None) -> None: | |
| self._settings = settings or get_settings() | |
| self._separator = None | |
| self._loaded_model: str | None = None | |
| def _ensure_separator(self) -> None: | |
| """Lazily create the underlying Separator instance.""" | |
| if self._separator is not None: | |
| return | |
| from audio_separator.separator import Separator | |
| self._separator = Separator( | |
| output_dir=self._settings.output_dir, | |
| model_file_dir=self._settings.model_file_dir, | |
| output_format=self._settings.output_format, | |
| normalization_threshold=self._settings.normalization, | |
| sample_rate=self._settings.sample_rate, | |
| log_level=logging.getLevelName(self._settings.log_level), | |
| ) | |
| def _load_model_for_mode( | |
| self, mode: StemMode, model_override: str | None = None | |
| ) -> str: | |
| """Load the appropriate model, returning the model filename used.""" | |
| self._ensure_separator() | |
| if model_override: | |
| model_filename = model_override | |
| elif mode == StemMode.TWO_STEM: | |
| model_filename = self._settings.default_2stem_model | |
| else: | |
| model_filename = self._settings.default_4stem_model | |
| if self._loaded_model != model_filename: | |
| logger.info("Loading model: %s", model_filename) | |
| self._separator.load_model(model_filename=model_filename) | |
| self._loaded_model = model_filename | |
| return model_filename | |
| def separate( | |
| self, | |
| input_path: str | Path, | |
| mode: StemMode = StemMode.TWO_STEM, | |
| output_format: OutputFormat | None = None, | |
| model_override: str | None = None, | |
| ) -> SeparationResult: | |
| """Separate an audio file into stems. | |
| Args: | |
| input_path: Path to the input audio file. | |
| mode: TWO_STEM or FOUR_STEM separation. | |
| output_format: Override the configured output format. | |
| model_override: Use a specific model filename instead of the | |
| default for the chosen mode. | |
| Returns: | |
| SeparationResult with paths to all output stem files. | |
| Raises: | |
| FileNotFoundError: If input_path does not exist. | |
| RuntimeError: If separation fails. | |
| """ | |
| input_path = Path(input_path) | |
| if not input_path.is_file(): | |
| raise FileNotFoundError(f"Input file not found: {input_path}") | |
| fmt = output_format or OutputFormat(self._settings.output_format) | |
| if output_format: | |
| self._ensure_separator() | |
| self._separator.output_format = fmt.value | |
| model_used = self._load_model_for_mode(mode, model_override) | |
| logger.info( | |
| "Separating '%s' (mode=%s, format=%s, model=%s)", | |
| input_path.name, | |
| mode.value, | |
| fmt.value, | |
| model_used, | |
| ) | |
| try: | |
| output_files = self._separator.separate(str(input_path)) | |
| except Exception as exc: | |
| raise RuntimeError( | |
| f"Separation failed for '{input_path}': {exc}" | |
| ) from exc | |
| return SeparationResult( | |
| input_file=str(input_path), | |
| output_files=list(output_files), | |
| mode=mode, | |
| output_format=fmt, | |
| model_used=model_used, | |
| ) | |