import gradio as gr import torch import torch.nn.functional as F import librosa import soundfile as sf import numpy as np import zipfile from huggingface_hub import hf_hub_download # ================= MODEL (unchanged) ================= try: from bs_roformer import BSRoformer from attend import Attend except ImportError: pass DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def safe_attend_forward(self, q, k, v, mask=None): return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) try: Attend.forward = safe_attend_forward except Exception: pass def load_model(): ckpt = hf_hub_download( repo_id="Tachyeon/IAM-RoFormer-Model-Weights", filename="v11_consensus_epoch_30.pt" ) model = BSRoformer( dim=512, depth=12, stereo=True, num_stems=4, time_transformer_depth=1, freq_transformer_depth=1, flash_attn=True ).to(DEVICE) state = torch.load(ckpt, map_location=DEVICE) model.load_state_dict(state["model"] if "model" in state else state) model.eval() return model model = load_model() # ================= SEPARATION + ZIP ================= def separate_audio(path): if not path: return [None]*5 mix, sr = librosa.load(path, sr=44100, mono=False) if mix.ndim == 1: mix = np.stack([mix, mix]) x = torch.tensor(mix).float().to(DEVICE)[None] L = x.shape[-1] out = torch.zeros(1,4,2,L, device=DEVICE) cnt = torch.zeros_like(out) chunk = 44100*10 hop = chunk - 44100 with torch.no_grad(), torch.autocast("cuda", enabled=torch.cuda.is_available()): for s in range(0, L, hop): e = min(s+chunk, L) part = x[:,:,s:e] if part.shape[-1] < chunk: part = F.pad(part,(0,chunk-part.shape[-1])) pred = model(part) out[:,:,:,s:e] += pred[:,:,:,:e-s] cnt[:,:,:,s:e] += 1 stems = (out / cnt.clamp(min=1)).cpu().numpy()[0] names = [ "lead_vocals.wav", "mridangam_percussion.wav", "tanpura_drone.wav", "violin_accompaniment.wav" ] for i, name in enumerate(names): sf.write(name, stems[i].T, sr) zip_path = "stems.zip" with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as z: for name in names: z.write(name) # return stems + zip + signal to show download button return [names[0], names[1], names[2], names[3], zip_path, gr.update(visible=True)] # ================= UI (VISUALLY UNCHANGED) ================= css = """ @import url('https://fonts.googleapis.com/css2?family=Anton&family=Poppins:wght@400;500;600&display=swap'); :root{ --bg1:#2a141d; --bg2:#14080d; --ink:#f3ece6; --muted:#b6aeb0; --accent:#ff6f9f; --panel: rgba(255,255,255,0.03); --panel-2: rgba(255,255,255,0.02); --radius: 12px; } html, body, .gradio-container { height: 100%; margin: 0; background: linear-gradient(180deg, var(--bg1), var(--bg2)) !important; color: var(--ink); font-family: Poppins, sans-serif; } .app { max-width: 1160px; margin: 0 auto; padding: 48px 40px; display: grid; grid-template-rows: auto 1fr; gap: 36px; } .brand { display:flex; flex-direction:column; gap:6px; } .logo { font-family: Anton, sans-serif; font-size:46px; } .tagline { font-size:14px; color:var(--accent); opacity:0.9; } .main { display:grid; grid-template-columns: 1fr 420px; gap: 48px; align-items: start; } .left h3 { margin: 0; font-size:18px; font-weight:600; } .left p { margin:6px 0 18px; color:var(--muted); font-size:13px; } .left .gradio-audio { background: var(--panel) !important; border-radius: var(--radius); min-height: 260px; display:flex; align-items:center; justify-content:center; } .button-primary { margin-top: 18px; height:46px; width:100%; font-size:15px !important; font-weight:600 !important; background: linear-gradient(90deg,#ff6f9f,#ffbf7a) !important; color: #14080d !important; border-radius: 10px !important; } .stems { display:grid; grid-template-columns: 1fr 1fr; gap: 22px; } .stem-surface { background: var(--panel-2); border-radius: 14px; padding: 12px; min-height: 140px; display:flex; flex-direction:column; justify-content:center; gap:6px; } .stem-label { font-size:13px; font-weight:600; color: var(--accent); } .stem-info { font-size:11px; color:var(--muted); opacity:0.85; } .stem-surface .gradio-audio label { display:none !important; } .stem-surface audio { width:92%; max-height:36px; } """ with gr.Blocks() as demo: with gr.Column(elem_classes="app"): with gr.Row(elem_classes="brand"): gr.HTML('') gr.HTML('
Separating Music Into Its elements
') with gr.Row(elem_classes="main"): # LEFT with gr.Column(elem_classes="left"): gr.HTML("""

Select a track

We’ll break it down into individual parts

""") input_audio = gr.Audio(type="filepath") run_btn = gr.Button("Separate", elem_classes="button-primary") ### CHANGED: clean download action download_btn = gr.Button("Download all stems", visible=False) zip_out = gr.File(visible=False) # RIGHT with gr.Column(): with gr.Row(elem_classes="stems"): with gr.Column(elem_classes="stem-surface"): gr.HTML("""
Lead Vocals
Primary melodic voice and lyrical content
""") out_vocals = gr.Audio(interactive=False) with gr.Column(elem_classes="stem-surface"): gr.HTML("""
Mridangam / Percussion
Rhythmic transients and percussive articulation
""") out_drums = gr.Audio(interactive=False) with gr.Column(elem_classes="stem-surface"): gr.HTML("""
Tanpura / Drone
Sustained harmonic bed and tonal reference
""") out_bass = gr.Audio(interactive=False) with gr.Column(elem_classes="stem-surface"): gr.HTML("""
Violin / Accompaniment
Melodic support and expressive ornamentation
""") out_other = gr.Audio(interactive=False) run_btn.click( separate_audio, input_audio, [out_vocals, out_drums, out_bass, out_other, zip_out, download_btn] ) download_btn.click(lambda z: z, zip_out, zip_out) if __name__ == "__main__": demo.launch(css=css, theme=gr.themes.Base())