File size: 4,936 Bytes
1824ea0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""Tests for the core StemSplitter class."""

import pytest

from stemsplitter.separator import (
    STEM_LABELS,
    OutputFormat,
    SeparationResult,
    StemMode,
    StemSplitter,
)


class TestStemMode:
    def test_two_stem_value(self):
        assert StemMode.TWO_STEM.value == "2stem"

    def test_four_stem_value(self):
        assert StemMode.FOUR_STEM.value == "4stem"

    def test_from_string(self):
        assert StemMode("2stem") == StemMode.TWO_STEM
        assert StemMode("4stem") == StemMode.FOUR_STEM


class TestOutputFormat:
    def test_format_values(self):
        assert OutputFormat.WAV.value == "WAV"
        assert OutputFormat.MP3.value == "MP3"
        assert OutputFormat.FLAC.value == "FLAC"


class TestStemLabels:
    def test_two_stem_labels(self):
        assert STEM_LABELS[StemMode.TWO_STEM] == ["Vocals", "Instrumental"]

    def test_four_stem_labels(self):
        assert STEM_LABELS[StemMode.FOUR_STEM] == [
            "Vocals",
            "Drums",
            "Bass",
            "Other",
        ]


class TestStemSplitter:
    def test_separate_2stem(self, mock_separator, test_audio_path, env_settings):
        """2-stem separation should return 2 output files."""
        splitter = StemSplitter()
        result = splitter.separate(
            input_path=test_audio_path,
            mode=StemMode.TWO_STEM,
        )
        assert isinstance(result, SeparationResult)
        assert len(result.output_files) == 2
        assert result.mode == StemMode.TWO_STEM
        mock_separator.load_model.assert_called_once()

    def test_separate_4stem(
        self, mock_separator_4stem, test_audio_path, env_settings
    ):
        """4-stem separation should return 4 output files."""
        splitter = StemSplitter()
        result = splitter.separate(
            input_path=test_audio_path,
            mode=StemMode.FOUR_STEM,
        )
        assert len(result.output_files) == 4
        assert result.mode == StemMode.FOUR_STEM

    def test_format_override(self, mock_separator, test_audio_path, env_settings):
        """Output format override should be reflected in result."""
        splitter = StemSplitter()
        result = splitter.separate(
            input_path=test_audio_path,
            mode=StemMode.TWO_STEM,
            output_format=OutputFormat.FLAC,
        )
        assert result.output_format == OutputFormat.FLAC

    def test_model_caching(self, mock_separator, test_audio_path, env_settings):
        """Same mode twice should NOT reload the model."""
        splitter = StemSplitter()
        splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
        splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
        assert mock_separator.load_model.call_count == 1

    def test_model_switch(self, mock_separator, test_audio_path, env_settings):
        """Switching modes should reload the model."""
        splitter = StemSplitter()
        splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
        splitter.separate(test_audio_path, mode=StemMode.FOUR_STEM)
        assert mock_separator.load_model.call_count == 2

    def test_file_not_found(self, env_settings):
        """Should raise FileNotFoundError for missing input."""
        splitter = StemSplitter()
        with pytest.raises(FileNotFoundError):
            splitter.separate("/nonexistent/file.wav")

    def test_model_override(self, mock_separator, test_audio_path, env_settings):
        """Custom model_override should be passed through."""
        splitter = StemSplitter()
        splitter.separate(
            test_audio_path,
            mode=StemMode.TWO_STEM,
            model_override="UVR_MDXNET_KARA_2.onnx",
        )
        mock_separator.load_model.assert_called_with(
            model_filename="UVR_MDXNET_KARA_2.onnx"
        )

    def test_result_contains_input_file(
        self, mock_separator, test_audio_path, env_settings
    ):
        """Result should reference the original input file."""
        splitter = StemSplitter()
        result = splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
        assert result.input_file == str(test_audio_path)

    def test_result_contains_model_used(
        self, mock_separator, test_audio_path, env_settings
    ):
        """Result should reference which model was used."""
        splitter = StemSplitter()
        result = splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
        assert "mel_band_roformer" in result.model_used

    def test_separation_runtime_error(
        self, mock_separator, test_audio_path, env_settings
    ):
        """RuntimeError should be raised if the underlying separator fails."""
        mock_separator.separate.side_effect = Exception("Model crashed")
        splitter = StemSplitter()
        with pytest.raises(RuntimeError, match="Separation failed"):
            splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)