import os import tempfile import zipfile import gradio as gr import torch import numpy as np import wave from demucs.pretrained import get_model from demucs.audio import AudioFile from demucs.apply import apply_model # Configuration MODEL_NAME = os.environ.get("DEMUCS_MODEL", "htdemucs") TARGET_SAMPLE_RATE = 44100 TARGET_NUM_CHANNELS = 2 MAX_DURATION_SECONDS = int(os.environ.get("MAX_DURATION_SECONDS", "30")) # Global model cache _loaded_model = None _inference_device = "cuda" if torch.cuda.is_available() else "cpu" def load_demucs_model(): """Load and cache the Demucs model""" global _loaded_model if _loaded_model is None: print(f"Loading model {MODEL_NAME} on device {_inference_device}") model = get_model(MODEL_NAME) model.to(_inference_device) model.eval() _loaded_model = model return _loaded_model def separate_stems(audio_file): """ Separate audio into stems using Demucs Args: audio_file: Gradio audio input (tuple of sample_rate, audio_data) Returns: List of separated stem audio files """ if audio_file is None: return "Please upload an audio file first." try: # Load model model = load_demucs_model() # Extract audio data and sample rate from Gradio input sample_rate, audio_data = audio_file # Convert to torch tensor and normalize if len(audio_data.shape) == 1: # Mono to stereo audio_data = np.stack([audio_data, audio_data]) elif len(audio_data.shape) == 2 and audio_data.shape[0] > audio_data.shape[1]: # Transpose if needed (samples, channels) -> (channels, samples) audio_data = audio_data.T # Resample to target sample rate if needed if sample_rate != TARGET_SAMPLE_RATE: # Simple resampling - in production you'd want proper resampling ratio = TARGET_SAMPLE_RATE / sample_rate new_length = int(audio_data.shape[1] * ratio) audio_data = np.array([np.interp(np.linspace(0, len(channel), new_length), np.arange(len(channel)), channel) for channel in audio_data]) # Limit duration max_samples = TARGET_SAMPLE_RATE * MAX_DURATION_SECONDS if audio_data.shape[1] > max_samples: audio_data = audio_data[:, :max_samples] # Normalize audio reference_channel = audio_data.mean(0) audio_data = (audio_data - reference_channel.mean()) / (reference_channel.std() + 1e-8) # Convert to tensor audio_tensor = torch.tensor(audio_data, dtype=torch.float32, device=_inference_device) audio_tensor = audio_tensor.unsqueeze(0) # Add batch dimension # Separate stems with torch.no_grad(): separated_sources = apply_model( model, audio_tensor, split=True, overlap=0.10, shifts=0, )[0].to("cpu") # Get source names source_names = getattr(model, "sources", ["drums", "bass", "other", "vocals"]) # Create temporary directory for output files with tempfile.TemporaryDirectory() as tmp_dir: output_files = [] for source_index, source_name in enumerate(source_names): stem_tensor = separated_sources[source_index] # De-normalize stem_tensor = stem_tensor * (reference_channel.std() + 1e-8) + reference_channel.mean() stem_np = stem_tensor.transpose(0, 1).numpy() # [samples, channels] # Clip and convert to int16 stem_np = np.clip(stem_np, -1.0, 1.0) # Save as audio file that Gradio can handle output_path = os.path.join(tmp_dir, f"{source_name}.wav") with wave.open(output_path, "wb") as wf: wf.setnchannels(TARGET_NUM_CHANNELS) wf.setsampwidth(2) # 16-bit wf.setframerate(TARGET_SAMPLE_RATE) pcm = (stem_np * 32767.0).astype(np.int16) wf.writeframes(pcm.tobytes()) # Copy to a permanent location for Gradio import shutil permanent_path = f"/tmp/{source_name}.wav" shutil.copy2(output_path, permanent_path) output_files.append(permanent_path) return output_files except Exception as e: return f"Error during separation: {str(e)}" def create_interface(): """Create the Gradio interface""" # Custom CSS for better styling css = """ .gradio-container { max-width: 800px; margin: auto; } .header { text-align: center; margin-bottom: 2rem; } .footer { text-align: center; margin-top: 2rem; color: #666; } """ with gr.Blocks(css=css, title="RiffRaff - AI Music Stem Separation") as demo: gr.HTML("""

đŸŽĩ RiffRaff - AI Music Stem Separation

Upload a song and separate it into individual stems (drums, bass, vocals, other)

Powered by Facebook's Demucs AI model

""") with gr.Row(): with gr.Column(): audio_input = gr.Audio( label="Upload Audio File", type="numpy", format="wav" ) separate_btn = gr.Button( "đŸŽ›ī¸ Separate Stems", variant="primary", size="lg" ) gr.HTML(f"""
â„šī¸ Info:
""") with gr.Row(): with gr.Column(): gr.HTML("

đŸŽĩ Separated Stems

") drums_output = gr.Audio(label="đŸĨ Drums", interactive=False) bass_output = gr.Audio(label="🎸 Bass", interactive=False) vocals_output = gr.Audio(label="🎤 Vocals", interactive=False) other_output = gr.Audio(label="🎹 Other", interactive=False) def process_and_display(audio_file): """Process audio and return individual stems""" if audio_file is None: return None, None, None, None result = separate_stems(audio_file) if isinstance(result, str): # Error message gr.Warning(result) return None, None, None, None # Return the four stems in order: drums, bass, other, vocals # (matching the typical Demucs output order) stems = [None, None, None, None] for i, path in enumerate(result): if i < 4: stems[i] = path return stems[0], stems[1], stems[3], stems[2] # drums, bass, vocals, other separate_btn.click( fn=process_and_display, inputs=[audio_input], outputs=[drums_output, bass_output, vocals_output, other_output], show_progress=True ) gr.HTML(""" """) return demo # Preload model when the module is imported print("Preloading Demucs model...") try: load_demucs_model() print("Model loaded successfully!") except Exception as e: print(f"Warning: Could not preload model: {e}") # Create and launch the interface if __name__ == "__main__": demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=True )