Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| import librosa | |
| import soundfile as sf | |
| import numpy as np | |
| import zipfile | |
| from huggingface_hub import hf_hub_download | |
| # ================= MODEL (unchanged) ================= | |
| try: | |
| from bs_roformer import BSRoformer | |
| from attend import Attend | |
| except ImportError: | |
| pass | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| def safe_attend_forward(self, q, k, v, mask=None): | |
| return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) | |
| try: | |
| Attend.forward = safe_attend_forward | |
| except Exception: | |
| pass | |
| def load_model(): | |
| ckpt = hf_hub_download( | |
| repo_id="Tachyeon/IAM-RoFormer-Model-Weights", | |
| filename="v11_consensus_epoch_30.pt" | |
| ) | |
| model = BSRoformer( | |
| dim=512, depth=12, stereo=True, num_stems=4, | |
| time_transformer_depth=1, freq_transformer_depth=1, | |
| flash_attn=True | |
| ).to(DEVICE) | |
| state = torch.load(ckpt, map_location=DEVICE) | |
| model.load_state_dict(state["model"] if "model" in state else state) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| # ================= SEPARATION + ZIP ================= | |
| def separate_audio(path): | |
| if not path: | |
| return [None]*5 | |
| mix, sr = librosa.load(path, sr=44100, mono=False) | |
| if mix.ndim == 1: | |
| mix = np.stack([mix, mix]) | |
| x = torch.tensor(mix).float().to(DEVICE)[None] | |
| L = x.shape[-1] | |
| out = torch.zeros(1,4,2,L, device=DEVICE) | |
| cnt = torch.zeros_like(out) | |
| chunk = 44100*10 | |
| hop = chunk - 44100 | |
| with torch.no_grad(), torch.autocast("cuda", enabled=torch.cuda.is_available()): | |
| for s in range(0, L, hop): | |
| e = min(s+chunk, L) | |
| part = x[:,:,s:e] | |
| if part.shape[-1] < chunk: | |
| part = F.pad(part,(0,chunk-part.shape[-1])) | |
| pred = model(part) | |
| out[:,:,:,s:e] += pred[:,:,:,:e-s] | |
| cnt[:,:,:,s:e] += 1 | |
| stems = (out / cnt.clamp(min=1)).cpu().numpy()[0] | |
| names = [ | |
| "lead_vocals.wav", | |
| "mridangam_percussion.wav", | |
| "tanpura_drone.wav", | |
| "violin_accompaniment.wav" | |
| ] | |
| for i, name in enumerate(names): | |
| sf.write(name, stems[i].T, sr) | |
| zip_path = "stems.zip" | |
| with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as z: | |
| for name in names: | |
| z.write(name) | |
| # return stems + zip + signal to show download button | |
| return [names[0], names[1], names[2], names[3], zip_path, gr.update(visible=True)] | |
| # ================= UI (VISUALLY UNCHANGED) ================= | |
| css = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Anton&family=Poppins:wght@400;500;600&display=swap'); | |
| :root{ | |
| --bg1:#2a141d; | |
| --bg2:#14080d; | |
| --ink:#f3ece6; | |
| --muted:#b6aeb0; | |
| --accent:#ff6f9f; | |
| --panel: rgba(255,255,255,0.03); | |
| --panel-2: rgba(255,255,255,0.02); | |
| --radius: 12px; | |
| } | |
| html, body, .gradio-container { | |
| height: 100%; | |
| margin: 0; | |
| background: linear-gradient(180deg, var(--bg1), var(--bg2)) !important; | |
| color: var(--ink); | |
| font-family: Poppins, sans-serif; | |
| } | |
| .app { | |
| max-width: 1160px; | |
| margin: 0 auto; | |
| padding: 48px 40px; | |
| display: grid; | |
| grid-template-rows: auto 1fr; | |
| gap: 36px; | |
| } | |
| .brand { display:flex; flex-direction:column; gap:6px; } | |
| .logo { font-family: Anton, sans-serif; font-size:46px; } | |
| .tagline { font-size:14px; color:var(--accent); opacity:0.9; } | |
| .main { | |
| display:grid; | |
| grid-template-columns: 1fr 420px; | |
| gap: 48px; | |
| align-items: start; | |
| } | |
| .left h3 { margin: 0; font-size:18px; font-weight:600; } | |
| .left p { margin:6px 0 18px; color:var(--muted); font-size:13px; } | |
| .left .gradio-audio { | |
| background: var(--panel) !important; | |
| border-radius: var(--radius); | |
| min-height: 260px; | |
| display:flex; | |
| align-items:center; | |
| justify-content:center; | |
| } | |
| .button-primary { | |
| margin-top: 18px; | |
| height:46px; | |
| width:100%; | |
| font-size:15px !important; | |
| font-weight:600 !important; | |
| background: linear-gradient(90deg,#ff6f9f,#ffbf7a) !important; | |
| color: #14080d !important; | |
| border-radius: 10px !important; | |
| } | |
| .stems { | |
| display:grid; | |
| grid-template-columns: 1fr 1fr; | |
| gap: 22px; | |
| } | |
| .stem-surface { | |
| background: var(--panel-2); | |
| border-radius: 14px; | |
| padding: 12px; | |
| min-height: 140px; | |
| display:flex; | |
| flex-direction:column; | |
| justify-content:center; | |
| gap:6px; | |
| } | |
| .stem-label { | |
| font-size:13px; | |
| font-weight:600; | |
| color: var(--accent); | |
| } | |
| .stem-info { | |
| font-size:11px; | |
| color:var(--muted); | |
| opacity:0.85; | |
| } | |
| .stem-surface .gradio-audio label { display:none !important; } | |
| .stem-surface audio { width:92%; max-height:36px; } | |
| """ | |
| with gr.Blocks() as demo: | |
| with gr.Column(elem_classes="app"): | |
| with gr.Row(elem_classes="brand"): | |
| gr.HTML('<div class="logo">SWARA STUDIO</div>') | |
| gr.HTML('<div class="tagline">Separating Music Into Its elements</div>') | |
| with gr.Row(elem_classes="main"): | |
| # LEFT | |
| with gr.Column(elem_classes="left"): | |
| gr.HTML(""" | |
| <h3>Select a track</h3> | |
| <p>We’ll break it down into individual parts</p> | |
| """) | |
| input_audio = gr.Audio(type="filepath") | |
| run_btn = gr.Button("Separate", elem_classes="button-primary") | |
| ### CHANGED: clean download action | |
| download_btn = gr.Button("Download all stems", visible=False) | |
| zip_out = gr.File(visible=False) | |
| # RIGHT | |
| with gr.Column(): | |
| with gr.Row(elem_classes="stems"): | |
| with gr.Column(elem_classes="stem-surface"): | |
| gr.HTML(""" | |
| <div class="stem-label">Lead Vocals</div> | |
| <div class="stem-info">Primary melodic voice and lyrical content</div> | |
| """) | |
| out_vocals = gr.Audio(interactive=False) | |
| with gr.Column(elem_classes="stem-surface"): | |
| gr.HTML(""" | |
| <div class="stem-label">Mridangam / Percussion</div> | |
| <div class="stem-info">Rhythmic transients and percussive articulation</div> | |
| """) | |
| out_drums = gr.Audio(interactive=False) | |
| with gr.Column(elem_classes="stem-surface"): | |
| gr.HTML(""" | |
| <div class="stem-label">Tanpura / Drone</div> | |
| <div class="stem-info">Sustained harmonic bed and tonal reference</div> | |
| """) | |
| out_bass = gr.Audio(interactive=False) | |
| with gr.Column(elem_classes="stem-surface"): | |
| gr.HTML(""" | |
| <div class="stem-label">Violin / Accompaniment</div> | |
| <div class="stem-info">Melodic support and expressive ornamentation</div> | |
| """) | |
| out_other = gr.Audio(interactive=False) | |
| run_btn.click( | |
| separate_audio, | |
| input_audio, | |
| [out_vocals, out_drums, out_bass, out_other, zip_out, download_btn] | |
| ) | |
| download_btn.click(lambda z: z, zip_out, zip_out) | |
| if __name__ == "__main__": | |
| demo.launch(css=css, theme=gr.themes.Base()) | |