demucs-cpu / app.py
lllindsey0615's picture
added model and stem selection
31bdbd1
raw
history blame
3.26 kB
import torch
import torchaudio
import gradio as gr
from demucs import pretrained
from demucs.apply import apply_model
from pyharp import *
from audiotools import AudioSignal
# Available Demucs models
DEMUX_MODELS = ["mdx_extra_q", "mdx_extra", "htdemucs", "mdx_q"]
STEM_CHOICES = {
"Vocals": 3,
"Drums": 0,
"Bass": 1,
"Other": 2,
"Instrumental (No Vocals)": "instrumental"
}
def separate_stem(audio_file_path: str, model_name: str, stem_choice: str):
"""
Separates an audio file into the chosen stem using a Demucs model.
Ensures correct stem ordering and supports mono input.
"""
# Load Demucs model
model = pretrained.get_model(model_name)
model.to('cuda' if torch.cuda.is_available() else 'cpu')
model.eval()
# Load the audio file
waveform, sr = torchaudio.load(audio_file_path)
# Check if input is mono
is_mono = waveform.shape[0] == 1
if is_mono:
waveform = waveform.repeat(2, 1) # Convert mono to stereo for Demucs
# Apply Demucs model
with torch.no_grad():
stems_batch = apply_model(
model,
waveform.unsqueeze(0),
overlap=0.2,
shifts=1,
split=True
)
# stems shape: (batch, stems, channels, samples)
stems = stems_batch[0]
print(f"Model '{model_name}' extracted stems shape: {stems.shape}")
if stem_choice == "Instrumental (No Vocals)":
stem = stems[0] + stems[1] + stems[2] # Drums + Bass + Other
else:
stem_index = STEM_CHOICES[stem_choice]
stem = stems[stem_index]
# Convert back to mono if the input was originally mono
if is_mono:
stem = stem.mean(dim=0, keepdim=True) # Stereo → Mono
# Convert to AudioSignal with float32 dtype
stem_signal = AudioSignal(stem.cpu().numpy().astype('float32'), sample_rate=sr)
return stem_signal
def process_fn_stem(audio_file_path: str, demucs_model: str, stem_choice: str):
"""
PyHARP process function:
- Separates the chosen stem using Demucs.
- Saves the stem as a .wav file.
"""
stem_signal = separate_stem(audio_file_path, model_name=demucs_model, stem_choice=stem_choice)
stem_path = save_audio(stem_signal, f"{stem_choice.lower().replace(' ', '_')}.wav")
return stem_path, LabelList(labels=[])
# Define the model card
model_card = ModelCard(
name="Demucs Stem Separator",
description="Uses Demucs to separate a music track into a selected stem.",
author="Alexandre Défossez, Nicolas Usunier, Léon Bottou, Francis Bach",
tags=["demucs", "source-separation", "pyharp", "stems"]
)
# Build Gradio interface with dropdowns for model and stem selection
with gr.Blocks() as demo:
components = [
gr.Dropdown(
label="Select Demucs Model",
choices=DEMUX_MODELS,
value="mdx_extra_q"
),
gr.Dropdown(
label="Select Stem to Separate",
choices=list(STEM_CHOICES.keys()),
value="Vocals"
)
]
app = build_endpoint(
model_card=model_card,
components=components,
process_fn=process_fn_stem
)
demo.queue()
demo.launch(share=True, show_error=True)