StemSplitter / tests /conftest.py
ymcnabb's picture
Upload folder using huggingface_hub
1824ea0 verified
"""Shared test fixtures."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock
import numpy as np
import pytest
import soundfile as sf
@pytest.fixture
def tmp_output_dir(tmp_path: Path) -> Path:
"""Provide a temporary output directory."""
d = tmp_path / "output"
d.mkdir()
return d
@pytest.fixture
def test_audio_path(tmp_path: Path) -> Path:
"""Generate a small synthetic WAV file (~1 second, 44100 Hz, mono)."""
path = tmp_path / "test_tone.wav"
sr = 44100
duration = 1.0
t = np.linspace(0, duration, int(sr * duration), endpoint=False)
audio = (0.5 * np.sin(2 * np.pi * 440 * t)).astype(np.float32)
sf.write(str(path), audio, sr)
return path
@pytest.fixture
def mock_separator(mocker, tmp_output_dir: Path):
"""Mock audio_separator.separator.Separator for 2-stem output."""
mock_cls = mocker.patch("audio_separator.separator.Separator")
instance = MagicMock()
mock_cls.return_value = instance
def fake_separate(input_path):
stem = Path(input_path).stem
files = []
for label in ["Vocals", "Instrumental"]:
out = tmp_output_dir / f"{stem}_{label}.wav"
out.touch()
files.append(str(out))
return files
instance.separate.side_effect = fake_separate
instance.load_model.return_value = None
return instance
@pytest.fixture
def mock_separator_4stem(mocker, tmp_output_dir: Path):
"""Mock separator producing 4-stem outputs."""
mock_cls = mocker.patch("audio_separator.separator.Separator")
instance = MagicMock()
mock_cls.return_value = instance
def fake_separate(input_path):
stem = Path(input_path).stem
files = []
for label in ["Vocals", "Drums", "Bass", "Other"]:
out = tmp_output_dir / f"{stem}_{label}.wav"
out.touch()
files.append(str(out))
return files
instance.separate.side_effect = fake_separate
instance.load_model.return_value = None
return instance
@pytest.fixture
def env_settings(monkeypatch, tmp_output_dir: Path):
"""Set environment variables for testing config."""
monkeypatch.setenv("STEMSPLITTER_OUTPUT_DIR", str(tmp_output_dir))
monkeypatch.setenv("STEMSPLITTER_LOG_LEVEL", "DEBUG")
monkeypatch.setenv("STEMSPLITTER_OUTPUT_FORMAT", "WAV")