"""Tests for the core StemSplitter class.""" import pytest from stemsplitter.separator import ( STEM_LABELS, OutputFormat, SeparationResult, StemMode, StemSplitter, ) class TestStemMode: def test_two_stem_value(self): assert StemMode.TWO_STEM.value == "2stem" def test_four_stem_value(self): assert StemMode.FOUR_STEM.value == "4stem" def test_from_string(self): assert StemMode("2stem") == StemMode.TWO_STEM assert StemMode("4stem") == StemMode.FOUR_STEM class TestOutputFormat: def test_format_values(self): assert OutputFormat.WAV.value == "WAV" assert OutputFormat.MP3.value == "MP3" assert OutputFormat.FLAC.value == "FLAC" class TestStemLabels: def test_two_stem_labels(self): assert STEM_LABELS[StemMode.TWO_STEM] == ["Vocals", "Instrumental"] def test_four_stem_labels(self): assert STEM_LABELS[StemMode.FOUR_STEM] == [ "Vocals", "Drums", "Bass", "Other", ] class TestStemSplitter: def test_separate_2stem(self, mock_separator, test_audio_path, env_settings): """2-stem separation should return 2 output files.""" splitter = StemSplitter() result = splitter.separate( input_path=test_audio_path, mode=StemMode.TWO_STEM, ) assert isinstance(result, SeparationResult) assert len(result.output_files) == 2 assert result.mode == StemMode.TWO_STEM mock_separator.load_model.assert_called_once() def test_separate_4stem( self, mock_separator_4stem, test_audio_path, env_settings ): """4-stem separation should return 4 output files.""" splitter = StemSplitter() result = splitter.separate( input_path=test_audio_path, mode=StemMode.FOUR_STEM, ) assert len(result.output_files) == 4 assert result.mode == StemMode.FOUR_STEM def test_format_override(self, mock_separator, test_audio_path, env_settings): """Output format override should be reflected in result.""" splitter = StemSplitter() result = splitter.separate( input_path=test_audio_path, mode=StemMode.TWO_STEM, output_format=OutputFormat.FLAC, ) assert result.output_format == OutputFormat.FLAC def test_model_caching(self, mock_separator, test_audio_path, env_settings): """Same mode twice should NOT reload the model.""" splitter = StemSplitter() splitter.separate(test_audio_path, mode=StemMode.TWO_STEM) splitter.separate(test_audio_path, mode=StemMode.TWO_STEM) assert mock_separator.load_model.call_count == 1 def test_model_switch(self, mock_separator, test_audio_path, env_settings): """Switching modes should reload the model.""" splitter = StemSplitter() splitter.separate(test_audio_path, mode=StemMode.TWO_STEM) splitter.separate(test_audio_path, mode=StemMode.FOUR_STEM) assert mock_separator.load_model.call_count == 2 def test_file_not_found(self, env_settings): """Should raise FileNotFoundError for missing input.""" splitter = StemSplitter() with pytest.raises(FileNotFoundError): splitter.separate("/nonexistent/file.wav") def test_model_override(self, mock_separator, test_audio_path, env_settings): """Custom model_override should be passed through.""" splitter = StemSplitter() splitter.separate( test_audio_path, mode=StemMode.TWO_STEM, model_override="UVR_MDXNET_KARA_2.onnx", ) mock_separator.load_model.assert_called_with( model_filename="UVR_MDXNET_KARA_2.onnx" ) def test_result_contains_input_file( self, mock_separator, test_audio_path, env_settings ): """Result should reference the original input file.""" splitter = StemSplitter() result = splitter.separate(test_audio_path, mode=StemMode.TWO_STEM) assert result.input_file == str(test_audio_path) def test_result_contains_model_used( self, mock_separator, test_audio_path, env_settings ): """Result should reference which model was used.""" splitter = StemSplitter() result = splitter.separate(test_audio_path, mode=StemMode.TWO_STEM) assert "mel_band_roformer" in result.model_used def test_separation_runtime_error( self, mock_separator, test_audio_path, env_settings ): """RuntimeError should be raised if the underlying separator fails.""" mock_separator.separate.side_effect = Exception("Model crashed") splitter = StemSplitter() with pytest.raises(RuntimeError, match="Separation failed"): splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)