demucs / app.py
Nekochu's picture
Init
ddabbfb verified
"""Demucs Audio Stem Separator - Separate audio into vocals, drums, bass, other."""
import os
import tempfile
import gradio as gr
import torch
import soundfile as sf
import numpy as np
import librosa
from demucs.pretrained import get_model
from demucs.apply import apply_model
device = torch.device("cpu")
_model_cache = {}
def load_model(model_name: str):
"""Load and cache Demucs model."""
if model_name not in _model_cache:
model = get_model(model_name)
model.to(device).eval()
_model_cache[model_name] = model
return _model_cache[model_name]
def separate_stems(audio_file: str, model_name: str = "htdemucs", progress=gr.Progress()) -> tuple[str, str, str, str]:
"""
Separate audio into stems using Demucs.
Args:
audio_file: Path to the input audio file (MP3, WAV, FLAC, OGG, M4A)
model_name: Demucs model - htdemucs (fast) or htdemucs_ft (better quality)
Returns:
Tuple of paths to separated stems: (vocals, drums, bass, other)
"""
if audio_file is None:
raise gr.Error("Please upload an audio file")
progress(0.1, desc="Loading model...")
model = load_model(model_name)
progress(0.2, desc="Loading audio...")
audio_data, sr = sf.read(audio_file)
if audio_data.ndim == 1:
waveform = torch.from_numpy(audio_data.astype(np.float32)).unsqueeze(0)
else:
waveform = torch.from_numpy(audio_data.T.astype(np.float32))
target_sr = model.samplerate
if sr != target_sr:
progress(0.25, desc="Resampling...")
waveform_np = waveform.numpy()
resampled = [librosa.resample(waveform_np[ch], orig_sr=sr, target_sr=target_sr)
for ch in range(waveform_np.shape[0])]
waveform = torch.from_numpy(np.stack(resampled))
if waveform.shape[0] == 1:
waveform = waveform.repeat(2, 1)
elif waveform.shape[0] > 2:
waveform = waveform[:2, :]
waveform = waveform.unsqueeze(0).to(device)
progress(0.3, desc="Separating stems...")
with torch.no_grad():
sources = apply_model(model, waveform, device=device, progress=True)
progress(0.9, desc="Saving...")
output_dir = tempfile.mkdtemp()
output_files = {}
for i, name in enumerate(model.sources):
stem = sources[0, i].cpu().numpy()
path = os.path.join(output_dir, f"{name}.wav")
sf.write(path, stem.T, target_sr)
output_files[name] = path
return output_files.get("vocals"), output_files.get("drums"), output_files.get("bass"), output_files.get("other")
with gr.Blocks(title="Demucs Stem Separator") as demo:
gr.Markdown("# 🎡 Demucs Stem Separator\nSeparate audio β†’ **Vocals** | **Drums** | **Bass** | **Other**")
with gr.Row():
with gr.Column(scale=1):
audio_input = gr.Audio(label="Audio File", type="filepath", sources=["upload"])
model_choice = gr.Dropdown(["htdemucs", "htdemucs_ft"], value="htdemucs", label="Model",
info="htdemucs: Fast | htdemucs_ft: Better quality")
separate_btn = gr.Button("Separate Stems", variant="primary", size="lg")
with gr.Column(scale=2):
with gr.Row():
vocals = gr.Audio(label="🎀 Vocals", type="filepath")
drums = gr.Audio(label="πŸ₯ Drums", type="filepath")
with gr.Row():
bass = gr.Audio(label="🎸 Bass", type="filepath")
other = gr.Audio(label="🎹 Other", type="filepath")
separate_btn.click(separate_stems, [audio_input, model_choice], [vocals, drums, bass, other], api_name="separate")
gr.Examples(
examples=[["example.wav", "htdemucs"]],
inputs=[audio_input, model_choice],
outputs=[vocals, drums, bass, other],
fn=separate_stems,
cache_examples=True,
cache_mode="lazy",
label="Example"
)
if __name__ == "__main__":
demo.launch(mcp_server=True, show_error=True)