File size: 4,402 Bytes
1824ea0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""Core audio stem separation logic."""

from __future__ import annotations

import logging
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Optional

from stemsplitter.config import Settings, get_settings

logger = logging.getLogger(__name__)


class StemMode(str, Enum):
    """Separation mode."""

    TWO_STEM = "2stem"
    FOUR_STEM = "4stem"


class OutputFormat(str, Enum):
    """Supported output audio formats."""

    WAV = "WAV"
    MP3 = "MP3"
    FLAC = "FLAC"


STEM_LABELS: dict[StemMode, list[str]] = {
    StemMode.TWO_STEM: ["Vocals", "Instrumental"],
    StemMode.FOUR_STEM: ["Vocals", "Drums", "Bass", "Other"],
}


@dataclass
class SeparationResult:
    """Result of a stem separation operation."""

    input_file: str
    output_files: list[str]
    mode: StemMode
    output_format: OutputFormat
    model_used: str


class StemSplitter:
    """High-level wrapper around audio-separator's Separator."""

    def __init__(self, settings: Optional[Settings] = None) -> None:
        self._settings = settings or get_settings()
        self._separator = None
        self._loaded_model: str | None = None

    def _ensure_separator(self) -> None:
        """Lazily create the underlying Separator instance."""
        if self._separator is not None:
            return

        from audio_separator.separator import Separator

        self._separator = Separator(
            output_dir=self._settings.output_dir,
            model_file_dir=self._settings.model_file_dir,
            output_format=self._settings.output_format,
            normalization_threshold=self._settings.normalization,
            sample_rate=self._settings.sample_rate,
            log_level=logging.getLevelName(self._settings.log_level),
        )

    def _load_model_for_mode(
        self, mode: StemMode, model_override: str | None = None
    ) -> str:
        """Load the appropriate model, returning the model filename used."""
        self._ensure_separator()

        if model_override:
            model_filename = model_override
        elif mode == StemMode.TWO_STEM:
            model_filename = self._settings.default_2stem_model
        else:
            model_filename = self._settings.default_4stem_model

        if self._loaded_model != model_filename:
            logger.info("Loading model: %s", model_filename)
            self._separator.load_model(model_filename=model_filename)
            self._loaded_model = model_filename

        return model_filename

    def separate(
        self,
        input_path: str | Path,
        mode: StemMode = StemMode.TWO_STEM,
        output_format: OutputFormat | None = None,
        model_override: str | None = None,
    ) -> SeparationResult:
        """Separate an audio file into stems.

        Args:
            input_path: Path to the input audio file.
            mode: TWO_STEM or FOUR_STEM separation.
            output_format: Override the configured output format.
            model_override: Use a specific model filename instead of the
                           default for the chosen mode.

        Returns:
            SeparationResult with paths to all output stem files.

        Raises:
            FileNotFoundError: If input_path does not exist.
            RuntimeError: If separation fails.
        """
        input_path = Path(input_path)
        if not input_path.is_file():
            raise FileNotFoundError(f"Input file not found: {input_path}")

        fmt = output_format or OutputFormat(self._settings.output_format)
        if output_format:
            self._ensure_separator()
            self._separator.output_format = fmt.value

        model_used = self._load_model_for_mode(mode, model_override)

        logger.info(
            "Separating '%s' (mode=%s, format=%s, model=%s)",
            input_path.name,
            mode.value,
            fmt.value,
            model_used,
        )

        try:
            output_files = self._separator.separate(str(input_path))
        except Exception as exc:
            raise RuntimeError(
                f"Separation failed for '{input_path}': {exc}"
            ) from exc

        return SeparationResult(
            input_file=str(input_path),
            output_files=list(output_files),
            mode=mode,
            output_format=fmt,
            model_used=model_used,
        )