import os import tempfile import zipfile import gradio as gr import torch import numpy as np import wave from demucs.pretrained import get_model from demucs.audio import AudioFile from demucs.apply import apply_model # Configuration MODEL_NAME = os.environ.get("DEMUCS_MODEL", "htdemucs") TARGET_SAMPLE_RATE = 44100 TARGET_NUM_CHANNELS = 2 MAX_DURATION_SECONDS = int(os.environ.get("MAX_DURATION_SECONDS", "30")) # Global model cache _loaded_model = None _inference_device = "cuda" if torch.cuda.is_available() else "cpu" def load_demucs_model(): """Load and cache the Demucs model""" global _loaded_model if _loaded_model is None: print(f"Loading model {MODEL_NAME} on device {_inference_device}") model = get_model(MODEL_NAME) model.to(_inference_device) model.eval() _loaded_model = model return _loaded_model def separate_stems(audio_file): """ Separate audio into stems using Demucs Args: audio_file: Gradio audio input (tuple of sample_rate, audio_data) Returns: List of separated stem audio files """ if audio_file is None: return "Please upload an audio file first." try: # Load model model = load_demucs_model() # Extract audio data and sample rate from Gradio input sample_rate, audio_data = audio_file # Convert to torch tensor and normalize if len(audio_data.shape) == 1: # Mono to stereo audio_data = np.stack([audio_data, audio_data]) elif len(audio_data.shape) == 2 and audio_data.shape[0] > audio_data.shape[1]: # Transpose if needed (samples, channels) -> (channels, samples) audio_data = audio_data.T # Resample to target sample rate if needed if sample_rate != TARGET_SAMPLE_RATE: # Simple resampling - in production you'd want proper resampling ratio = TARGET_SAMPLE_RATE / sample_rate new_length = int(audio_data.shape[1] * ratio) audio_data = np.array([np.interp(np.linspace(0, len(channel), new_length), np.arange(len(channel)), channel) for channel in audio_data]) # Limit duration max_samples = TARGET_SAMPLE_RATE * MAX_DURATION_SECONDS if audio_data.shape[1] > max_samples: audio_data = audio_data[:, :max_samples] # Normalize audio reference_channel = audio_data.mean(0) audio_data = (audio_data - reference_channel.mean()) / (reference_channel.std() + 1e-8) # Convert to tensor audio_tensor = torch.tensor(audio_data, dtype=torch.float32, device=_inference_device) audio_tensor = audio_tensor.unsqueeze(0) # Add batch dimension # Separate stems with torch.no_grad(): separated_sources = apply_model( model, audio_tensor, split=True, overlap=0.10, shifts=0, )[0].to("cpu") # Get source names source_names = getattr(model, "sources", ["drums", "bass", "other", "vocals"]) # Create temporary directory for output files with tempfile.TemporaryDirectory() as tmp_dir: output_files = [] for source_index, source_name in enumerate(source_names): stem_tensor = separated_sources[source_index] # De-normalize stem_tensor = stem_tensor * (reference_channel.std() + 1e-8) + reference_channel.mean() stem_np = stem_tensor.transpose(0, 1).numpy() # [samples, channels] # Clip and convert to int16 stem_np = np.clip(stem_np, -1.0, 1.0) # Save as audio file that Gradio can handle output_path = os.path.join(tmp_dir, f"{source_name}.wav") with wave.open(output_path, "wb") as wf: wf.setnchannels(TARGET_NUM_CHANNELS) wf.setsampwidth(2) # 16-bit wf.setframerate(TARGET_SAMPLE_RATE) pcm = (stem_np * 32767.0).astype(np.int16) wf.writeframes(pcm.tobytes()) # Copy to a permanent location for Gradio import shutil permanent_path = f"/tmp/{source_name}.wav" shutil.copy2(output_path, permanent_path) output_files.append(permanent_path) return output_files except Exception as e: return f"Error during separation: {str(e)}" def create_interface(): """Create the Gradio interface""" # Custom CSS for better styling css = """ .gradio-container { max-width: 800px; margin: auto; } .header { text-align: center; margin-bottom: 2rem; } .footer { text-align: center; margin-top: 2rem; color: #666; } """ with gr.Blocks(css=css, title="RiffRaff - AI Music Stem Separation") as demo: gr.HTML("""
Upload a song and separate it into individual stems (drums, bass, vocals, other)
Powered by Facebook's Demucs AI model