Spaces:
Sleeping
Sleeping
| # serve.py | |
| import os | |
| import tempfile | |
| import numpy as np | |
| import torch | |
| import librosa | |
| import soundfile as sf | |
| import gradio as gr | |
| from omegaconf import OmegaConf | |
| from huggingface_hub import hf_hub_download | |
| from src.models.unet import UNet | |
| # 1) Load your config and model once at startup | |
| CFG = OmegaConf.load("config/default.yaml") | |
| if torch.cuda.is_available(): | |
| DEVICE = torch.device("cuda") | |
| elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): | |
| DEVICE = torch.device("mps") | |
| else: | |
| DEVICE = torch.device("cpu") | |
| MODEL = UNet( | |
| in_ch=1, | |
| num_sources=len(CFG.data.sources) - 1, | |
| chans=CFG.model.chans, | |
| num_pool_layers=CFG.model.num_pool_layers | |
| ).to(DEVICE) | |
| # point this at your best checkpoint in the Space | |
| ckpt_file = hf_hub_download( | |
| repo_id="theadityamittal/music-separator-unet", | |
| filename="checkpoints/unet_best.pt" | |
| ) | |
| MODEL.load_state_dict(torch.load(ckpt_file, map_location=DEVICE)) | |
| MODEL.eval() | |
| def separate_file(mix_path): | |
| """ | |
| Given a file path to the uploaded mixture WAV, returns | |
| a dict of { "drums": path, "bass": path, ... } to the separated .wav files. | |
| """ | |
| # 1. Load audio & STFT | |
| wav, sr = librosa.load(mix_path, sr=CFG.data.sample_rate, mono=True) | |
| stft = librosa.stft( | |
| wav, n_fft=CFG.data.n_fft, hop_length=CFG.data.hop_length | |
| ) | |
| mag, phase = np.abs(stft), np.angle(stft) | |
| F, T = mag.shape | |
| # 2. Pad to multiple of segment_length | |
| SEG = CFG.data.segment_length | |
| pad = (SEG - (T % SEG)) % SEG | |
| if pad: | |
| mag = np.pad(mag, ((0,0),(0,pad)), constant_values=0) | |
| phase = np.pad(phase, ((0,0),(0,pad)), constant_values=0) | |
| n_seg = mag.shape[1] // SEG | |
| # 3. Inference in chunks | |
| preds = [] | |
| with torch.no_grad(): | |
| for i in range(n_seg): | |
| mseg = mag[:, i*SEG:(i+1)*SEG] | |
| x = torch.from_numpy(mseg).unsqueeze(0).unsqueeze(0).to(DEVICE).float() | |
| y = MODEL(x) # (1, S, F, SEG) | |
| preds.append(y.squeeze(0).cpu().numpy()) | |
| pred_mag = np.concatenate(preds, axis=2)[:, :, :T] | |
| phase = phase[:, :T] | |
| # 4. Reconstruct waveforms and write to temp files | |
| out_paths = {} | |
| for idx, src in enumerate(CFG.data.sources[1:]): | |
| spec = pred_mag[idx] * np.exp(1j * phase) | |
| est = librosa.istft( | |
| spec, | |
| hop_length=CFG.data.hop_length, | |
| win_length=CFG.data.n_fft | |
| ) | |
| # write to a temp WAV file | |
| fd, path = tempfile.mkstemp(suffix=f"_{src}.wav") | |
| os.close(fd) | |
| sf.write(path, est, sr) | |
| out_paths[src] = path | |
| # return in the order drums, bass, other, vocals | |
| return [out_paths[src] for src in CFG.data.sources[1:]] | |
| # 5) Build Gradio interface | |
| description = """ | |
| ## Music Source Separation | |
| Upload a mix `.wav` and get back **drums**, **bass**, **other**, and **vocals** stems separated by a U-Net model. | |
| """ | |
| iface = gr.Interface( | |
| fn=separate_file, | |
| inputs=gr.Audio(label="Mixture (.wav)", type="filepath"), | |
| outputs=[ | |
| gr.Audio(label="Drums", type="filepath"), | |
| gr.Audio(label="Bass", type="filepath"), | |
| gr.Audio(label="Other", type="filepath"), | |
| gr.Audio(label="Vocals", type="filepath"), | |
| ], | |
| title="U-Net Music Separator", | |
| description=description, | |
| allow_flagging="never", | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)), share=True) | |