Spaces:
No application file
No application file
File size: 4,936 Bytes
1824ea0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | """Tests for the core StemSplitter class."""
import pytest
from stemsplitter.separator import (
STEM_LABELS,
OutputFormat,
SeparationResult,
StemMode,
StemSplitter,
)
class TestStemMode:
def test_two_stem_value(self):
assert StemMode.TWO_STEM.value == "2stem"
def test_four_stem_value(self):
assert StemMode.FOUR_STEM.value == "4stem"
def test_from_string(self):
assert StemMode("2stem") == StemMode.TWO_STEM
assert StemMode("4stem") == StemMode.FOUR_STEM
class TestOutputFormat:
def test_format_values(self):
assert OutputFormat.WAV.value == "WAV"
assert OutputFormat.MP3.value == "MP3"
assert OutputFormat.FLAC.value == "FLAC"
class TestStemLabels:
def test_two_stem_labels(self):
assert STEM_LABELS[StemMode.TWO_STEM] == ["Vocals", "Instrumental"]
def test_four_stem_labels(self):
assert STEM_LABELS[StemMode.FOUR_STEM] == [
"Vocals",
"Drums",
"Bass",
"Other",
]
class TestStemSplitter:
def test_separate_2stem(self, mock_separator, test_audio_path, env_settings):
"""2-stem separation should return 2 output files."""
splitter = StemSplitter()
result = splitter.separate(
input_path=test_audio_path,
mode=StemMode.TWO_STEM,
)
assert isinstance(result, SeparationResult)
assert len(result.output_files) == 2
assert result.mode == StemMode.TWO_STEM
mock_separator.load_model.assert_called_once()
def test_separate_4stem(
self, mock_separator_4stem, test_audio_path, env_settings
):
"""4-stem separation should return 4 output files."""
splitter = StemSplitter()
result = splitter.separate(
input_path=test_audio_path,
mode=StemMode.FOUR_STEM,
)
assert len(result.output_files) == 4
assert result.mode == StemMode.FOUR_STEM
def test_format_override(self, mock_separator, test_audio_path, env_settings):
"""Output format override should be reflected in result."""
splitter = StemSplitter()
result = splitter.separate(
input_path=test_audio_path,
mode=StemMode.TWO_STEM,
output_format=OutputFormat.FLAC,
)
assert result.output_format == OutputFormat.FLAC
def test_model_caching(self, mock_separator, test_audio_path, env_settings):
"""Same mode twice should NOT reload the model."""
splitter = StemSplitter()
splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
assert mock_separator.load_model.call_count == 1
def test_model_switch(self, mock_separator, test_audio_path, env_settings):
"""Switching modes should reload the model."""
splitter = StemSplitter()
splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
splitter.separate(test_audio_path, mode=StemMode.FOUR_STEM)
assert mock_separator.load_model.call_count == 2
def test_file_not_found(self, env_settings):
"""Should raise FileNotFoundError for missing input."""
splitter = StemSplitter()
with pytest.raises(FileNotFoundError):
splitter.separate("/nonexistent/file.wav")
def test_model_override(self, mock_separator, test_audio_path, env_settings):
"""Custom model_override should be passed through."""
splitter = StemSplitter()
splitter.separate(
test_audio_path,
mode=StemMode.TWO_STEM,
model_override="UVR_MDXNET_KARA_2.onnx",
)
mock_separator.load_model.assert_called_with(
model_filename="UVR_MDXNET_KARA_2.onnx"
)
def test_result_contains_input_file(
self, mock_separator, test_audio_path, env_settings
):
"""Result should reference the original input file."""
splitter = StemSplitter()
result = splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
assert result.input_file == str(test_audio_path)
def test_result_contains_model_used(
self, mock_separator, test_audio_path, env_settings
):
"""Result should reference which model was used."""
splitter = StemSplitter()
result = splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
assert "mel_band_roformer" in result.model_used
def test_separation_runtime_error(
self, mock_separator, test_audio_path, env_settings
):
"""RuntimeError should be raised if the underlying separator fails."""
mock_separator.separate.side_effect = Exception("Model crashed")
splitter = StemSplitter()
with pytest.raises(RuntimeError, match="Separation failed"):
splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
|