ymcnabb's picture
Upload folder using huggingface_hub
1824ea0 verified
"""Core audio stem separation logic."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Optional
from stemsplitter.config import Settings, get_settings
logger = logging.getLogger(__name__)
class StemMode(str, Enum):
"""Separation mode."""
TWO_STEM = "2stem"
FOUR_STEM = "4stem"
class OutputFormat(str, Enum):
"""Supported output audio formats."""
WAV = "WAV"
MP3 = "MP3"
FLAC = "FLAC"
STEM_LABELS: dict[StemMode, list[str]] = {
StemMode.TWO_STEM: ["Vocals", "Instrumental"],
StemMode.FOUR_STEM: ["Vocals", "Drums", "Bass", "Other"],
}
@dataclass
class SeparationResult:
"""Result of a stem separation operation."""
input_file: str
output_files: list[str]
mode: StemMode
output_format: OutputFormat
model_used: str
class StemSplitter:
"""High-level wrapper around audio-separator's Separator."""
def __init__(self, settings: Optional[Settings] = None) -> None:
self._settings = settings or get_settings()
self._separator = None
self._loaded_model: str | None = None
def _ensure_separator(self) -> None:
"""Lazily create the underlying Separator instance."""
if self._separator is not None:
return
from audio_separator.separator import Separator
self._separator = Separator(
output_dir=self._settings.output_dir,
model_file_dir=self._settings.model_file_dir,
output_format=self._settings.output_format,
normalization_threshold=self._settings.normalization,
sample_rate=self._settings.sample_rate,
log_level=logging.getLevelName(self._settings.log_level),
)
def _load_model_for_mode(
self, mode: StemMode, model_override: str | None = None
) -> str:
"""Load the appropriate model, returning the model filename used."""
self._ensure_separator()
if model_override:
model_filename = model_override
elif mode == StemMode.TWO_STEM:
model_filename = self._settings.default_2stem_model
else:
model_filename = self._settings.default_4stem_model
if self._loaded_model != model_filename:
logger.info("Loading model: %s", model_filename)
self._separator.load_model(model_filename=model_filename)
self._loaded_model = model_filename
return model_filename
def separate(
self,
input_path: str | Path,
mode: StemMode = StemMode.TWO_STEM,
output_format: OutputFormat | None = None,
model_override: str | None = None,
) -> SeparationResult:
"""Separate an audio file into stems.
Args:
input_path: Path to the input audio file.
mode: TWO_STEM or FOUR_STEM separation.
output_format: Override the configured output format.
model_override: Use a specific model filename instead of the
default for the chosen mode.
Returns:
SeparationResult with paths to all output stem files.
Raises:
FileNotFoundError: If input_path does not exist.
RuntimeError: If separation fails.
"""
input_path = Path(input_path)
if not input_path.is_file():
raise FileNotFoundError(f"Input file not found: {input_path}")
fmt = output_format or OutputFormat(self._settings.output_format)
if output_format:
self._ensure_separator()
self._separator.output_format = fmt.value
model_used = self._load_model_for_mode(mode, model_override)
logger.info(
"Separating '%s' (mode=%s, format=%s, model=%s)",
input_path.name,
mode.value,
fmt.value,
model_used,
)
try:
output_files = self._separator.separate(str(input_path))
except Exception as exc:
raise RuntimeError(
f"Separation failed for '{input_path}': {exc}"
) from exc
return SeparationResult(
input_file=str(input_path),
output_files=list(output_files),
mode=mode,
output_format=fmt,
model_used=model_used,
)