Spaces:
No application file
No application file
File size: 4,402 Bytes
1824ea0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | """Core audio stem separation logic."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Optional
from stemsplitter.config import Settings, get_settings
logger = logging.getLogger(__name__)
class StemMode(str, Enum):
"""Separation mode."""
TWO_STEM = "2stem"
FOUR_STEM = "4stem"
class OutputFormat(str, Enum):
"""Supported output audio formats."""
WAV = "WAV"
MP3 = "MP3"
FLAC = "FLAC"
STEM_LABELS: dict[StemMode, list[str]] = {
StemMode.TWO_STEM: ["Vocals", "Instrumental"],
StemMode.FOUR_STEM: ["Vocals", "Drums", "Bass", "Other"],
}
@dataclass
class SeparationResult:
"""Result of a stem separation operation."""
input_file: str
output_files: list[str]
mode: StemMode
output_format: OutputFormat
model_used: str
class StemSplitter:
"""High-level wrapper around audio-separator's Separator."""
def __init__(self, settings: Optional[Settings] = None) -> None:
self._settings = settings or get_settings()
self._separator = None
self._loaded_model: str | None = None
def _ensure_separator(self) -> None:
"""Lazily create the underlying Separator instance."""
if self._separator is not None:
return
from audio_separator.separator import Separator
self._separator = Separator(
output_dir=self._settings.output_dir,
model_file_dir=self._settings.model_file_dir,
output_format=self._settings.output_format,
normalization_threshold=self._settings.normalization,
sample_rate=self._settings.sample_rate,
log_level=logging.getLevelName(self._settings.log_level),
)
def _load_model_for_mode(
self, mode: StemMode, model_override: str | None = None
) -> str:
"""Load the appropriate model, returning the model filename used."""
self._ensure_separator()
if model_override:
model_filename = model_override
elif mode == StemMode.TWO_STEM:
model_filename = self._settings.default_2stem_model
else:
model_filename = self._settings.default_4stem_model
if self._loaded_model != model_filename:
logger.info("Loading model: %s", model_filename)
self._separator.load_model(model_filename=model_filename)
self._loaded_model = model_filename
return model_filename
def separate(
self,
input_path: str | Path,
mode: StemMode = StemMode.TWO_STEM,
output_format: OutputFormat | None = None,
model_override: str | None = None,
) -> SeparationResult:
"""Separate an audio file into stems.
Args:
input_path: Path to the input audio file.
mode: TWO_STEM or FOUR_STEM separation.
output_format: Override the configured output format.
model_override: Use a specific model filename instead of the
default for the chosen mode.
Returns:
SeparationResult with paths to all output stem files.
Raises:
FileNotFoundError: If input_path does not exist.
RuntimeError: If separation fails.
"""
input_path = Path(input_path)
if not input_path.is_file():
raise FileNotFoundError(f"Input file not found: {input_path}")
fmt = output_format or OutputFormat(self._settings.output_format)
if output_format:
self._ensure_separator()
self._separator.output_format = fmt.value
model_used = self._load_model_for_mode(mode, model_override)
logger.info(
"Separating '%s' (mode=%s, format=%s, model=%s)",
input_path.name,
mode.value,
fmt.value,
model_used,
)
try:
output_files = self._separator.separate(str(input_path))
except Exception as exc:
raise RuntimeError(
f"Separation failed for '{input_path}': {exc}"
) from exc
return SeparationResult(
input_file=str(input_path),
output_files=list(output_files),
mode=mode,
output_format=fmt,
model_used=model_used,
)
|