ymcnabb commited on
Commit
1824ea0
·
verified ·
1 Parent(s): cad5a52

Upload folder using huggingface_hub

Browse files
.coverage ADDED
Binary file (53.2 kB). View file
 
.env.example ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # StemSplitter Configuration
2
+ STEMSPLITTER_OUTPUT_DIR=./output
3
+ STEMSPLITTER_MODEL_DIR=/tmp/audio-separator-models/
4
+ STEMSPLITTER_2STEM_MODEL=model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt
5
+ STEMSPLITTER_4STEM_MODEL=htdemucs_ft.yaml
6
+ STEMSPLITTER_OUTPUT_FORMAT=WAV
7
+ STEMSPLITTER_OUTPUT_BITRATE=320k
8
+ STEMSPLITTER_SAMPLE_RATE=44100
9
+ STEMSPLITTER_NORMALIZATION=0.9
10
+ STEMSPLITTER_LOG_LEVEL=WARNING
11
+ STEMSPLITTER_WEB_HOST=127.0.0.1
12
+ STEMSPLITTER_WEB_PORT=7860
.github/workflows/update_space.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run Python script
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout
14
+ uses: actions/checkout@v2
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v2
18
+ with:
19
+ python-version: '3.9'
20
+
21
+ - name: Install Gradio
22
+ run: python -m pip install gradio
23
+
24
+ - name: Log in to Hugging Face
25
+ run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
26
+
27
+ - name: Deploy to Spaces
28
+ run: gradio deploy
.gitignore ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.egg-info/
6
+ dist/
7
+ build/
8
+
9
+ # Environment
10
+ .env
11
+ .venv/
12
+
13
+ # Output
14
+ output/
15
+
16
+ # Models (large, auto-downloaded)
17
+ /tmp/audio-separator-models/
18
+
19
+ # IDE
20
+ .vscode/
21
+ .idea/
22
+
23
+ # OS
24
+ .DS_Store
25
+
26
+ # Claude
27
+ .claude/
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
.pytest_cache/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Created by pytest automatically.
2
+ *
.pytest_cache/CACHEDIR.TAG ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Signature: 8a477f597d28d172789f06886806bc55
2
+ # This file is a cache directory tag created by pytest.
3
+ # For information about cache directory tags, see:
4
+ # https://bford.info/cachedir/spec.html
.pytest_cache/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # pytest cache directory #
2
+
3
+ This directory contains data from the pytest's cache plugin,
4
+ which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
5
+
6
+ **Do not** commit this to version control.
7
+
8
+ See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information.
.pytest_cache/v/cache/nodeids ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "tests/test_cli.py::TestCLI::test_2stem_default",
3
+ "tests/test_cli.py::TestCLI::test_4stem_flac",
4
+ "tests/test_cli.py::TestCLI::test_custom_output_dir",
5
+ "tests/test_cli.py::TestCLI::test_help",
6
+ "tests/test_cli.py::TestCLI::test_missing_file",
7
+ "tests/test_cli.py::TestCLI::test_mode_shown_in_output",
8
+ "tests/test_cli.py::TestCLI::test_mp3_format",
9
+ "tests/test_cli.py::TestCLI::test_output_lists_files",
10
+ "tests/test_config.py::TestSettings::test_defaults",
11
+ "tests/test_config.py::TestSettings::test_env_override",
12
+ "tests/test_config.py::TestSettings::test_get_settings_returns_fresh_instance",
13
+ "tests/test_config.py::TestSettings::test_immutability",
14
+ "tests/test_config.py::TestSettings::test_model_defaults",
15
+ "tests/test_config.py::TestSettings::test_output_dir_default",
16
+ "tests/test_separator.py::TestOutputFormat::test_format_values",
17
+ "tests/test_separator.py::TestStemLabels::test_four_stem_labels",
18
+ "tests/test_separator.py::TestStemLabels::test_two_stem_labels",
19
+ "tests/test_separator.py::TestStemMode::test_four_stem_value",
20
+ "tests/test_separator.py::TestStemMode::test_from_string",
21
+ "tests/test_separator.py::TestStemMode::test_two_stem_value",
22
+ "tests/test_separator.py::TestStemSplitter::test_file_not_found",
23
+ "tests/test_separator.py::TestStemSplitter::test_format_override",
24
+ "tests/test_separator.py::TestStemSplitter::test_model_caching",
25
+ "tests/test_separator.py::TestStemSplitter::test_model_override",
26
+ "tests/test_separator.py::TestStemSplitter::test_model_switch",
27
+ "tests/test_separator.py::TestStemSplitter::test_result_contains_input_file",
28
+ "tests/test_separator.py::TestStemSplitter::test_result_contains_model_used",
29
+ "tests/test_separator.py::TestStemSplitter::test_separate_2stem",
30
+ "tests/test_separator.py::TestStemSplitter::test_separate_4stem",
31
+ "tests/test_separator.py::TestStemSplitter::test_separation_runtime_error",
32
+ "tests/test_web.py::TestWebApp::test_app_creation",
33
+ "tests/test_web.py::TestWebApp::test_separate_audio_2stem",
34
+ "tests/test_web.py::TestWebApp::test_separate_audio_4stem",
35
+ "tests/test_web.py::TestWebApp::test_separate_audio_format_passed",
36
+ "tests/test_web.py::TestWebApp::test_separate_audio_no_file"
37
+ ]
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
README.md CHANGED
@@ -1,12 +1,161 @@
1
  ---
2
  title: StemSplitter
3
- emoji: 🚀
4
- colorFrom: blue
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 6.6.0
8
- app_file: app.py
9
- pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: StemSplitter
3
+ app_file: /Users/YaronMcNabb_1/Documents/StemSplitter/src/stemsplitter/web.py
 
 
4
  sdk: gradio
5
  sdk_version: 6.6.0
 
 
6
  ---
7
+ # StemSplitter
8
 
9
+ Audio stem separation tool that splits songs into individual components (vocals, drums, bass, instruments). Provides both a command-line interface and a Gradio web UI.
10
+
11
+ Powered by open-source models via [audio-separator](https://github.com/nomadkaraoke/python-audio-separator):
12
+
13
+ | Mode | Stems | Default Model |
14
+ |------|-------|---------------|
15
+ | 2-stem | Vocals, Instrumental | MelBand-RoFormer |
16
+ | 4-stem | Vocals, Drums, Bass, Other | Demucs htdemucs_ft |
17
+
18
+ ## Prerequisites
19
+
20
+ - Python 3.10+
21
+ - [uv](https://docs.astral.sh/uv/getting-started/installation/) for dependency management
22
+ - FFmpeg (required by audio-separator for reading various audio formats)
23
+
24
+ Install FFmpeg if you don't have it:
25
+
26
+ ```bash
27
+ # macOS
28
+ brew install ffmpeg
29
+
30
+ # Ubuntu/Debian
31
+ sudo apt install ffmpeg
32
+
33
+ # Windows (via chocolatey)
34
+ choco install ffmpeg
35
+ ```
36
+
37
+ ## Installation
38
+
39
+ ```bash
40
+ git clone <repo-url>
41
+ cd StemSplitter
42
+
43
+ # Copy the example env file and adjust as needed
44
+ cp .env.example .env
45
+
46
+ # Install dependencies (CPU inference)
47
+ uv sync --extra dev
48
+
49
+ # Or, for GPU-accelerated inference (NVIDIA CUDA)
50
+ uv sync --extra dev --extra gpu
51
+ ```
52
+
53
+ Models are downloaded automatically on first use (~200 MB for 2-stem, ~800 MB for 4-stem).
54
+
55
+ ## Usage
56
+
57
+ ### CLI
58
+
59
+ ```bash
60
+ # Basic 2-stem separation (vocals + instrumental), outputs WAV
61
+ uv run stemsplitter song.mp3
62
+
63
+ # 4-stem separation with FLAC output
64
+ uv run stemsplitter song.wav -m 4stem -f FLAC
65
+
66
+ # MP3 output to a custom directory
67
+ uv run stemsplitter song.flac -m 2stem -f MP3 -o ./my_stems/
68
+
69
+ # Override the model
70
+ uv run stemsplitter song.mp3 --model htdemucs.yaml
71
+
72
+ # Show all options
73
+ uv run stemsplitter --help
74
+ ```
75
+
76
+ **Supported input formats:** MP3, WAV, FLAC, OGG, M4A, and anything FFmpeg can decode.
77
+
78
+ **Supported output formats:** WAV, MP3, FLAC (set via `-f` flag or `STEMSPLITTER_OUTPUT_FORMAT` in `.env`).
79
+
80
+ ### Web UI
81
+
82
+ ```bash
83
+ uv run stemsplitter-web
84
+ ```
85
+
86
+ Opens a Gradio interface (default: `http://127.0.0.1:7860`) where you can:
87
+
88
+ 1. Upload an audio file
89
+ 2. Choose separation mode (2-stem or 4-stem)
90
+ 3. Choose output format (WAV, MP3, FLAC)
91
+ 4. Click **Separate** and download individual stems
92
+
93
+ A public share link is also generated automatically.
94
+
95
+ ## Project Structure
96
+
97
+ ```
98
+ src/stemsplitter/
99
+ __init__.py # Package version
100
+ config.py # Settings loaded from .env with sensible defaults
101
+ separator.py # Core StemSplitter class wrapping audio-separator
102
+ cli.py # Click-based CLI entry point
103
+ web.py # Gradio web UI
104
+
105
+ tests/
106
+ conftest.py # Shared fixtures (mock separator, synthetic audio)
107
+ test_config.py # Configuration loading tests
108
+ test_separator.py# Core separation logic tests
109
+ test_cli.py # CLI invocation tests
110
+ test_web.py # Web UI handler tests
111
+ ```
112
+
113
+ ### Components
114
+
115
+ - **config.py** -- Loads settings from a `.env` file using `python-dotenv`. All values are exposed as a frozen `Settings` dataclass. See `.env.example` for the full list of options.
116
+
117
+ - **separator.py** -- Wraps `audio-separator` with a `StemSplitter` class that handles model selection per mode, lazy initialization (so imports are fast), and model caching (the model stays loaded between calls).
118
+
119
+ - **cli.py** -- A Click command that accepts an input file and flags for mode, format, output directory, and model override.
120
+
121
+ - **web.py** -- A Gradio Blocks app with audio upload, mode/format radio buttons, and per-stem audio outputs. The 4-stem outputs (drums, bass) are hidden in 2-stem mode.
122
+
123
+ ## Configuration
124
+
125
+ All settings are configurable via environment variables in `.env`:
126
+
127
+ | Variable | Default | Description |
128
+ |----------|---------|-------------|
129
+ | `STEMSPLITTER_OUTPUT_DIR` | `./output` | Directory for separated stems |
130
+ | `STEMSPLITTER_MODEL_DIR` | `/tmp/audio-separator-models/` | Where downloaded models are cached |
131
+ | `STEMSPLITTER_2STEM_MODEL` | `model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt` | Model for 2-stem separation |
132
+ | `STEMSPLITTER_4STEM_MODEL` | `htdemucs_ft.yaml` | Model for 4-stem separation |
133
+ | `STEMSPLITTER_OUTPUT_FORMAT` | `WAV` | Default output format (WAV, MP3, FLAC) |
134
+ | `STEMSPLITTER_OUTPUT_BITRATE` | `320k` | Bitrate for MP3 output |
135
+ | `STEMSPLITTER_SAMPLE_RATE` | `44100` | Output sample rate |
136
+ | `STEMSPLITTER_NORMALIZATION` | `0.9` | Peak normalization threshold |
137
+ | `STEMSPLITTER_LOG_LEVEL` | `WARNING` | Logging verbosity (DEBUG, INFO, WARNING, ERROR) |
138
+ | `STEMSPLITTER_WEB_HOST` | `127.0.0.1` | Web UI bind address |
139
+ | `STEMSPLITTER_WEB_PORT` | `7860` | Web UI port |
140
+
141
+ ## Running Tests
142
+
143
+ ```bash
144
+ # Run all tests
145
+ uv run pytest
146
+
147
+ # Verbose output
148
+ uv run pytest -v
149
+
150
+ # With coverage report
151
+ uv run pytest -v --cov=stemsplitter --cov-report=term-missing
152
+
153
+ # Run a specific test file
154
+ uv run pytest tests/test_separator.py
155
+ ```
156
+
157
+ Tests use mocked models so no GPU or model downloads are required.
158
+
159
+ ## License
160
+
161
+ MIT
pyproject.toml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "stemsplitter"
3
+ version = "0.1.0"
4
+ description = "Audio stem splitter with CLI and web UI"
5
+ requires-python = ">=3.10"
6
+ dependencies = [
7
+ "audio-separator[cpu]>=0.41.0",
8
+ "click>=8.1",
9
+ "gradio>=5.0",
10
+ "python-dotenv>=1.0",
11
+ "soundfile>=0.12",
12
+ ]
13
+
14
+ [project.optional-dependencies]
15
+ gpu = ["audio-separator[gpu]>=0.41.0"]
16
+ dev = [
17
+ "pytest>=8.0",
18
+ "pytest-cov>=5.0",
19
+ "pytest-mock>=3.14",
20
+ "numpy>=1.26",
21
+ ]
22
+
23
+ [project.scripts]
24
+ stemsplitter = "stemsplitter.cli:main"
25
+ stemsplitter-web = "stemsplitter.web:launch"
26
+
27
+ [build-system]
28
+ requires = ["hatchling"]
29
+ build-backend = "hatchling.build"
30
+
31
+ [tool.hatch.build.targets.wheel]
32
+ packages = ["src/stemsplitter"]
33
+
34
+ [tool.pytest.ini_options]
35
+ testpaths = ["tests"]
36
+ pythonpath = ["src"]
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ audio-seperator click gradio python-dotenv soundfile
2
+ audio-separator
3
+ click
4
+ gradio
5
+ python-dotenv
6
+ soundfile
src/stemsplitter/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """StemSplitter - Audio stem separation with CLI and web UI."""
2
+
3
+ __version__ = "0.1.0"
src/stemsplitter/cli.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Command-line interface for StemSplitter."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import sys
6
+ from dataclasses import replace
7
+ from pathlib import Path
8
+
9
+ import click
10
+
11
+ from stemsplitter.config import get_settings
12
+ from stemsplitter.separator import OutputFormat, StemMode, StemSplitter
13
+
14
+
15
+ @click.command()
16
+ @click.argument("input_file", type=click.Path(exists=True, dir_okay=False))
17
+ @click.option(
18
+ "-m",
19
+ "--mode",
20
+ type=click.Choice(["2stem", "4stem"], case_sensitive=False),
21
+ default="2stem",
22
+ show_default=True,
23
+ help="Separation mode: 2-stem (vocals/instrumental) or 4-stem.",
24
+ )
25
+ @click.option(
26
+ "-f",
27
+ "--format",
28
+ "output_format",
29
+ type=click.Choice(["WAV", "MP3", "FLAC"], case_sensitive=False),
30
+ default=None,
31
+ help="Output audio format. Defaults to value in .env or WAV.",
32
+ )
33
+ @click.option(
34
+ "-o",
35
+ "--output-dir",
36
+ type=click.Path(file_okay=False),
37
+ default=None,
38
+ help="Output directory. Defaults to value in .env or ./output.",
39
+ )
40
+ @click.option(
41
+ "--model",
42
+ default=None,
43
+ help="Override the model filename (e.g., htdemucs.yaml).",
44
+ )
45
+ def main(
46
+ input_file: str,
47
+ mode: str,
48
+ output_format: str | None,
49
+ output_dir: str | None,
50
+ model: str | None,
51
+ ) -> None:
52
+ """Separate audio stems from INPUT_FILE.
53
+
54
+ Splits an audio file into vocal and instrumental stems (2-stem mode)
55
+ or into vocals, drums, bass, and other (4-stem mode).
56
+
57
+ \b
58
+ Examples:
59
+ stemsplitter song.mp3
60
+ stemsplitter song.wav -m 4stem -f FLAC
61
+ stemsplitter song.flac -m 2stem -f MP3 -o ./stems/
62
+ """
63
+ settings = get_settings()
64
+
65
+ if output_dir:
66
+ settings = replace(settings, output_dir=output_dir)
67
+
68
+ Path(settings.output_dir).mkdir(parents=True, exist_ok=True)
69
+
70
+ splitter = StemSplitter(settings=settings)
71
+
72
+ stem_mode = StemMode(mode)
73
+ fmt = OutputFormat(output_format.upper()) if output_format else None
74
+
75
+ click.echo(
76
+ f"Processing: {input_file}\n"
77
+ f"Mode: {stem_mode.value} | "
78
+ f"Format: {(fmt or OutputFormat(settings.output_format)).value}"
79
+ )
80
+
81
+ try:
82
+ result = splitter.separate(
83
+ input_path=input_file,
84
+ mode=stem_mode,
85
+ output_format=fmt,
86
+ model_override=model,
87
+ )
88
+ except FileNotFoundError as exc:
89
+ click.secho(str(exc), fg="red", err=True)
90
+ sys.exit(1)
91
+ except RuntimeError as exc:
92
+ click.secho(f"Separation failed: {exc}", fg="red", err=True)
93
+ sys.exit(1)
94
+
95
+ click.secho("Separation complete!", fg="green")
96
+ for f in result.output_files:
97
+ click.echo(f" -> {f}")
src/stemsplitter/config.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Application configuration loaded from .env with defaults."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from dataclasses import dataclass, field
7
+ from pathlib import Path
8
+
9
+ from dotenv import load_dotenv
10
+
11
+
12
+ def _load_env() -> None:
13
+ """Load .env from project root (or CWD) if it exists."""
14
+ for candidate in (Path(__file__).resolve().parents[2], Path.cwd()):
15
+ env_path = candidate / ".env"
16
+ if env_path.is_file():
17
+ load_dotenv(env_path)
18
+ return
19
+ load_dotenv()
20
+
21
+
22
+ _load_env()
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class Settings:
27
+ """Immutable application settings."""
28
+
29
+ output_dir: str = field(
30
+ default_factory=lambda: os.getenv("STEMSPLITTER_OUTPUT_DIR", "./output")
31
+ )
32
+ model_file_dir: str = field(
33
+ default_factory=lambda: os.getenv(
34
+ "STEMSPLITTER_MODEL_DIR", "/tmp/audio-separator-models/"
35
+ )
36
+ )
37
+ default_2stem_model: str = field(
38
+ default_factory=lambda: os.getenv(
39
+ "STEMSPLITTER_2STEM_MODEL",
40
+ "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt",
41
+ )
42
+ )
43
+ default_4stem_model: str = field(
44
+ default_factory=lambda: os.getenv(
45
+ "STEMSPLITTER_4STEM_MODEL", "htdemucs_ft.yaml"
46
+ )
47
+ )
48
+ output_format: str = field(
49
+ default_factory=lambda: os.getenv("STEMSPLITTER_OUTPUT_FORMAT", "WAV")
50
+ )
51
+ output_bitrate: str = field(
52
+ default_factory=lambda: os.getenv("STEMSPLITTER_OUTPUT_BITRATE", "320k")
53
+ )
54
+ sample_rate: int = field(
55
+ default_factory=lambda: int(os.getenv("STEMSPLITTER_SAMPLE_RATE", "44100"))
56
+ )
57
+ normalization: float = field(
58
+ default_factory=lambda: float(
59
+ os.getenv("STEMSPLITTER_NORMALIZATION", "0.9")
60
+ )
61
+ )
62
+ log_level: str = field(
63
+ default_factory=lambda: os.getenv("STEMSPLITTER_LOG_LEVEL", "WARNING")
64
+ )
65
+ web_host: str = field(
66
+ default_factory=lambda: os.getenv("STEMSPLITTER_WEB_HOST", "127.0.0.1")
67
+ )
68
+ web_port: int = field(
69
+ default_factory=lambda: int(os.getenv("STEMSPLITTER_WEB_PORT", "7860"))
70
+ )
71
+
72
+
73
+ def get_settings() -> Settings:
74
+ """Return a fresh Settings instance (re-reads env vars)."""
75
+ return Settings()
src/stemsplitter/separator.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Core audio stem separation logic."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from dataclasses import dataclass
7
+ from enum import Enum
8
+ from pathlib import Path
9
+ from typing import Optional
10
+
11
+ from stemsplitter.config import Settings, get_settings
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class StemMode(str, Enum):
17
+ """Separation mode."""
18
+
19
+ TWO_STEM = "2stem"
20
+ FOUR_STEM = "4stem"
21
+
22
+
23
+ class OutputFormat(str, Enum):
24
+ """Supported output audio formats."""
25
+
26
+ WAV = "WAV"
27
+ MP3 = "MP3"
28
+ FLAC = "FLAC"
29
+
30
+
31
+ STEM_LABELS: dict[StemMode, list[str]] = {
32
+ StemMode.TWO_STEM: ["Vocals", "Instrumental"],
33
+ StemMode.FOUR_STEM: ["Vocals", "Drums", "Bass", "Other"],
34
+ }
35
+
36
+
37
+ @dataclass
38
+ class SeparationResult:
39
+ """Result of a stem separation operation."""
40
+
41
+ input_file: str
42
+ output_files: list[str]
43
+ mode: StemMode
44
+ output_format: OutputFormat
45
+ model_used: str
46
+
47
+
48
+ class StemSplitter:
49
+ """High-level wrapper around audio-separator's Separator."""
50
+
51
+ def __init__(self, settings: Optional[Settings] = None) -> None:
52
+ self._settings = settings or get_settings()
53
+ self._separator = None
54
+ self._loaded_model: str | None = None
55
+
56
+ def _ensure_separator(self) -> None:
57
+ """Lazily create the underlying Separator instance."""
58
+ if self._separator is not None:
59
+ return
60
+
61
+ from audio_separator.separator import Separator
62
+
63
+ self._separator = Separator(
64
+ output_dir=self._settings.output_dir,
65
+ model_file_dir=self._settings.model_file_dir,
66
+ output_format=self._settings.output_format,
67
+ normalization_threshold=self._settings.normalization,
68
+ sample_rate=self._settings.sample_rate,
69
+ log_level=logging.getLevelName(self._settings.log_level),
70
+ )
71
+
72
+ def _load_model_for_mode(
73
+ self, mode: StemMode, model_override: str | None = None
74
+ ) -> str:
75
+ """Load the appropriate model, returning the model filename used."""
76
+ self._ensure_separator()
77
+
78
+ if model_override:
79
+ model_filename = model_override
80
+ elif mode == StemMode.TWO_STEM:
81
+ model_filename = self._settings.default_2stem_model
82
+ else:
83
+ model_filename = self._settings.default_4stem_model
84
+
85
+ if self._loaded_model != model_filename:
86
+ logger.info("Loading model: %s", model_filename)
87
+ self._separator.load_model(model_filename=model_filename)
88
+ self._loaded_model = model_filename
89
+
90
+ return model_filename
91
+
92
+ def separate(
93
+ self,
94
+ input_path: str | Path,
95
+ mode: StemMode = StemMode.TWO_STEM,
96
+ output_format: OutputFormat | None = None,
97
+ model_override: str | None = None,
98
+ ) -> SeparationResult:
99
+ """Separate an audio file into stems.
100
+
101
+ Args:
102
+ input_path: Path to the input audio file.
103
+ mode: TWO_STEM or FOUR_STEM separation.
104
+ output_format: Override the configured output format.
105
+ model_override: Use a specific model filename instead of the
106
+ default for the chosen mode.
107
+
108
+ Returns:
109
+ SeparationResult with paths to all output stem files.
110
+
111
+ Raises:
112
+ FileNotFoundError: If input_path does not exist.
113
+ RuntimeError: If separation fails.
114
+ """
115
+ input_path = Path(input_path)
116
+ if not input_path.is_file():
117
+ raise FileNotFoundError(f"Input file not found: {input_path}")
118
+
119
+ fmt = output_format or OutputFormat(self._settings.output_format)
120
+ if output_format:
121
+ self._ensure_separator()
122
+ self._separator.output_format = fmt.value
123
+
124
+ model_used = self._load_model_for_mode(mode, model_override)
125
+
126
+ logger.info(
127
+ "Separating '%s' (mode=%s, format=%s, model=%s)",
128
+ input_path.name,
129
+ mode.value,
130
+ fmt.value,
131
+ model_used,
132
+ )
133
+
134
+ try:
135
+ output_files = self._separator.separate(str(input_path))
136
+ except Exception as exc:
137
+ raise RuntimeError(
138
+ f"Separation failed for '{input_path}': {exc}"
139
+ ) from exc
140
+
141
+ return SeparationResult(
142
+ input_file=str(input_path),
143
+ output_files=list(output_files),
144
+ mode=mode,
145
+ output_format=fmt,
146
+ model_used=model_used,
147
+ )
src/stemsplitter/web.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio web interface for StemSplitter."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from pathlib import Path
7
+
8
+ import gradio as gr
9
+
10
+ from stemsplitter.config import get_settings
11
+ from stemsplitter.separator import OutputFormat, StemMode, StemSplitter
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ _splitter: StemSplitter | None = None
16
+
17
+
18
+ def _get_splitter() -> StemSplitter:
19
+ """Get or create the module-level StemSplitter singleton."""
20
+ global _splitter
21
+ if _splitter is None:
22
+ settings = get_settings()
23
+ Path(settings.output_dir).mkdir(parents=True, exist_ok=True)
24
+ _splitter = StemSplitter(settings=settings)
25
+ return _splitter
26
+
27
+
28
+ def separate_audio(
29
+ audio_path: str,
30
+ mode: str,
31
+ output_format: str,
32
+ progress: gr.Progress = gr.Progress(),
33
+ ) -> list[str | None]:
34
+ """Gradio handler: separate audio and return stem file paths.
35
+
36
+ Returns a list of 4 file paths (padding with None for 2-stem mode).
37
+ """
38
+ if not audio_path:
39
+ raise gr.Error("Please upload an audio file.")
40
+
41
+ progress(0.1, desc="Initializing model...")
42
+ splitter = _get_splitter()
43
+
44
+ stem_mode = StemMode(mode)
45
+ fmt = OutputFormat(output_format)
46
+
47
+ progress(0.3, desc=f"Separating stems ({stem_mode.value})...")
48
+ result = splitter.separate(
49
+ input_path=audio_path,
50
+ mode=stem_mode,
51
+ output_format=fmt,
52
+ )
53
+
54
+ progress(1.0, desc="Done!")
55
+
56
+ outputs = list(result.output_files)
57
+ while len(outputs) < 4:
58
+ outputs.append(None)
59
+
60
+ return outputs[:4]
61
+
62
+
63
+ def create_app() -> gr.Blocks:
64
+ """Build and return the Gradio Blocks application."""
65
+ with gr.Blocks(title="StemSplitter") as app:
66
+ gr.Markdown("# StemSplitter\nSeparate audio into individual stems.")
67
+
68
+ with gr.Row():
69
+ with gr.Column(scale=1):
70
+ audio_input = gr.Audio(
71
+ label="Upload Audio",
72
+ type="filepath",
73
+ sources=["upload"],
74
+ )
75
+ mode_radio = gr.Radio(
76
+ choices=["2stem", "4stem"],
77
+ value="2stem",
78
+ label="Separation Mode",
79
+ info="2-stem: Vocals + Instrumental | 4-stem: Vocals + Drums + Bass + Other",
80
+ )
81
+ format_radio = gr.Radio(
82
+ choices=["WAV", "MP3", "FLAC"],
83
+ value="WAV",
84
+ label="Output Format",
85
+ )
86
+ separate_btn = gr.Button("Separate", variant="primary")
87
+
88
+ with gr.Column(scale=2):
89
+ vocals_output = gr.Audio(label="Vocals", type="filepath")
90
+ instrumental_output = gr.Audio(
91
+ label="Instrumental", type="filepath"
92
+ )
93
+ drums_output = gr.Audio(
94
+ label="Drums",
95
+ type="filepath",
96
+ visible=False,
97
+ )
98
+ bass_output = gr.Audio(
99
+ label="Bass",
100
+ type="filepath",
101
+ visible=False,
102
+ )
103
+
104
+ def update_outputs_visibility(mode: str):
105
+ is_4stem = mode == "4stem"
106
+ return (
107
+ gr.update(label="Instrumental" if not is_4stem else "Other"),
108
+ gr.update(visible=is_4stem),
109
+ gr.update(visible=is_4stem),
110
+ )
111
+
112
+ mode_radio.change(
113
+ fn=update_outputs_visibility,
114
+ inputs=[mode_radio],
115
+ outputs=[instrumental_output, drums_output, bass_output],
116
+ )
117
+
118
+ separate_btn.click(
119
+ fn=separate_audio,
120
+ inputs=[audio_input, mode_radio, format_radio],
121
+ outputs=[
122
+ vocals_output,
123
+ instrumental_output,
124
+ drums_output,
125
+ bass_output,
126
+ ],
127
+ )
128
+
129
+ return app
130
+
131
+
132
+ def launch() -> None:
133
+ """Entry point for `stemsplitter-web` console script."""
134
+ settings = get_settings()
135
+ app = create_app()
136
+ app.launch(
137
+ server_name=settings.web_host,
138
+ server_port=settings.web_port,
139
+ theme=gr.themes.Soft(),
140
+ share=True,
141
+ )
tests/__init__.py ADDED
File without changes
tests/conftest.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared test fixtures."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from unittest.mock import MagicMock
7
+
8
+ import numpy as np
9
+ import pytest
10
+ import soundfile as sf
11
+
12
+
13
+ @pytest.fixture
14
+ def tmp_output_dir(tmp_path: Path) -> Path:
15
+ """Provide a temporary output directory."""
16
+ d = tmp_path / "output"
17
+ d.mkdir()
18
+ return d
19
+
20
+
21
+ @pytest.fixture
22
+ def test_audio_path(tmp_path: Path) -> Path:
23
+ """Generate a small synthetic WAV file (~1 second, 44100 Hz, mono)."""
24
+ path = tmp_path / "test_tone.wav"
25
+ sr = 44100
26
+ duration = 1.0
27
+ t = np.linspace(0, duration, int(sr * duration), endpoint=False)
28
+ audio = (0.5 * np.sin(2 * np.pi * 440 * t)).astype(np.float32)
29
+ sf.write(str(path), audio, sr)
30
+ return path
31
+
32
+
33
+ @pytest.fixture
34
+ def mock_separator(mocker, tmp_output_dir: Path):
35
+ """Mock audio_separator.separator.Separator for 2-stem output."""
36
+ mock_cls = mocker.patch("audio_separator.separator.Separator")
37
+ instance = MagicMock()
38
+ mock_cls.return_value = instance
39
+
40
+ def fake_separate(input_path):
41
+ stem = Path(input_path).stem
42
+ files = []
43
+ for label in ["Vocals", "Instrumental"]:
44
+ out = tmp_output_dir / f"{stem}_{label}.wav"
45
+ out.touch()
46
+ files.append(str(out))
47
+ return files
48
+
49
+ instance.separate.side_effect = fake_separate
50
+ instance.load_model.return_value = None
51
+
52
+ return instance
53
+
54
+
55
+ @pytest.fixture
56
+ def mock_separator_4stem(mocker, tmp_output_dir: Path):
57
+ """Mock separator producing 4-stem outputs."""
58
+ mock_cls = mocker.patch("audio_separator.separator.Separator")
59
+ instance = MagicMock()
60
+ mock_cls.return_value = instance
61
+
62
+ def fake_separate(input_path):
63
+ stem = Path(input_path).stem
64
+ files = []
65
+ for label in ["Vocals", "Drums", "Bass", "Other"]:
66
+ out = tmp_output_dir / f"{stem}_{label}.wav"
67
+ out.touch()
68
+ files.append(str(out))
69
+ return files
70
+
71
+ instance.separate.side_effect = fake_separate
72
+ instance.load_model.return_value = None
73
+
74
+ return instance
75
+
76
+
77
+ @pytest.fixture
78
+ def env_settings(monkeypatch, tmp_output_dir: Path):
79
+ """Set environment variables for testing config."""
80
+ monkeypatch.setenv("STEMSPLITTER_OUTPUT_DIR", str(tmp_output_dir))
81
+ monkeypatch.setenv("STEMSPLITTER_LOG_LEVEL", "DEBUG")
82
+ monkeypatch.setenv("STEMSPLITTER_OUTPUT_FORMAT", "WAV")
tests/test_cli.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the CLI interface."""
2
+
3
+ from click.testing import CliRunner
4
+
5
+ from stemsplitter.cli import main
6
+
7
+
8
+ class TestCLI:
9
+ def test_help(self):
10
+ """--help should succeed and show usage."""
11
+ runner = CliRunner()
12
+ result = runner.invoke(main, ["--help"])
13
+ assert result.exit_code == 0
14
+ assert "Separate audio stems" in result.output
15
+
16
+ def test_missing_file(self):
17
+ """Non-existent file should produce error."""
18
+ runner = CliRunner()
19
+ result = runner.invoke(main, ["/nonexistent.wav"])
20
+ assert result.exit_code != 0
21
+
22
+ def test_2stem_default(self, mock_separator, test_audio_path, env_settings):
23
+ """Default invocation should use 2-stem mode."""
24
+ runner = CliRunner()
25
+ result = runner.invoke(main, [str(test_audio_path)])
26
+ assert result.exit_code == 0
27
+ assert "Separation complete" in result.output
28
+
29
+ def test_4stem_flac(
30
+ self, mock_separator_4stem, test_audio_path, env_settings
31
+ ):
32
+ """4-stem + FLAC flags should be accepted."""
33
+ runner = CliRunner()
34
+ result = runner.invoke(
35
+ main,
36
+ [str(test_audio_path), "-m", "4stem", "-f", "FLAC"],
37
+ )
38
+ assert result.exit_code == 0
39
+ assert "Separation complete" in result.output
40
+
41
+ def test_custom_output_dir(
42
+ self, mock_separator, test_audio_path, tmp_path, env_settings
43
+ ):
44
+ """--output-dir flag should override the default."""
45
+ custom_dir = tmp_path / "custom_out"
46
+ runner = CliRunner()
47
+ result = runner.invoke(
48
+ main,
49
+ [str(test_audio_path), "-o", str(custom_dir)],
50
+ )
51
+ assert result.exit_code == 0
52
+ assert custom_dir.is_dir()
53
+
54
+ def test_output_lists_files(
55
+ self, mock_separator, test_audio_path, env_settings
56
+ ):
57
+ """Output should list the generated stem files."""
58
+ runner = CliRunner()
59
+ result = runner.invoke(main, [str(test_audio_path)])
60
+ assert "->" in result.output
61
+
62
+ def test_mode_shown_in_output(
63
+ self, mock_separator, test_audio_path, env_settings
64
+ ):
65
+ """The mode should be shown in processing output."""
66
+ runner = CliRunner()
67
+ result = runner.invoke(main, [str(test_audio_path), "-m", "2stem"])
68
+ assert "2stem" in result.output
69
+
70
+ def test_mp3_format(self, mock_separator, test_audio_path, env_settings):
71
+ """MP3 format flag should be accepted."""
72
+ runner = CliRunner()
73
+ result = runner.invoke(
74
+ main,
75
+ [str(test_audio_path), "-f", "MP3"],
76
+ )
77
+ assert result.exit_code == 0
78
+ assert "MP3" in result.output
tests/test_config.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for configuration loading."""
2
+
3
+ import pytest
4
+
5
+ from stemsplitter.config import Settings, get_settings
6
+
7
+
8
+ class TestSettings:
9
+ def test_defaults(self):
10
+ """Settings should provide sensible defaults."""
11
+ s = Settings()
12
+ assert s.output_format in ("WAV", "MP3", "FLAC")
13
+ assert s.sample_rate == 44100
14
+ assert s.normalization == 0.9
15
+ assert s.web_port == 7860
16
+ assert s.log_level == "WARNING"
17
+
18
+ def test_env_override(self, monkeypatch):
19
+ """Settings should pick up environment variable overrides."""
20
+ monkeypatch.setenv("STEMSPLITTER_OUTPUT_FORMAT", "FLAC")
21
+ monkeypatch.setenv("STEMSPLITTER_SAMPLE_RATE", "48000")
22
+ monkeypatch.setenv("STEMSPLITTER_WEB_PORT", "9090")
23
+ s = get_settings()
24
+ assert s.output_format == "FLAC"
25
+ assert s.sample_rate == 48000
26
+ assert s.web_port == 9090
27
+
28
+ def test_immutability(self):
29
+ """Settings instances should be frozen."""
30
+ s = Settings()
31
+ with pytest.raises(AttributeError):
32
+ s.output_dir = "/some/other/path" # type: ignore
33
+
34
+ def test_output_dir_default(self):
35
+ """Default output_dir should be ./output."""
36
+ s = Settings()
37
+ assert s.output_dir == "./output"
38
+
39
+ def test_model_defaults(self):
40
+ """Default models should be set for both modes."""
41
+ s = Settings()
42
+ assert "mel_band_roformer" in s.default_2stem_model
43
+ assert "htdemucs_ft" in s.default_4stem_model
44
+
45
+ def test_get_settings_returns_fresh_instance(self, monkeypatch):
46
+ """get_settings should create a new instance each time."""
47
+ s1 = get_settings()
48
+ monkeypatch.setenv("STEMSPLITTER_OUTPUT_FORMAT", "MP3")
49
+ s2 = get_settings()
50
+ assert s2.output_format == "MP3"
51
+ assert s1 is not s2
tests/test_separator.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the core StemSplitter class."""
2
+
3
+ import pytest
4
+
5
+ from stemsplitter.separator import (
6
+ STEM_LABELS,
7
+ OutputFormat,
8
+ SeparationResult,
9
+ StemMode,
10
+ StemSplitter,
11
+ )
12
+
13
+
14
+ class TestStemMode:
15
+ def test_two_stem_value(self):
16
+ assert StemMode.TWO_STEM.value == "2stem"
17
+
18
+ def test_four_stem_value(self):
19
+ assert StemMode.FOUR_STEM.value == "4stem"
20
+
21
+ def test_from_string(self):
22
+ assert StemMode("2stem") == StemMode.TWO_STEM
23
+ assert StemMode("4stem") == StemMode.FOUR_STEM
24
+
25
+
26
+ class TestOutputFormat:
27
+ def test_format_values(self):
28
+ assert OutputFormat.WAV.value == "WAV"
29
+ assert OutputFormat.MP3.value == "MP3"
30
+ assert OutputFormat.FLAC.value == "FLAC"
31
+
32
+
33
+ class TestStemLabels:
34
+ def test_two_stem_labels(self):
35
+ assert STEM_LABELS[StemMode.TWO_STEM] == ["Vocals", "Instrumental"]
36
+
37
+ def test_four_stem_labels(self):
38
+ assert STEM_LABELS[StemMode.FOUR_STEM] == [
39
+ "Vocals",
40
+ "Drums",
41
+ "Bass",
42
+ "Other",
43
+ ]
44
+
45
+
46
+ class TestStemSplitter:
47
+ def test_separate_2stem(self, mock_separator, test_audio_path, env_settings):
48
+ """2-stem separation should return 2 output files."""
49
+ splitter = StemSplitter()
50
+ result = splitter.separate(
51
+ input_path=test_audio_path,
52
+ mode=StemMode.TWO_STEM,
53
+ )
54
+ assert isinstance(result, SeparationResult)
55
+ assert len(result.output_files) == 2
56
+ assert result.mode == StemMode.TWO_STEM
57
+ mock_separator.load_model.assert_called_once()
58
+
59
+ def test_separate_4stem(
60
+ self, mock_separator_4stem, test_audio_path, env_settings
61
+ ):
62
+ """4-stem separation should return 4 output files."""
63
+ splitter = StemSplitter()
64
+ result = splitter.separate(
65
+ input_path=test_audio_path,
66
+ mode=StemMode.FOUR_STEM,
67
+ )
68
+ assert len(result.output_files) == 4
69
+ assert result.mode == StemMode.FOUR_STEM
70
+
71
+ def test_format_override(self, mock_separator, test_audio_path, env_settings):
72
+ """Output format override should be reflected in result."""
73
+ splitter = StemSplitter()
74
+ result = splitter.separate(
75
+ input_path=test_audio_path,
76
+ mode=StemMode.TWO_STEM,
77
+ output_format=OutputFormat.FLAC,
78
+ )
79
+ assert result.output_format == OutputFormat.FLAC
80
+
81
+ def test_model_caching(self, mock_separator, test_audio_path, env_settings):
82
+ """Same mode twice should NOT reload the model."""
83
+ splitter = StemSplitter()
84
+ splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
85
+ splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
86
+ assert mock_separator.load_model.call_count == 1
87
+
88
+ def test_model_switch(self, mock_separator, test_audio_path, env_settings):
89
+ """Switching modes should reload the model."""
90
+ splitter = StemSplitter()
91
+ splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
92
+ splitter.separate(test_audio_path, mode=StemMode.FOUR_STEM)
93
+ assert mock_separator.load_model.call_count == 2
94
+
95
+ def test_file_not_found(self, env_settings):
96
+ """Should raise FileNotFoundError for missing input."""
97
+ splitter = StemSplitter()
98
+ with pytest.raises(FileNotFoundError):
99
+ splitter.separate("/nonexistent/file.wav")
100
+
101
+ def test_model_override(self, mock_separator, test_audio_path, env_settings):
102
+ """Custom model_override should be passed through."""
103
+ splitter = StemSplitter()
104
+ splitter.separate(
105
+ test_audio_path,
106
+ mode=StemMode.TWO_STEM,
107
+ model_override="UVR_MDXNET_KARA_2.onnx",
108
+ )
109
+ mock_separator.load_model.assert_called_with(
110
+ model_filename="UVR_MDXNET_KARA_2.onnx"
111
+ )
112
+
113
+ def test_result_contains_input_file(
114
+ self, mock_separator, test_audio_path, env_settings
115
+ ):
116
+ """Result should reference the original input file."""
117
+ splitter = StemSplitter()
118
+ result = splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
119
+ assert result.input_file == str(test_audio_path)
120
+
121
+ def test_result_contains_model_used(
122
+ self, mock_separator, test_audio_path, env_settings
123
+ ):
124
+ """Result should reference which model was used."""
125
+ splitter = StemSplitter()
126
+ result = splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
127
+ assert "mel_band_roformer" in result.model_used
128
+
129
+ def test_separation_runtime_error(
130
+ self, mock_separator, test_audio_path, env_settings
131
+ ):
132
+ """RuntimeError should be raised if the underlying separator fails."""
133
+ mock_separator.separate.side_effect = Exception("Model crashed")
134
+ splitter = StemSplitter()
135
+ with pytest.raises(RuntimeError, match="Separation failed"):
136
+ splitter.separate(test_audio_path, mode=StemMode.TWO_STEM)
tests/test_web.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the Gradio web interface."""
2
+
3
+ import pytest
4
+
5
+ import stemsplitter.web as web_mod
6
+ from stemsplitter.web import create_app, separate_audio
7
+
8
+
9
+ @pytest.fixture(autouse=True)
10
+ def _reset_splitter_singleton():
11
+ """Reset the module-level splitter before each test."""
12
+ web_mod._splitter = None
13
+ yield
14
+ web_mod._splitter = None
15
+
16
+
17
+ class TestWebApp:
18
+ def test_app_creation(self):
19
+ """create_app() should return a Gradio Blocks instance."""
20
+ import gradio as gr
21
+
22
+ app = create_app()
23
+ assert isinstance(app, gr.Blocks)
24
+
25
+ def test_separate_audio_2stem(
26
+ self, mock_separator, test_audio_path, env_settings
27
+ ):
28
+ """Handler should return 4 values (2 real + 2 None) for 2-stem."""
29
+ outputs = separate_audio(
30
+ audio_path=str(test_audio_path),
31
+ mode="2stem",
32
+ output_format="WAV",
33
+ )
34
+ assert len(outputs) == 4
35
+ assert outputs[0] is not None
36
+ assert outputs[1] is not None
37
+ assert outputs[2] is None
38
+ assert outputs[3] is None
39
+
40
+ def test_separate_audio_4stem(
41
+ self, mock_separator_4stem, test_audio_path, env_settings
42
+ ):
43
+ """Handler should return 4 non-None values for 4-stem."""
44
+ outputs = separate_audio(
45
+ audio_path=str(test_audio_path),
46
+ mode="4stem",
47
+ output_format="WAV",
48
+ )
49
+ assert len(outputs) == 4
50
+ assert all(o is not None for o in outputs)
51
+
52
+ def test_separate_audio_no_file(self, env_settings):
53
+ """Handler should raise gr.Error when no file uploaded."""
54
+ import gradio as gr
55
+
56
+ with pytest.raises(gr.Error):
57
+ separate_audio(
58
+ audio_path="",
59
+ mode="2stem",
60
+ output_format="WAV",
61
+ )
62
+
63
+ def test_separate_audio_format_passed(
64
+ self, mock_separator, test_audio_path, env_settings
65
+ ):
66
+ """The output format choice should be forwarded to the splitter."""
67
+ outputs = separate_audio(
68
+ audio_path=str(test_audio_path),
69
+ mode="2stem",
70
+ output_format="FLAC",
71
+ )
72
+ assert len(outputs) == 4
uv.lock ADDED
The diff for this file is too large to render. See raw diff