File size: 4,004 Bytes
ddabbfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""Demucs Audio Stem Separator - Separate audio into vocals, drums, bass, other."""
import os
import tempfile
import gradio as gr
import torch
import soundfile as sf
import numpy as np
import librosa
from demucs.pretrained import get_model
from demucs.apply import apply_model

device = torch.device("cpu")
_model_cache = {}


def load_model(model_name: str):
    """Load and cache Demucs model."""
    if model_name not in _model_cache:
        model = get_model(model_name)
        model.to(device).eval()
        _model_cache[model_name] = model
    return _model_cache[model_name]


def separate_stems(audio_file: str, model_name: str = "htdemucs", progress=gr.Progress()) -> tuple[str, str, str, str]:
    """
    Separate audio into stems using Demucs.

    Args:
        audio_file: Path to the input audio file (MP3, WAV, FLAC, OGG, M4A)
        model_name: Demucs model - htdemucs (fast) or htdemucs_ft (better quality)

    Returns:
        Tuple of paths to separated stems: (vocals, drums, bass, other)
    """
    if audio_file is None:
        raise gr.Error("Please upload an audio file")

    progress(0.1, desc="Loading model...")
    model = load_model(model_name)

    progress(0.2, desc="Loading audio...")
    audio_data, sr = sf.read(audio_file)
    if audio_data.ndim == 1:
        waveform = torch.from_numpy(audio_data.astype(np.float32)).unsqueeze(0)
    else:
        waveform = torch.from_numpy(audio_data.T.astype(np.float32))

    target_sr = model.samplerate
    if sr != target_sr:
        progress(0.25, desc="Resampling...")
        waveform_np = waveform.numpy()
        resampled = [librosa.resample(waveform_np[ch], orig_sr=sr, target_sr=target_sr)
                     for ch in range(waveform_np.shape[0])]
        waveform = torch.from_numpy(np.stack(resampled))

    if waveform.shape[0] == 1:
        waveform = waveform.repeat(2, 1)
    elif waveform.shape[0] > 2:
        waveform = waveform[:2, :]

    waveform = waveform.unsqueeze(0).to(device)

    progress(0.3, desc="Separating stems...")
    with torch.no_grad():
        sources = apply_model(model, waveform, device=device, progress=True)

    progress(0.9, desc="Saving...")
    output_dir = tempfile.mkdtemp()
    output_files = {}
    for i, name in enumerate(model.sources):
        stem = sources[0, i].cpu().numpy()
        path = os.path.join(output_dir, f"{name}.wav")
        sf.write(path, stem.T, target_sr)
        output_files[name] = path

    return output_files.get("vocals"), output_files.get("drums"), output_files.get("bass"), output_files.get("other")


with gr.Blocks(title="Demucs Stem Separator") as demo:
    gr.Markdown("# 🎵 Demucs Stem Separator\nSeparate audio → **Vocals** | **Drums** | **Bass** | **Other**")

    with gr.Row():
        with gr.Column(scale=1):
            audio_input = gr.Audio(label="Audio File", type="filepath", sources=["upload"])
            model_choice = gr.Dropdown(["htdemucs", "htdemucs_ft"], value="htdemucs", label="Model",
                                       info="htdemucs: Fast | htdemucs_ft: Better quality")
            separate_btn = gr.Button("Separate Stems", variant="primary", size="lg")

        with gr.Column(scale=2):
            with gr.Row():
                vocals = gr.Audio(label="🎤 Vocals", type="filepath")
                drums = gr.Audio(label="🥁 Drums", type="filepath")
            with gr.Row():
                bass = gr.Audio(label="🎸 Bass", type="filepath")
                other = gr.Audio(label="🎹 Other", type="filepath")

    separate_btn.click(separate_stems, [audio_input, model_choice], [vocals, drums, bass, other], api_name="separate")

    gr.Examples(
        examples=[["example.wav", "htdemucs"]],
        inputs=[audio_input, model_choice],
        outputs=[vocals, drums, bass, other],
        fn=separate_stems,
        cache_examples=True,
        cache_mode="lazy",
        label="Example"
    )

if __name__ == "__main__":
    demo.launch(mcp_server=True, show_error=True)