Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import soundfile as sf | |
| import numpy as np | |
| import torchaudio | |
| import os | |
| import gc | |
| import time | |
| from demucs.apply import apply_model | |
| from demucs.pretrained import get_model | |
| print("🚀 Starting SpecTacles...") | |
| # Force CPU for free hosting | |
| device = "cpu" | |
| # ========================================== | |
| # 1. DEFINE BRAIN | |
| # ========================================== | |
| class StemMixer(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.n_fft, self.hop = 2048, 512 | |
| self.mixer = nn.Sequential( | |
| nn.Linear(4, 64), nn.BatchNorm1d(64), nn.ReLU(), | |
| nn.Linear(64, 32), nn.BatchNorm1d(32), nn.ReLU(), | |
| nn.Linear(32, 16), nn.BatchNorm1d(16), nn.ReLU(), | |
| nn.Linear(16, 2), nn.Sigmoid() | |
| ) | |
| def compute_spec(self, x): | |
| b, s, c, t = x.shape | |
| x = x.reshape(b * s * c, t) | |
| return torch.abs(torch.stft(x, self.n_fft, self.hop, window=torch.hann_window(self.n_fft).to(x.device), return_complex=True)) | |
| def forward(self, d, m): | |
| ds, ms = self.compute_spec(d), self.compute_spec(m) | |
| b, s, c, t = d.shape | |
| f, fr = ds.shape[1], ds.shape[2] | |
| ds, ms = ds.reshape(b, s, c, f, fr), ms.reshape(b, s, c, f, fr) | |
| inp = torch.cat([ds, ms], dim=2).permute(0, 1, 3, 4, 2).reshape(-1, 4) | |
| mask = self.mixer(inp).reshape(b, s, f, fr, c).permute(0, 1, 4, 2, 3) | |
| return mask * ds + (1 - mask) * ms | |
| # ========================================== | |
| # 2. LIGHTWEIGHT STARTUP | |
| # ========================================== | |
| print("Loading Brain...") | |
| mixer = StemMixer().to(device) | |
| if os.path.exists("best_stem_mixer.pth"): | |
| mixer.load_state_dict(torch.load("best_stem_mixer.pth", map_location=torch.device('cpu'))) | |
| else: | |
| print("⚠️ Model file missing!") | |
| mixer.eval() | |
| # Lazy Loading Variables | |
| model_d = None | |
| model_m = None | |
| def load_heavy_models(): | |
| global model_d, model_m | |
| if model_d is None: | |
| print("⏳ Loading Demucs...") | |
| model_d = get_model('htdemucs').to(device).eval() | |
| if model_m is None: | |
| print("⏳ Loading MDX...") | |
| model_m = get_model('mdx_extra').to(device).eval() | |
| # ========================================== | |
| # 3. PROCESS FUNCTION | |
| # ========================================== | |
| def process(audio): | |
| if audio is None: return None | |
| load_heavy_models() | |
| print(f"Processing...") | |
| wav, sr = torchaudio.load(audio) | |
| if sr != 44100: wav = torchaudio.transforms.Resample(sr, 44100)(wav) | |
| if wav.abs().max() > 1.0: wav = wav / wav.abs().max() | |
| chunk = 44100 * 10 | |
| stems = [] | |
| for s in range(0, wav.shape[1], chunk): | |
| e = min(s + chunk, wav.shape[1]) | |
| c = wav[:, s:e] | |
| if c.shape[1] < 2048: continue | |
| with torch.no_grad(): | |
| ct = c.unsqueeze(0).to(device) | |
| d = apply_model(model_d, ct, shifts=0) | |
| m = apply_model(model_m, ct, shifts=0) | |
| mix = mixer(d, m) | |
| b,st,ch,t = d.shape | |
| dc = torch.stft(d.reshape(b*st*ch, t), 2048, 512, window=torch.hann_window(2048).to(device), return_complex=True) | |
| rec = torch.istft(torch.polar(mix.reshape(b*st*ch, dc.shape[-2], dc.shape[-1]), torch.angle(dc)), 2048, 512, window=torch.hann_window(2048).to(device), length=t) | |
| stems.append(rec.reshape(st, ch, -1).cpu()) | |
| full = torch.cat(stems, dim=2).numpy() | |
| paths = [] | |
| ts = int(time.time()) | |
| for i, n in enumerate(['Drums', 'Bass', 'Other', 'Vocals']): | |
| p = f"{n}_{ts}.mp3" | |
| sf.write(p, np.clip(full[i].T, -0.99, 0.99), 44100) | |
| paths.append(p) | |
| return paths[3], paths[0], paths[1], paths[2] | |
| # ========================================== | |
| # 4. UI (Nonchalant & Themed) | |
| # ========================================== | |
| # JavaScript to toggle Dark Mode | |
| js_toggle = """ | |
| function() { | |
| document.body.classList.toggle('dark'); | |
| } | |
| """ | |
| # CSS for the button and clean look | |
| css = """ | |
| .gr-button-primary { | |
| background: linear-gradient(90deg, #6366f1, #a855f7) !important; | |
| color: white !important; | |
| border: none !important; | |
| } | |
| h1 { | |
| font-family: 'Helvetica', sans-serif; | |
| font-weight: 700; | |
| } | |
| """ | |
| # Use Soft theme (Looks good in both Light and Dark) | |
| with gr.Blocks(theme=gr.themes.Soft(), css=css) as app: | |
| with gr.Row(): | |
| with gr.Column(scale=10): | |
| gr.Markdown("# 👓 SpecTacles") | |
| with gr.Column(scale=1): | |
| # The Theme Toggle Button | |
| theme_btn = gr.Button("🌗", variant="secondary") | |
| with gr.Row(): | |
| inp = gr.Audio(type="filepath", label="Source") | |
| btn = gr.Button("SEPARATE", variant="primary") | |
| with gr.Row(): | |
| v = gr.Audio(label="Vocals") | |
| d = gr.Audio(label="Drums") | |
| b = gr.Audio(label="Bass") | |
| o = gr.Audio(label="Other") | |
| # Logic | |
| btn.click(process, inputs=inp, outputs=[v,d,b,o]) | |
| # JS Trigger | |
| theme_btn.click(None, None, None, js=js_toggle) | |
| app.launch() |