File size: 1,683 Bytes
44158a3
 
4df5c73
12090e7
44158a3
 
 
4df5c73
 
 
 
 
 
 
12090e7
44158a3
4df5c73
 
 
 
 
 
 
 
 
 
 
 
 
44158a3
4df5c73
 
 
44158a3
4df5c73
 
 
 
44158a3
3f04617
4df5c73
44158a3
 
 
 
 
12090e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44158a3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gradio as gr
import torch
import torchaudio
import spaces


# --- 1. Load your pre-trained model ---
bundle = torchaudio.pipelines.HDEMUCS_HIGH_MUSDB_PLUS
model = bundle.get_model()
sample_rate = bundle.sample_rate
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


@spaces.GPU
def separate_drums(audio_path):
    # a. Load audio
    waveform, sr = torchaudio.load(audio_path)

    # b. Resample if necessary
    if sr != sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
        waveform = resampler(waveform)

    # c. Ensure stereo
    if waveform.shape[0] == 1:
        waveform = waveform.repeat(2, 1)

    waveform = waveform.unsqueeze(0).to(device)

    # d. Perform source separation
    with torch.no_grad():
        separated_sources = model(waveform)

    # e. Extract and save the drum track
    source_names = ["drums", "bass", "other", "vocals"]
    drums_index = source_names.index("drums")
    drums_waveform = separated_sources[0, drums_index].cpu()

    output_path = "drums_output.mp3"
    torchaudio.save(output_path, drums_waveform, sample_rate)

    return output_path


# --- 2. Create the Gradio Interface ---
with gr.Blocks() as iface:
    gr.Markdown(
        f"""
        # 🥁 Drum Separator

        Upload an audio file to isolate the drum track.

        CUDA Enabled: **{"YES" if torch.cuda.is_available() else "NO"}**
        """
    )
    gr.Interface(
        fn=separate_drums,
        inputs=gr.Audio(type="filepath", label="Upload Your Song"),
        outputs=gr.Audio(label="Isolated Drum Track"),
    )

# --- 3. Launch the App ---
iface.launch()