Swara-Split / app.py
Tachyeon's picture
Update app.py
5cf74f9 verified
import gradio as gr
import torch
import torch.nn.functional as F
import librosa
import soundfile as sf
import numpy as np
import zipfile
from huggingface_hub import hf_hub_download
# ================= MODEL (unchanged) =================
try:
from bs_roformer import BSRoformer
from attend import Attend
except ImportError:
pass
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def safe_attend_forward(self, q, k, v, mask=None):
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
try:
Attend.forward = safe_attend_forward
except Exception:
pass
def load_model():
ckpt = hf_hub_download(
repo_id="Tachyeon/IAM-RoFormer-Model-Weights",
filename="v11_consensus_epoch_30.pt"
)
model = BSRoformer(
dim=512, depth=12, stereo=True, num_stems=4,
time_transformer_depth=1, freq_transformer_depth=1,
flash_attn=True
).to(DEVICE)
state = torch.load(ckpt, map_location=DEVICE)
model.load_state_dict(state["model"] if "model" in state else state)
model.eval()
return model
model = load_model()
# ================= SEPARATION + ZIP =================
def separate_audio(path):
if not path:
return [None]*5
mix, sr = librosa.load(path, sr=44100, mono=False)
if mix.ndim == 1:
mix = np.stack([mix, mix])
x = torch.tensor(mix).float().to(DEVICE)[None]
L = x.shape[-1]
out = torch.zeros(1,4,2,L, device=DEVICE)
cnt = torch.zeros_like(out)
chunk = 44100*10
hop = chunk - 44100
with torch.no_grad(), torch.autocast("cuda", enabled=torch.cuda.is_available()):
for s in range(0, L, hop):
e = min(s+chunk, L)
part = x[:,:,s:e]
if part.shape[-1] < chunk:
part = F.pad(part,(0,chunk-part.shape[-1]))
pred = model(part)
out[:,:,:,s:e] += pred[:,:,:,:e-s]
cnt[:,:,:,s:e] += 1
stems = (out / cnt.clamp(min=1)).cpu().numpy()[0]
names = [
"lead_vocals.wav",
"mridangam_percussion.wav",
"tanpura_drone.wav",
"violin_accompaniment.wav"
]
for i, name in enumerate(names):
sf.write(name, stems[i].T, sr)
zip_path = "stems.zip"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as z:
for name in names:
z.write(name)
# return stems + zip + signal to show download button
return [names[0], names[1], names[2], names[3], zip_path, gr.update(visible=True)]
# ================= UI (VISUALLY UNCHANGED) =================
css = """
@import url('https://fonts.googleapis.com/css2?family=Anton&family=Poppins:wght@400;500;600&display=swap');
:root{
--bg1:#2a141d;
--bg2:#14080d;
--ink:#f3ece6;
--muted:#b6aeb0;
--accent:#ff6f9f;
--panel: rgba(255,255,255,0.03);
--panel-2: rgba(255,255,255,0.02);
--radius: 12px;
}
html, body, .gradio-container {
height: 100%;
margin: 0;
background: linear-gradient(180deg, var(--bg1), var(--bg2)) !important;
color: var(--ink);
font-family: Poppins, sans-serif;
}
.app {
max-width: 1160px;
margin: 0 auto;
padding: 48px 40px;
display: grid;
grid-template-rows: auto 1fr;
gap: 36px;
}
.brand { display:flex; flex-direction:column; gap:6px; }
.logo { font-family: Anton, sans-serif; font-size:46px; }
.tagline { font-size:14px; color:var(--accent); opacity:0.9; }
.main {
display:grid;
grid-template-columns: 1fr 420px;
gap: 48px;
align-items: start;
}
.left h3 { margin: 0; font-size:18px; font-weight:600; }
.left p { margin:6px 0 18px; color:var(--muted); font-size:13px; }
.left .gradio-audio {
background: var(--panel) !important;
border-radius: var(--radius);
min-height: 260px;
display:flex;
align-items:center;
justify-content:center;
}
.button-primary {
margin-top: 18px;
height:46px;
width:100%;
font-size:15px !important;
font-weight:600 !important;
background: linear-gradient(90deg,#ff6f9f,#ffbf7a) !important;
color: #14080d !important;
border-radius: 10px !important;
}
.stems {
display:grid;
grid-template-columns: 1fr 1fr;
gap: 22px;
}
.stem-surface {
background: var(--panel-2);
border-radius: 14px;
padding: 12px;
min-height: 140px;
display:flex;
flex-direction:column;
justify-content:center;
gap:6px;
}
.stem-label {
font-size:13px;
font-weight:600;
color: var(--accent);
}
.stem-info {
font-size:11px;
color:var(--muted);
opacity:0.85;
}
.stem-surface .gradio-audio label { display:none !important; }
.stem-surface audio { width:92%; max-height:36px; }
"""
with gr.Blocks() as demo:
with gr.Column(elem_classes="app"):
with gr.Row(elem_classes="brand"):
gr.HTML('<div class="logo">SWARA STUDIO</div>')
gr.HTML('<div class="tagline">Separating Music Into Its elements</div>')
with gr.Row(elem_classes="main"):
# LEFT
with gr.Column(elem_classes="left"):
gr.HTML("""
<h3>Select a track</h3>
<p>We’ll break it down into individual parts</p>
""")
input_audio = gr.Audio(type="filepath")
run_btn = gr.Button("Separate", elem_classes="button-primary")
### CHANGED: clean download action
download_btn = gr.Button("Download all stems", visible=False)
zip_out = gr.File(visible=False)
# RIGHT
with gr.Column():
with gr.Row(elem_classes="stems"):
with gr.Column(elem_classes="stem-surface"):
gr.HTML("""
<div class="stem-label">Lead Vocals</div>
<div class="stem-info">Primary melodic voice and lyrical content</div>
""")
out_vocals = gr.Audio(interactive=False)
with gr.Column(elem_classes="stem-surface"):
gr.HTML("""
<div class="stem-label">Mridangam / Percussion</div>
<div class="stem-info">Rhythmic transients and percussive articulation</div>
""")
out_drums = gr.Audio(interactive=False)
with gr.Column(elem_classes="stem-surface"):
gr.HTML("""
<div class="stem-label">Tanpura / Drone</div>
<div class="stem-info">Sustained harmonic bed and tonal reference</div>
""")
out_bass = gr.Audio(interactive=False)
with gr.Column(elem_classes="stem-surface"):
gr.HTML("""
<div class="stem-label">Violin / Accompaniment</div>
<div class="stem-info">Melodic support and expressive ornamentation</div>
""")
out_other = gr.Audio(interactive=False)
run_btn.click(
separate_audio,
input_audio,
[out_vocals, out_drums, out_bass, out_other, zip_out, download_btn]
)
download_btn.click(lambda z: z, zip_out, zip_out)
if __name__ == "__main__":
demo.launch(css=css, theme=gr.themes.Base())