SpecTacles / app.py
Erudesu's picture
Update app.py
6eb335e verified
import gradio as gr
import torch
import torch.nn as nn
import soundfile as sf
import numpy as np
import torchaudio
import os
import gc
import time
from demucs.apply import apply_model
from demucs.pretrained import get_model
print("🚀 Starting SpecTacles...")
# Force CPU for free hosting
device = "cpu"
# ==========================================
# 1. DEFINE BRAIN
# ==========================================
class StemMixer(nn.Module):
def __init__(self):
super().__init__()
self.n_fft, self.hop = 2048, 512
self.mixer = nn.Sequential(
nn.Linear(4, 64), nn.BatchNorm1d(64), nn.ReLU(),
nn.Linear(64, 32), nn.BatchNorm1d(32), nn.ReLU(),
nn.Linear(32, 16), nn.BatchNorm1d(16), nn.ReLU(),
nn.Linear(16, 2), nn.Sigmoid()
)
def compute_spec(self, x):
b, s, c, t = x.shape
x = x.reshape(b * s * c, t)
return torch.abs(torch.stft(x, self.n_fft, self.hop, window=torch.hann_window(self.n_fft).to(x.device), return_complex=True))
def forward(self, d, m):
ds, ms = self.compute_spec(d), self.compute_spec(m)
b, s, c, t = d.shape
f, fr = ds.shape[1], ds.shape[2]
ds, ms = ds.reshape(b, s, c, f, fr), ms.reshape(b, s, c, f, fr)
inp = torch.cat([ds, ms], dim=2).permute(0, 1, 3, 4, 2).reshape(-1, 4)
mask = self.mixer(inp).reshape(b, s, f, fr, c).permute(0, 1, 4, 2, 3)
return mask * ds + (1 - mask) * ms
# ==========================================
# 2. LIGHTWEIGHT STARTUP
# ==========================================
print("Loading Brain...")
mixer = StemMixer().to(device)
if os.path.exists("best_stem_mixer.pth"):
mixer.load_state_dict(torch.load("best_stem_mixer.pth", map_location=torch.device('cpu')))
else:
print("⚠️ Model file missing!")
mixer.eval()
# Lazy Loading Variables
model_d = None
model_m = None
def load_heavy_models():
global model_d, model_m
if model_d is None:
print("⏳ Loading Demucs...")
model_d = get_model('htdemucs').to(device).eval()
if model_m is None:
print("⏳ Loading MDX...")
model_m = get_model('mdx_extra').to(device).eval()
# ==========================================
# 3. PROCESS FUNCTION
# ==========================================
def process(audio):
if audio is None: return None
load_heavy_models()
print(f"Processing...")
wav, sr = torchaudio.load(audio)
if sr != 44100: wav = torchaudio.transforms.Resample(sr, 44100)(wav)
if wav.abs().max() > 1.0: wav = wav / wav.abs().max()
chunk = 44100 * 10
stems = []
for s in range(0, wav.shape[1], chunk):
e = min(s + chunk, wav.shape[1])
c = wav[:, s:e]
if c.shape[1] < 2048: continue
with torch.no_grad():
ct = c.unsqueeze(0).to(device)
d = apply_model(model_d, ct, shifts=0)
m = apply_model(model_m, ct, shifts=0)
mix = mixer(d, m)
b,st,ch,t = d.shape
dc = torch.stft(d.reshape(b*st*ch, t), 2048, 512, window=torch.hann_window(2048).to(device), return_complex=True)
rec = torch.istft(torch.polar(mix.reshape(b*st*ch, dc.shape[-2], dc.shape[-1]), torch.angle(dc)), 2048, 512, window=torch.hann_window(2048).to(device), length=t)
stems.append(rec.reshape(st, ch, -1).cpu())
full = torch.cat(stems, dim=2).numpy()
paths = []
ts = int(time.time())
for i, n in enumerate(['Drums', 'Bass', 'Other', 'Vocals']):
p = f"{n}_{ts}.mp3"
sf.write(p, np.clip(full[i].T, -0.99, 0.99), 44100)
paths.append(p)
return paths[3], paths[0], paths[1], paths[2]
# ==========================================
# 4. UI (Nonchalant & Themed)
# ==========================================
# JavaScript to toggle Dark Mode
js_toggle = """
function() {
document.body.classList.toggle('dark');
}
"""
# CSS for the button and clean look
css = """
.gr-button-primary {
background: linear-gradient(90deg, #6366f1, #a855f7) !important;
color: white !important;
border: none !important;
}
h1 {
font-family: 'Helvetica', sans-serif;
font-weight: 700;
}
"""
# Use Soft theme (Looks good in both Light and Dark)
with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
with gr.Row():
with gr.Column(scale=10):
gr.Markdown("# 👓 SpecTacles")
with gr.Column(scale=1):
# The Theme Toggle Button
theme_btn = gr.Button("🌗", variant="secondary")
with gr.Row():
inp = gr.Audio(type="filepath", label="Source")
btn = gr.Button("SEPARATE", variant="primary")
with gr.Row():
v = gr.Audio(label="Vocals")
d = gr.Audio(label="Drums")
b = gr.Audio(label="Bass")
o = gr.Audio(label="Other")
# Logic
btn.click(process, inputs=inp, outputs=[v,d,b,o])
# JS Trigger
theme_btn.click(None, None, None, js=js_toggle)
app.launch()