Spaces:
Build error
Build error
File size: 5,038 Bytes
b2d0c3c 6eb335e b2d0c3c 6eb335e b2d0c3c 6eb335e b2d0c3c 6eb335e b2d0c3c 6eb335e b2d0c3c 6eb335e b2d0c3c 6eb335e b2d0c3c 6eb335e b2d0c3c 6eb335e b2d0c3c 6eb335e b2d0c3c 6eb335e b2d0c3c 6eb335e b2d0c3c 6eb335e b2d0c3c 6eb335e b2d0c3c 241a38a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | import gradio as gr
import torch
import torch.nn as nn
import soundfile as sf
import numpy as np
import torchaudio
import os
import gc
import time
from demucs.apply import apply_model
from demucs.pretrained import get_model
print("🚀 Starting SpecTacles...")
# Force CPU for free hosting
device = "cpu"
# ==========================================
# 1. DEFINE BRAIN
# ==========================================
class StemMixer(nn.Module):
def __init__(self):
super().__init__()
self.n_fft, self.hop = 2048, 512
self.mixer = nn.Sequential(
nn.Linear(4, 64), nn.BatchNorm1d(64), nn.ReLU(),
nn.Linear(64, 32), nn.BatchNorm1d(32), nn.ReLU(),
nn.Linear(32, 16), nn.BatchNorm1d(16), nn.ReLU(),
nn.Linear(16, 2), nn.Sigmoid()
)
def compute_spec(self, x):
b, s, c, t = x.shape
x = x.reshape(b * s * c, t)
return torch.abs(torch.stft(x, self.n_fft, self.hop, window=torch.hann_window(self.n_fft).to(x.device), return_complex=True))
def forward(self, d, m):
ds, ms = self.compute_spec(d), self.compute_spec(m)
b, s, c, t = d.shape
f, fr = ds.shape[1], ds.shape[2]
ds, ms = ds.reshape(b, s, c, f, fr), ms.reshape(b, s, c, f, fr)
inp = torch.cat([ds, ms], dim=2).permute(0, 1, 3, 4, 2).reshape(-1, 4)
mask = self.mixer(inp).reshape(b, s, f, fr, c).permute(0, 1, 4, 2, 3)
return mask * ds + (1 - mask) * ms
# ==========================================
# 2. LIGHTWEIGHT STARTUP
# ==========================================
print("Loading Brain...")
mixer = StemMixer().to(device)
if os.path.exists("best_stem_mixer.pth"):
mixer.load_state_dict(torch.load("best_stem_mixer.pth", map_location=torch.device('cpu')))
else:
print("⚠️ Model file missing!")
mixer.eval()
# Lazy Loading Variables
model_d = None
model_m = None
def load_heavy_models():
global model_d, model_m
if model_d is None:
print("⏳ Loading Demucs...")
model_d = get_model('htdemucs').to(device).eval()
if model_m is None:
print("⏳ Loading MDX...")
model_m = get_model('mdx_extra').to(device).eval()
# ==========================================
# 3. PROCESS FUNCTION
# ==========================================
def process(audio):
if audio is None: return None
load_heavy_models()
print(f"Processing...")
wav, sr = torchaudio.load(audio)
if sr != 44100: wav = torchaudio.transforms.Resample(sr, 44100)(wav)
if wav.abs().max() > 1.0: wav = wav / wav.abs().max()
chunk = 44100 * 10
stems = []
for s in range(0, wav.shape[1], chunk):
e = min(s + chunk, wav.shape[1])
c = wav[:, s:e]
if c.shape[1] < 2048: continue
with torch.no_grad():
ct = c.unsqueeze(0).to(device)
d = apply_model(model_d, ct, shifts=0)
m = apply_model(model_m, ct, shifts=0)
mix = mixer(d, m)
b,st,ch,t = d.shape
dc = torch.stft(d.reshape(b*st*ch, t), 2048, 512, window=torch.hann_window(2048).to(device), return_complex=True)
rec = torch.istft(torch.polar(mix.reshape(b*st*ch, dc.shape[-2], dc.shape[-1]), torch.angle(dc)), 2048, 512, window=torch.hann_window(2048).to(device), length=t)
stems.append(rec.reshape(st, ch, -1).cpu())
full = torch.cat(stems, dim=2).numpy()
paths = []
ts = int(time.time())
for i, n in enumerate(['Drums', 'Bass', 'Other', 'Vocals']):
p = f"{n}_{ts}.mp3"
sf.write(p, np.clip(full[i].T, -0.99, 0.99), 44100)
paths.append(p)
return paths[3], paths[0], paths[1], paths[2]
# ==========================================
# 4. UI (Nonchalant & Themed)
# ==========================================
# JavaScript to toggle Dark Mode
js_toggle = """
function() {
document.body.classList.toggle('dark');
}
"""
# CSS for the button and clean look
css = """
.gr-button-primary {
background: linear-gradient(90deg, #6366f1, #a855f7) !important;
color: white !important;
border: none !important;
}
h1 {
font-family: 'Helvetica', sans-serif;
font-weight: 700;
}
"""
# Use Soft theme (Looks good in both Light and Dark)
with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
with gr.Row():
with gr.Column(scale=10):
gr.Markdown("# 👓 SpecTacles")
with gr.Column(scale=1):
# The Theme Toggle Button
theme_btn = gr.Button("🌗", variant="secondary")
with gr.Row():
inp = gr.Audio(type="filepath", label="Source")
btn = gr.Button("SEPARATE", variant="primary")
with gr.Row():
v = gr.Audio(label="Vocals")
d = gr.Audio(label="Drums")
b = gr.Audio(label="Bass")
o = gr.Audio(label="Other")
# Logic
btn.click(process, inputs=inp, outputs=[v,d,b,o])
# JS Trigger
theme_btn.click(None, None, None, js=js_toggle)
app.launch() |