import gradio as gr import torch import torch.nn as nn import soundfile as sf import numpy as np import torchaudio import os import gc import time from demucs.apply import apply_model from demucs.pretrained import get_model print("🚀 Starting SpecTacles...") # Force CPU for free hosting device = "cpu" # ========================================== # 1. DEFINE BRAIN # ========================================== class StemMixer(nn.Module): def __init__(self): super().__init__() self.n_fft, self.hop = 2048, 512 self.mixer = nn.Sequential( nn.Linear(4, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Linear(64, 32), nn.BatchNorm1d(32), nn.ReLU(), nn.Linear(32, 16), nn.BatchNorm1d(16), nn.ReLU(), nn.Linear(16, 2), nn.Sigmoid() ) def compute_spec(self, x): b, s, c, t = x.shape x = x.reshape(b * s * c, t) return torch.abs(torch.stft(x, self.n_fft, self.hop, window=torch.hann_window(self.n_fft).to(x.device), return_complex=True)) def forward(self, d, m): ds, ms = self.compute_spec(d), self.compute_spec(m) b, s, c, t = d.shape f, fr = ds.shape[1], ds.shape[2] ds, ms = ds.reshape(b, s, c, f, fr), ms.reshape(b, s, c, f, fr) inp = torch.cat([ds, ms], dim=2).permute(0, 1, 3, 4, 2).reshape(-1, 4) mask = self.mixer(inp).reshape(b, s, f, fr, c).permute(0, 1, 4, 2, 3) return mask * ds + (1 - mask) * ms # ========================================== # 2. LIGHTWEIGHT STARTUP # ========================================== print("Loading Brain...") mixer = StemMixer().to(device) if os.path.exists("best_stem_mixer.pth"): mixer.load_state_dict(torch.load("best_stem_mixer.pth", map_location=torch.device('cpu'))) else: print("⚠️ Model file missing!") mixer.eval() # Lazy Loading Variables model_d = None model_m = None def load_heavy_models(): global model_d, model_m if model_d is None: print("⏳ Loading Demucs...") model_d = get_model('htdemucs').to(device).eval() if model_m is None: print("⏳ Loading MDX...") model_m = get_model('mdx_extra').to(device).eval() # ========================================== # 3. PROCESS FUNCTION # ========================================== def process(audio): if audio is None: return None load_heavy_models() print(f"Processing...") wav, sr = torchaudio.load(audio) if sr != 44100: wav = torchaudio.transforms.Resample(sr, 44100)(wav) if wav.abs().max() > 1.0: wav = wav / wav.abs().max() chunk = 44100 * 10 stems = [] for s in range(0, wav.shape[1], chunk): e = min(s + chunk, wav.shape[1]) c = wav[:, s:e] if c.shape[1] < 2048: continue with torch.no_grad(): ct = c.unsqueeze(0).to(device) d = apply_model(model_d, ct, shifts=0) m = apply_model(model_m, ct, shifts=0) mix = mixer(d, m) b,st,ch,t = d.shape dc = torch.stft(d.reshape(b*st*ch, t), 2048, 512, window=torch.hann_window(2048).to(device), return_complex=True) rec = torch.istft(torch.polar(mix.reshape(b*st*ch, dc.shape[-2], dc.shape[-1]), torch.angle(dc)), 2048, 512, window=torch.hann_window(2048).to(device), length=t) stems.append(rec.reshape(st, ch, -1).cpu()) full = torch.cat(stems, dim=2).numpy() paths = [] ts = int(time.time()) for i, n in enumerate(['Drums', 'Bass', 'Other', 'Vocals']): p = f"{n}_{ts}.mp3" sf.write(p, np.clip(full[i].T, -0.99, 0.99), 44100) paths.append(p) return paths[3], paths[0], paths[1], paths[2] # ========================================== # 4. UI (Nonchalant & Themed) # ========================================== # JavaScript to toggle Dark Mode js_toggle = """ function() { document.body.classList.toggle('dark'); } """ # CSS for the button and clean look css = """ .gr-button-primary { background: linear-gradient(90deg, #6366f1, #a855f7) !important; color: white !important; border: none !important; } h1 { font-family: 'Helvetica', sans-serif; font-weight: 700; } """ # Use Soft theme (Looks good in both Light and Dark) with gr.Blocks(theme=gr.themes.Soft(), css=css) as app: with gr.Row(): with gr.Column(scale=10): gr.Markdown("# 👓 SpecTacles") with gr.Column(scale=1): # The Theme Toggle Button theme_btn = gr.Button("🌗", variant="secondary") with gr.Row(): inp = gr.Audio(type="filepath", label="Source") btn = gr.Button("SEPARATE", variant="primary") with gr.Row(): v = gr.Audio(label="Vocals") d = gr.Audio(label="Drums") b = gr.Audio(label="Bass") o = gr.Audio(label="Other") # Logic btn.click(process, inputs=inp, outputs=[v,d,b,o]) # JS Trigger theme_btn.click(None, None, None, js=js_toggle) app.launch()