File size: 2,690 Bytes
6027dac
ee4b2a5
31fd6da
 
 
 
 
 
cb6aae7
 
 
f7ddc5b
cb6aae7
f7ddc5b
ee4b2a5
 
 
 
f7ddc5b
 
cb6aae7
8ea9073
ee4b2a5
bf186e6
f7ddc5b
bf186e6
f7ddc5b
 
 
 
 
ee4b2a5
f7ddc5b
0116557
 
f7ddc5b
 
 
 
 
 
 
 
 
ee4b2a5
f7ddc5b
 
ee4b2a5
f7ddc5b
 
 
cb6aae7
f7ddc5b
ee4b2a5
cb6aae7
ee4b2a5
 
 
cb6aae7
f7ddc5b
 
 
 
 
 
 
 
 
cb6aae7
 
ee4b2a5
f7ddc5b
 
 
 
 
 
cb6aae7
ee4b2a5
cb6aae7
ee4b2a5
b6d192d
 
 
 
 
f7ddc5b
b6d192d
 
 
 
 
cb6aae7
 
 
b6d192d
f7ddc5b
 
cb6aae7
f7ddc5b
b6d192d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import spaces 
import torch
import torchaudio
import gradio as gr
from demucs import pretrained
from demucs.apply import apply_model
from audiotools import AudioSignal
from typing import Dict
from pyharp import *


DEMUX_MODELS = ["mdx_extra_q", "mdx_extra", "htdemucs", "mdx_q"]

STEM_CHOICES = {
    "Vocals": 3,
    "Drums": 0,
    "Bass": 1,
    "Other": 2,
    "Instrumental (No Vocals)": "instrumental"
}

@spaces.GPU(duration = 180)
def separate_stem(audio_file_path: str, model_name: str, stem_choice: str) -> AudioSignal:
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = pretrained.get_model(model_name)
    model.to(device)
    model.eval()

    waveform, sr = torchaudio.load(audio_file_path)
    is_mono = waveform.shape[0] == 1
    if is_mono:
        waveform = waveform.repeat(2, 1)

    waveform = waveform.to(device)

    with torch.no_grad():
        stems_batch = apply_model(
            model,
            waveform.unsqueeze(0),
            overlap=0.2,
            shifts=1,
            split=True
        )

    stems = stems_batch[0]

    if stem_choice == "Instrumental (No Vocals)":
        stem = stems[0] + stems[1] + stems[2]
    else:
        stem_index = STEM_CHOICES[stem_choice]
        stem = stems[stem_index]

    if is_mono:
        stem = stem.mean(dim=0, keepdim=True)

    return AudioSignal(stem.cpu().numpy().astype('float32'), sample_rate=sr)

# Gradio Callback Function

def process_fn_stem(audio_file_path: str, demucs_model: str, stem_choice: str):
    """
    PyHARP process function:
      - Separates the chosen stem using Demucs.
      - Saves the stem as a .wav file.
    """
    stem_signal = separate_stem(audio_file_path, model_name=demucs_model, stem_choice=stem_choice)
    stem_path = save_audio(stem_signal, f"{stem_choice.lower().replace(' ', '_')}.wav")
    return stem_path, LabelList(labels=[])


# Model Card 
model_card = ModelCard(
    name="Demucs Stem Separator",
    description="Uses Demucs to separate a music track into a selected stem.",
    author="Alexandre Défossez, Nicolas Usunier, Léon Bottou, Francis Bach",
    tags=["demucs", "source-separation", "pyharp", "stems"]
)

# Gradio UI
with gr.Blocks() as demo:

    dropdown_model = gr.Dropdown(
        label="Select Demucs Model",
        choices=DEMUX_MODELS,
        value="mdx_extra_q"
    )

    dropdown_stem = gr.Dropdown(
        label="Select Stem to Separate",
        choices=list(STEM_CHOICES.keys()),
        value="Vocals"
    )

    app = build_endpoint(
        model_card=model_card,
        components=[dropdown_model, dropdown_stem],
        process_fn=process_fn_stem
    )

demo.queue()
demo.launch(show_error=True)