ymcnabb's picture
Upload folder using huggingface_hub
1824ea0 verified
"""Gradio web interface for StemSplitter."""
from __future__ import annotations
import logging
from pathlib import Path
import gradio as gr
from stemsplitter.config import get_settings
from stemsplitter.separator import OutputFormat, StemMode, StemSplitter
logger = logging.getLogger(__name__)
_splitter: StemSplitter | None = None
def _get_splitter() -> StemSplitter:
"""Get or create the module-level StemSplitter singleton."""
global _splitter
if _splitter is None:
settings = get_settings()
Path(settings.output_dir).mkdir(parents=True, exist_ok=True)
_splitter = StemSplitter(settings=settings)
return _splitter
def separate_audio(
audio_path: str,
mode: str,
output_format: str,
progress: gr.Progress = gr.Progress(),
) -> list[str | None]:
"""Gradio handler: separate audio and return stem file paths.
Returns a list of 4 file paths (padding with None for 2-stem mode).
"""
if not audio_path:
raise gr.Error("Please upload an audio file.")
progress(0.1, desc="Initializing model...")
splitter = _get_splitter()
stem_mode = StemMode(mode)
fmt = OutputFormat(output_format)
progress(0.3, desc=f"Separating stems ({stem_mode.value})...")
result = splitter.separate(
input_path=audio_path,
mode=stem_mode,
output_format=fmt,
)
progress(1.0, desc="Done!")
outputs = list(result.output_files)
while len(outputs) < 4:
outputs.append(None)
return outputs[:4]
def create_app() -> gr.Blocks:
"""Build and return the Gradio Blocks application."""
with gr.Blocks(title="StemSplitter") as app:
gr.Markdown("# StemSplitter\nSeparate audio into individual stems.")
with gr.Row():
with gr.Column(scale=1):
audio_input = gr.Audio(
label="Upload Audio",
type="filepath",
sources=["upload"],
)
mode_radio = gr.Radio(
choices=["2stem", "4stem"],
value="2stem",
label="Separation Mode",
info="2-stem: Vocals + Instrumental | 4-stem: Vocals + Drums + Bass + Other",
)
format_radio = gr.Radio(
choices=["WAV", "MP3", "FLAC"],
value="WAV",
label="Output Format",
)
separate_btn = gr.Button("Separate", variant="primary")
with gr.Column(scale=2):
vocals_output = gr.Audio(label="Vocals", type="filepath")
instrumental_output = gr.Audio(
label="Instrumental", type="filepath"
)
drums_output = gr.Audio(
label="Drums",
type="filepath",
visible=False,
)
bass_output = gr.Audio(
label="Bass",
type="filepath",
visible=False,
)
def update_outputs_visibility(mode: str):
is_4stem = mode == "4stem"
return (
gr.update(label="Instrumental" if not is_4stem else "Other"),
gr.update(visible=is_4stem),
gr.update(visible=is_4stem),
)
mode_radio.change(
fn=update_outputs_visibility,
inputs=[mode_radio],
outputs=[instrumental_output, drums_output, bass_output],
)
separate_btn.click(
fn=separate_audio,
inputs=[audio_input, mode_radio, format_radio],
outputs=[
vocals_output,
instrumental_output,
drums_output,
bass_output,
],
)
return app
def launch() -> None:
"""Entry point for `stemsplitter-web` console script."""
settings = get_settings()
app = create_app()
app.launch(
server_name=settings.web_host,
server_port=settings.web_port,
theme=gr.themes.Soft(),
share=True,
)