Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,9 +6,9 @@ import soundfile as sf
|
|
| 6 |
import numpy as np
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
|
| 9 |
-
# ==========================================
|
| 10 |
-
# 1.
|
| 11 |
-
# ==========================================
|
| 12 |
try:
|
| 13 |
from bs_roformer import BSRoformer
|
| 14 |
from attend import Attend
|
|
@@ -18,7 +18,9 @@ except ImportError:
|
|
| 18 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
|
| 20 |
def safe_attend_forward(self, q, k, v, mask=None):
|
| 21 |
-
return F.scaled_dot_product_attention(
|
|
|
|
|
|
|
| 22 |
|
| 23 |
try:
|
| 24 |
Attend.forward = safe_attend_forward
|
|
@@ -27,7 +29,7 @@ except Exception:
|
|
| 27 |
|
| 28 |
def load_model():
|
| 29 |
print("Connecting to model...")
|
| 30 |
-
|
| 31 |
repo_id="Tachyeon/IAM-RoFormer-Model-Weights",
|
| 32 |
filename="v11_consensus_epoch_30.pt"
|
| 33 |
)
|
|
@@ -42,169 +44,215 @@ def load_model():
|
|
| 42 |
flash_attn=True
|
| 43 |
).to(DEVICE)
|
| 44 |
|
| 45 |
-
|
| 46 |
-
model.load_state_dict(
|
| 47 |
model.eval()
|
| 48 |
return model
|
| 49 |
|
| 50 |
model = load_model()
|
| 51 |
|
| 52 |
def separate_audio(audio_path):
|
| 53 |
-
if
|
| 54 |
return [None] * 4
|
| 55 |
|
| 56 |
mix, sr = librosa.load(audio_path, sr=44100, mono=False)
|
| 57 |
if mix.ndim == 1:
|
| 58 |
-
mix = np.stack([mix, mix]
|
| 59 |
|
| 60 |
-
|
| 61 |
overlap = 44100
|
| 62 |
|
| 63 |
-
|
| 64 |
-
length =
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
|
| 69 |
with torch.no_grad(), torch.autocast("cuda", enabled=torch.cuda.is_available()):
|
| 70 |
-
for
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
if
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
stems = (output / count.clamp(min=1)).cpu().numpy()[0]
|
| 82 |
-
|
| 83 |
files = []
|
| 84 |
for i in range(4):
|
| 85 |
-
|
| 86 |
-
sf.write(
|
| 87 |
-
files.append(
|
| 88 |
-
|
| 89 |
return files
|
| 90 |
|
| 91 |
-
# ==========================================
|
| 92 |
-
# 2. UI (
|
| 93 |
-
# ==========================================
|
| 94 |
css = """
|
| 95 |
@import url('https://fonts.googleapis.com/css2?family=Anton&family=Playfair+Display:ital@1&family=Poppins:wght@400;600;700&display=swap');
|
| 96 |
|
| 97 |
-
:root{
|
| 98 |
-
--
|
| 99 |
-
--
|
|
|
|
|
|
|
| 100 |
--ink:#f6efe8;
|
| 101 |
--muted:#c7bfbf;
|
| 102 |
--accent:#ff73a6;
|
| 103 |
}
|
| 104 |
|
|
|
|
| 105 |
html, body, .gradio-container {
|
| 106 |
height:100%;
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
font-family:Poppins,sans-serif;
|
| 110 |
}
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
max-width:
|
| 115 |
-
|
| 116 |
-
|
|
|
|
| 117 |
display:grid;
|
| 118 |
grid-template-rows:auto 1fr;
|
| 119 |
-
gap:
|
|
|
|
| 120 |
}
|
| 121 |
|
| 122 |
-
|
|
|
|
| 123 |
display:flex;
|
| 124 |
justify-content:space-between;
|
| 125 |
align-items:center;
|
| 126 |
-
border:1px solid rgba(255,255,255,.05);
|
| 127 |
-
padding:16px;
|
| 128 |
}
|
| 129 |
|
| 130 |
-
.
|
| 131 |
font-family:Anton,sans-serif;
|
| 132 |
-
font-size:
|
|
|
|
| 133 |
}
|
| 134 |
|
| 135 |
-
.subtitle{
|
| 136 |
font-family:'Playfair Display',serif;
|
| 137 |
font-style:italic;
|
| 138 |
color:var(--accent);
|
|
|
|
| 139 |
}
|
| 140 |
|
| 141 |
-
|
|
|
|
| 142 |
display:grid;
|
| 143 |
grid-template-columns:1fr 1fr;
|
| 144 |
-
gap:
|
| 145 |
height:100%;
|
| 146 |
}
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
| 151 |
display:flex;
|
| 152 |
flex-direction:column;
|
| 153 |
-
gap:
|
|
|
|
| 154 |
}
|
| 155 |
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
| 159 |
text-align:center;
|
| 160 |
}
|
| 161 |
|
| 162 |
-
.
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
}
|
| 167 |
|
| 168 |
-
|
|
|
|
| 169 |
display:grid;
|
| 170 |
grid-template-columns:1fr 1fr;
|
| 171 |
-
gap:
|
| 172 |
}
|
| 173 |
|
| 174 |
-
.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
font-family:'Playfair Display',serif;
|
| 176 |
font-style:italic;
|
| 177 |
color:var(--accent);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
}
|
| 179 |
"""
|
| 180 |
|
| 181 |
with gr.Blocks() as demo:
|
| 182 |
-
with gr.Column(elem_classes="
|
| 183 |
|
| 184 |
with gr.Row(elem_classes="header"):
|
| 185 |
-
gr.HTML('<div class="
|
| 186 |
-
gr.HTML('<div class="subtitle">
|
| 187 |
-
|
| 188 |
-
with gr.Row(elem_classes="
|
| 189 |
-
|
| 190 |
-
with gr.Column(elem_classes="
|
| 191 |
-
gr.HTML(
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
gr.HTML('<div class="label">STEMS</div>')
|
| 197 |
with gr.Row(elem_classes="stems"):
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
if __name__ == "__main__":
|
| 210 |
demo.launch(css=css, theme=gr.themes.Base())
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
|
| 9 |
+
# =====================================================
|
| 10 |
+
# 1. MODEL LOGIC (UNCHANGED)
|
| 11 |
+
# =====================================================
|
| 12 |
try:
|
| 13 |
from bs_roformer import BSRoformer
|
| 14 |
from attend import Attend
|
|
|
|
| 18 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
|
| 20 |
def safe_attend_forward(self, q, k, v, mask=None):
|
| 21 |
+
return F.scaled_dot_product_attention(
|
| 22 |
+
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
|
| 23 |
+
)
|
| 24 |
|
| 25 |
try:
|
| 26 |
Attend.forward = safe_attend_forward
|
|
|
|
| 29 |
|
| 30 |
def load_model():
|
| 31 |
print("Connecting to model...")
|
| 32 |
+
ckpt = hf_hub_download(
|
| 33 |
repo_id="Tachyeon/IAM-RoFormer-Model-Weights",
|
| 34 |
filename="v11_consensus_epoch_30.pt"
|
| 35 |
)
|
|
|
|
| 44 |
flash_attn=True
|
| 45 |
).to(DEVICE)
|
| 46 |
|
| 47 |
+
state = torch.load(ckpt, map_location=DEVICE)
|
| 48 |
+
model.load_state_dict(state["model"] if "model" in state else state)
|
| 49 |
model.eval()
|
| 50 |
return model
|
| 51 |
|
| 52 |
model = load_model()
|
| 53 |
|
| 54 |
def separate_audio(audio_path):
|
| 55 |
+
if not audio_path:
|
| 56 |
return [None] * 4
|
| 57 |
|
| 58 |
mix, sr = librosa.load(audio_path, sr=44100, mono=False)
|
| 59 |
if mix.ndim == 1:
|
| 60 |
+
mix = np.stack([mix, mix])
|
| 61 |
|
| 62 |
+
chunk = 44100 * 10
|
| 63 |
overlap = 44100
|
| 64 |
|
| 65 |
+
x = torch.tensor(mix).float().to(DEVICE)[None]
|
| 66 |
+
length = x.shape[-1]
|
| 67 |
|
| 68 |
+
out = torch.zeros(1, 4, 2, length, device=DEVICE)
|
| 69 |
+
cnt = torch.zeros_like(out)
|
| 70 |
|
| 71 |
with torch.no_grad(), torch.autocast("cuda", enabled=torch.cuda.is_available()):
|
| 72 |
+
for s in range(0, length, chunk - overlap):
|
| 73 |
+
e = min(s + chunk, length)
|
| 74 |
+
part = x[:, :, s:e]
|
| 75 |
+
if part.shape[-1] < chunk:
|
| 76 |
+
part = F.pad(part, (0, chunk - part.shape[-1]))
|
| 77 |
+
pred = model(part)
|
| 78 |
+
out[:, :, :, s:e] += pred[:, :, :, : e - s]
|
| 79 |
+
cnt[:, :, :, s:e] += 1
|
| 80 |
+
|
| 81 |
+
stems = (out / cnt.clamp(min=1)).cpu().numpy()[0]
|
|
|
|
|
|
|
|
|
|
| 82 |
files = []
|
| 83 |
for i in range(4):
|
| 84 |
+
f = f"stem_{i}.wav"
|
| 85 |
+
sf.write(f, stems[i].T, sr)
|
| 86 |
+
files.append(f)
|
|
|
|
| 87 |
return files
|
| 88 |
|
| 89 |
+
# =====================================================
|
| 90 |
+
# 2. POLISHED UI (FIXED LAYOUT, NO SCROLL)
|
| 91 |
+
# =====================================================
|
| 92 |
css = """
|
| 93 |
@import url('https://fonts.googleapis.com/css2?family=Anton&family=Playfair+Display:ital@1&family=Poppins:wght@400;600;700&display=swap');
|
| 94 |
|
| 95 |
+
:root {
|
| 96 |
+
--bg1:#2b1620;
|
| 97 |
+
--bg2:#1c0d14;
|
| 98 |
+
--panel:rgba(255,255,255,0.04);
|
| 99 |
+
--border:rgba(255,255,255,0.08);
|
| 100 |
--ink:#f6efe8;
|
| 101 |
--muted:#c7bfbf;
|
| 102 |
--accent:#ff73a6;
|
| 103 |
}
|
| 104 |
|
| 105 |
+
/* HARD RESET */
|
| 106 |
html, body, .gradio-container {
|
| 107 |
height:100%;
|
| 108 |
+
width:100%;
|
| 109 |
+
margin:0;
|
| 110 |
+
padding:0;
|
| 111 |
+
overflow:hidden !important;
|
| 112 |
+
background:linear-gradient(180deg,var(--bg1),var(--bg2)) !important;
|
| 113 |
+
color:var(--ink);
|
| 114 |
font-family:Poppins,sans-serif;
|
| 115 |
}
|
| 116 |
|
| 117 |
+
/* CENTERED APP */
|
| 118 |
+
.app {
|
| 119 |
+
max-width:1100px;
|
| 120 |
+
height:100%;
|
| 121 |
+
margin:0 auto;
|
| 122 |
+
padding:32px;
|
| 123 |
display:grid;
|
| 124 |
grid-template-rows:auto 1fr;
|
| 125 |
+
gap:28px;
|
| 126 |
+
box-sizing:border-box;
|
| 127 |
}
|
| 128 |
|
| 129 |
+
/* HEADER */
|
| 130 |
+
.header {
|
| 131 |
display:flex;
|
| 132 |
justify-content:space-between;
|
| 133 |
align-items:center;
|
|
|
|
|
|
|
| 134 |
}
|
| 135 |
|
| 136 |
+
.title {
|
| 137 |
font-family:Anton,sans-serif;
|
| 138 |
+
font-size:44px;
|
| 139 |
+
letter-spacing:1px;
|
| 140 |
}
|
| 141 |
|
| 142 |
+
.subtitle {
|
| 143 |
font-family:'Playfair Display',serif;
|
| 144 |
font-style:italic;
|
| 145 |
color:var(--accent);
|
| 146 |
+
margin-left:14px;
|
| 147 |
}
|
| 148 |
|
| 149 |
+
/* MAIN GRID */
|
| 150 |
+
.main {
|
| 151 |
display:grid;
|
| 152 |
grid-template-columns:1fr 1fr;
|
| 153 |
+
gap:32px;
|
| 154 |
height:100%;
|
| 155 |
}
|
| 156 |
|
| 157 |
+
/* PANELS */
|
| 158 |
+
.panel {
|
| 159 |
+
background:var(--panel);
|
| 160 |
+
border:1px solid var(--border);
|
| 161 |
+
border-radius:16px;
|
| 162 |
+
padding:28px;
|
| 163 |
display:flex;
|
| 164 |
flex-direction:column;
|
| 165 |
+
gap:22px;
|
| 166 |
+
box-sizing:border-box;
|
| 167 |
}
|
| 168 |
|
| 169 |
+
/* INPUT */
|
| 170 |
+
.drop {
|
| 171 |
+
border:1px dashed var(--border);
|
| 172 |
+
border-radius:12px;
|
| 173 |
+
padding:32px;
|
| 174 |
text-align:center;
|
| 175 |
}
|
| 176 |
|
| 177 |
+
.drop h3 {
|
| 178 |
+
margin:0;
|
| 179 |
+
font-size:18px;
|
| 180 |
+
letter-spacing:1px;
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
/* BUTTON */
|
| 184 |
+
.run {
|
| 185 |
+
background:linear-gradient(90deg,#ff73a6,#ffd58a) !important;
|
| 186 |
+
color:#160c10 !important;
|
| 187 |
+
font-weight:800 !important;
|
| 188 |
+
border-radius:10px !important;
|
| 189 |
+
border:none !important;
|
| 190 |
}
|
| 191 |
|
| 192 |
+
/* STEMS */
|
| 193 |
+
.stems {
|
| 194 |
display:grid;
|
| 195 |
grid-template-columns:1fr 1fr;
|
| 196 |
+
gap:18px;
|
| 197 |
}
|
| 198 |
|
| 199 |
+
.stem {
|
| 200 |
+
background:rgba(255,255,255,0.03);
|
| 201 |
+
border:1px solid var(--border);
|
| 202 |
+
border-radius:12px;
|
| 203 |
+
padding:16px;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
.label {
|
| 207 |
font-family:'Playfair Display',serif;
|
| 208 |
font-style:italic;
|
| 209 |
color:var(--accent);
|
| 210 |
+
margin-bottom:6px;
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
/* AUDIO FIX */
|
| 214 |
+
audio {
|
| 215 |
+
width:100%;
|
| 216 |
+
max-height:36px;
|
| 217 |
}
|
| 218 |
"""
|
| 219 |
|
| 220 |
with gr.Blocks() as demo:
|
| 221 |
+
with gr.Column(elem_classes="app"):
|
| 222 |
|
| 223 |
with gr.Row(elem_classes="header"):
|
| 224 |
+
gr.HTML('<div class="title">SWARA STUDIO</div>')
|
| 225 |
+
gr.HTML('<div class="subtitle">Audio Source Separation</div>')
|
| 226 |
+
|
| 227 |
+
with gr.Row(elem_classes="main"):
|
| 228 |
+
|
| 229 |
+
with gr.Column(elem_classes="panel"):
|
| 230 |
+
gr.HTML("""
|
| 231 |
+
<div class="drop">
|
| 232 |
+
<h3>MASTER AUDIO</h3>
|
| 233 |
+
<p>Drop or upload WAV / MP3</p>
|
| 234 |
+
</div>
|
| 235 |
+
""")
|
| 236 |
+
inp = gr.Audio(type="filepath")
|
| 237 |
+
btn = gr.Button("RUN SEPARATION", elem_classes="run")
|
| 238 |
+
|
| 239 |
+
with gr.Column(elem_classes="panel"):
|
| 240 |
gr.HTML('<div class="label">STEMS</div>')
|
| 241 |
with gr.Row(elem_classes="stems"):
|
| 242 |
+
with gr.Column(elem_classes="stem"):
|
| 243 |
+
gr.HTML('<div class="label">Vocals</div>')
|
| 244 |
+
o1 = gr.Audio(interactive=False)
|
| 245 |
+
with gr.Column(elem_classes="stem"):
|
| 246 |
+
gr.HTML('<div class="label">Drums</div>')
|
| 247 |
+
o2 = gr.Audio(interactive=False)
|
| 248 |
+
with gr.Column(elem_classes="stem"):
|
| 249 |
+
gr.HTML('<div class="label">Bass</div>')
|
| 250 |
+
o3 = gr.Audio(interactive=False)
|
| 251 |
+
with gr.Column(elem_classes="stem"):
|
| 252 |
+
gr.HTML('<div class="label">Other</div>')
|
| 253 |
+
o4 = gr.Audio(interactive=False)
|
| 254 |
+
|
| 255 |
+
btn.click(separate_audio, inp, [o1, o2, o3, o4])
|
| 256 |
|
| 257 |
if __name__ == "__main__":
|
| 258 |
demo.launch(css=css, theme=gr.themes.Base())
|