Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import spaces | |
| # --- 1. Load your pre-trained model --- | |
| bundle = torchaudio.pipelines.HDEMUCS_HIGH_MUSDB_PLUS | |
| model = bundle.get_model() | |
| sample_rate = bundle.sample_rate | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| def separate_drums(audio_path): | |
| # a. Load audio | |
| waveform, sr = torchaudio.load(audio_path) | |
| # b. Resample if necessary | |
| if sr != sample_rate: | |
| resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) | |
| waveform = resampler(waveform) | |
| # c. Ensure stereo | |
| if waveform.shape[0] == 1: | |
| waveform = waveform.repeat(2, 1) | |
| waveform = waveform.unsqueeze(0).to(device) | |
| # d. Perform source separation | |
| with torch.no_grad(): | |
| separated_sources = model(waveform) | |
| # e. Extract and save the drum track | |
| source_names = ["drums", "bass", "other", "vocals"] | |
| drums_index = source_names.index("drums") | |
| drums_waveform = separated_sources[0, drums_index].cpu() | |
| output_path = "drums_output.mp3" | |
| torchaudio.save(output_path, drums_waveform, sample_rate) | |
| return output_path | |
| # --- 2. Create the Gradio Interface --- | |
| with gr.Blocks() as iface: | |
| gr.Markdown( | |
| f""" | |
| # 🥁 Drum Separator | |
| Upload an audio file to isolate the drum track. | |
| CUDA Enabled: **{"YES" if torch.cuda.is_available() else "NO"}** | |
| """ | |
| ) | |
| gr.Interface( | |
| fn=separate_drums, | |
| inputs=gr.Audio(type="filepath", label="Upload Your Song"), | |
| outputs=gr.Audio(label="Isolated Drum Track"), | |
| ) | |
| # --- 3. Launch the App --- | |
| iface.launch() | |