"""Demucs Audio Stem Separator - Separate audio into vocals, drums, bass, other.""" import os import tempfile import gradio as gr import torch import soundfile as sf import numpy as np import librosa from demucs.pretrained import get_model from demucs.apply import apply_model device = torch.device("cpu") _model_cache = {} def load_model(model_name: str): """Load and cache Demucs model.""" if model_name not in _model_cache: model = get_model(model_name) model.to(device).eval() _model_cache[model_name] = model return _model_cache[model_name] def separate_stems(audio_file: str, model_name: str = "htdemucs", progress=gr.Progress()) -> tuple[str, str, str, str]: """ Separate audio into stems using Demucs. Args: audio_file: Path to the input audio file (MP3, WAV, FLAC, OGG, M4A) model_name: Demucs model - htdemucs (fast) or htdemucs_ft (better quality) Returns: Tuple of paths to separated stems: (vocals, drums, bass, other) """ if audio_file is None: raise gr.Error("Please upload an audio file") progress(0.1, desc="Loading model...") model = load_model(model_name) progress(0.2, desc="Loading audio...") audio_data, sr = sf.read(audio_file) if audio_data.ndim == 1: waveform = torch.from_numpy(audio_data.astype(np.float32)).unsqueeze(0) else: waveform = torch.from_numpy(audio_data.T.astype(np.float32)) target_sr = model.samplerate if sr != target_sr: progress(0.25, desc="Resampling...") waveform_np = waveform.numpy() resampled = [librosa.resample(waveform_np[ch], orig_sr=sr, target_sr=target_sr) for ch in range(waveform_np.shape[0])] waveform = torch.from_numpy(np.stack(resampled)) if waveform.shape[0] == 1: waveform = waveform.repeat(2, 1) elif waveform.shape[0] > 2: waveform = waveform[:2, :] waveform = waveform.unsqueeze(0).to(device) progress(0.3, desc="Separating stems...") with torch.no_grad(): sources = apply_model(model, waveform, device=device, progress=True) progress(0.9, desc="Saving...") output_dir = tempfile.mkdtemp() output_files = {} for i, name in enumerate(model.sources): stem = sources[0, i].cpu().numpy() path = os.path.join(output_dir, f"{name}.wav") sf.write(path, stem.T, target_sr) output_files[name] = path return output_files.get("vocals"), output_files.get("drums"), output_files.get("bass"), output_files.get("other") with gr.Blocks(title="Demucs Stem Separator") as demo: gr.Markdown("# 🎵 Demucs Stem Separator\nSeparate audio → **Vocals** | **Drums** | **Bass** | **Other**") with gr.Row(): with gr.Column(scale=1): audio_input = gr.Audio(label="Audio File", type="filepath", sources=["upload"]) model_choice = gr.Dropdown(["htdemucs", "htdemucs_ft"], value="htdemucs", label="Model", info="htdemucs: Fast | htdemucs_ft: Better quality") separate_btn = gr.Button("Separate Stems", variant="primary", size="lg") with gr.Column(scale=2): with gr.Row(): vocals = gr.Audio(label="🎤 Vocals", type="filepath") drums = gr.Audio(label="🥁 Drums", type="filepath") with gr.Row(): bass = gr.Audio(label="🎸 Bass", type="filepath") other = gr.Audio(label="🎹 Other", type="filepath") separate_btn.click(separate_stems, [audio_input, model_choice], [vocals, drums, bass, other], api_name="separate") gr.Examples( examples=[["example.wav", "htdemucs"]], inputs=[audio_input, model_choice], outputs=[vocals, drums, bass, other], fn=separate_stems, cache_examples=True, cache_mode="lazy", label="Example" ) if __name__ == "__main__": demo.launch(mcp_server=True, show_error=True)