Spaces:
Paused
Paused
| import torchaudio | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from tqdm import tqdm | |
| from safetensors.torch import load_file | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchaudio.transforms import Fade | |
| from torchaudio.models import HDemucs | |
| import gradio as gr | |
| import math | |
| # Constants | |
| WIN_LENGTH, HOP_LENGTH, SR = 4096, 1024, 44100 | |
| class Crop2d(nn.Module): | |
| def __init__(self, l, r, t, b): | |
| super().__init__() | |
| self.l, self.r, self.t, self.b = l, r, t, b | |
| def forward(self, x): return x[:, :, self.t:-self.b, self.l:-self.r] | |
| class EncoderBlock(nn.Module): | |
| def __init__(self, in_c, out_c): | |
| super().__init__() | |
| self.conv = nn.Conv2d(in_c, out_c, 5, 2) | |
| self.bn = nn.BatchNorm2d(out_c, 0.001, 0.01) | |
| self.relu = nn.LeakyReLU(0.2) | |
| self.pad = nn.ConstantPad2d((1, 2, 1, 2), 0) | |
| def forward(self, x): | |
| down = self.conv(self.pad(x)) | |
| return down, self.relu(self.bn(down)) | |
| class DecoderBlock(nn.Module): | |
| def __init__(self, in_c, out_c): | |
| super().__init__() | |
| self.tconv = nn.ConvTranspose2d(in_c, out_c, 5, 2) | |
| self.crop = Crop2d(1, 2, 1, 2) | |
| self.bn = nn.BatchNorm2d(out_c, 0.001, 0.01) | |
| self.relu = nn.ReLU() | |
| def forward(self, x): return self.bn(self.relu(self.crop(self.tconv(x)))) | |
| class UNet(nn.Module): | |
| def __init__(self, n_layers=6, in_c=2): | |
| super().__init__() | |
| down = [in_c] + [2**(i+4) for i in range(n_layers)] | |
| self.encoder_layers = nn.ModuleList([EncoderBlock(i, o) for i, o in zip(down[:-1], down[1:])]) | |
| up = [1] + [2**(i+4) for i in range(n_layers)] | |
| up.reverse() | |
| self.decoder_layers = nn.ModuleList([DecoderBlock(i if idx==0 else i*2, o) for idx, (i, o) in enumerate(zip(up[:-1], up[1:]))]) | |
| self.up_final = nn.Conv2d(1, in_c, 4, dilation=2, padding=3) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| enc_out = [] | |
| for down in self.encoder_layers: | |
| conv, x = down(x) | |
| enc_out.append(conv) | |
| for i, up in enumerate(self.decoder_layers): | |
| x = up(enc_out.pop()) if i == 0 else up(torch.cat([enc_out.pop(), x], 1)) | |
| mask = self.sigmoid(self.up_final(x)) | |
| min_f, min_t = min(mask.size(-2), x.size(-2)), min(mask.size(-1), x.size(-1)) | |
| return mask[..., :min_f, :min_t] * x[..., :min_f, :min_t] | |
| class STFTChunkDataset(Dataset): | |
| def __init__(self, wav, win): | |
| self.win, self.T, self.F = win, 512, 1024 | |
| wav = wav.unsqueeze(0) if wav.dim() == 1 else wav | |
| stft = torch.stft(wav.squeeze(), WIN_LENGTH, HOP_LENGTH, window=win, return_complex=False, pad_mode="constant")[:, :self.F] | |
| self.stft, self.L = stft, stft.size(2) | |
| mag = torch.sqrt(stft[:,:,:,0]**2 + stft[:,:,:,1]**2 + 1e-10) | |
| mag = mag.unsqueeze(-1).permute(3, 0, 1, 2) | |
| self.stft_mag = self._batch(mag) | |
| def __len__(self): return self.stft_mag.size(0) | |
| def __getitem__(self, i): return self.stft_mag[i] | |
| def _batch(self, x): | |
| new_size = math.ceil(x.size(-1) / self.T) * self.T | |
| x = F.pad(x, [0, new_size - x.size(-1)]) | |
| return torch.cat(torch.split(x, self.T, -1), 0).transpose(2, 3) | |
| def apply_mask(self, mask, mask_sum): | |
| mask = (mask**2 + 1e-10/2) / mask_sum | |
| mask = torch.cat(torch.split(mask.transpose(2, 3), 1, 0), 3).squeeze(0)[:,:,:self.L].unsqueeze(-1) | |
| stft = F.pad(self.stft * mask, (0,0,0,0,0,WIN_LENGTH//2+1-self.stft.size(1))) if self.stft.size(1) < WIN_LENGTH//2+1 else self.stft * mask | |
| return torch.istft(torch.view_as_complex(stft), WIN_LENGTH, HOP_LENGTH, WIN_LENGTH, self.win,True) | |
| def decoder(self, masks): | |
| mask_sum = sum([m**2 for m in masks.values()]) + 1e-10 | |
| return {n: self.apply_mask(m, mask_sum) for n, m in masks.items()} | |
| class Splitter(nn.Module): | |
| CFG = {2: ['2_other', '2_vocals'], 4: ['4_bass', '4_drums', '4_other', '4_vocals'], 5: ['5_piano', '5_bass', '5_drums', '5_other', '5_vocals']} | |
| def __init__(self, stem=2): | |
| super().__init__() | |
| self.win = nn.Parameter(torch.hann_window(WIN_LENGTH), requires_grad=False) | |
| self.stems = nn.ModuleDict({n: UNet() for n in self.CFG[stem]}) | |
| for n in self.stems: | |
| self.stems[n].load_state_dict(load_file(hf_hub_download("shethjenil/spleeter", f"{n}.safetensors"))) | |
| self.eval() | |
| def forward(self, wav, sr, bs): | |
| dev = next(self.parameters()).device | |
| wav = torchaudio.functional.resample(wav, sr, SR).to(dev) if sr != SR else wav.to(dev) | |
| ds = STFTChunkDataset(wav, self.win) | |
| masks = {n: [] for n in self.stems} | |
| for batch in DataLoader(ds, bs, shuffle=False): | |
| for n, net in self.stems.items(): | |
| masks[n].append(net(batch.to(dev))) | |
| return ds.decoder({k: torch.cat(v, 0) for k, v in masks.items()}) | |
| class DemucsChunkDataset(Dataset): | |
| def __init__(self, wav, seg=10.0, ovlp=0.1, sr=SR, n_src=4): | |
| super().__init__() | |
| self.mean, self.std = wav.mean(), wav.std() | |
| self.mix = (wav - self.mean) / self.std | |
| self.c, self.len = self.mix.shape | |
| self.ovlp_f = int(ovlp * sr) | |
| self.chunk = int(sr * seg * (1 + ovlp)) | |
| self.starts, start, idx = [], 0, 0 | |
| while start < self.len - self.ovlp_f: | |
| self.starts.append(start) | |
| start += self.chunk - self.ovlp_f if idx == 0 else self.chunk | |
| idx += 1 | |
| self.final = torch.zeros(n_src, self.c, self.len, device=wav.device) | |
| self.fade = Fade(0, self.ovlp_f, "linear").to(wav.device) | |
| def __len__(self): return len(self.starts) | |
| def __getitem__(self, i): | |
| s, e = self.starts[i], min(self.starts[i] + self.chunk, self.len) | |
| chunk = self.mix[:, s:e] | |
| return {"chunk": F.pad(chunk, (0, self.chunk - chunk.shape[-1])) if chunk.shape[-1] < self.chunk else chunk, "start": s, "idx": i} | |
| def decode_and_append(self, out, meta): | |
| for i in range(out.size(0)): | |
| s, idx, e = meta["start"][i], meta["idx"][i], min(meta["start"][i] + self.chunk, self.len) | |
| self.final[:, :, s:e] += self.fade(out[i:i+1])[0, :, :, :e-s] | |
| if idx == 0: self.fade.fade_in_len = self.ovlp_f | |
| if e >= self.len: self.fade.fade_out_len = 0 | |
| def get_output(self): return self.final * self.std + self.mean | |
| class Demucs(nn.Module): | |
| CFG = {4: ["drums", "bass", "other", "vocals"]} | |
| def __init__(self, stem=4): | |
| super().__init__() | |
| self.model = HDemucs(self.CFG[stem]) | |
| self.model.load_state_dict(torch.load(torchaudio.utils._download_asset("models/hdemucs_high_trained.pt", progress=False))) | |
| self.eval() | |
| def forward(self, wav, sr, bs): | |
| dev = next(self.parameters()).device | |
| wav = torchaudio.functional.resample(wav, sr, SR).to(dev) if sr != SR else wav.to(dev) | |
| ds = DemucsChunkDataset(wav) | |
| for b in tqdm(DataLoader(ds, bs), desc="Separating"): | |
| ds.decode_and_append(self.model(b["chunk"]), b) | |
| return dict(zip(self.model.sources, ds.get_output())) | |
| def separate_audio_spleeter(path, bs, stem, progress=gr.Progress(True)): | |
| wav, sr = torchaudio.load(path) | |
| res = Splitter(stem).to("cuda" if torch.cuda.is_available() else "cpu")(wav, sr, bs) | |
| for i in res: torchaudio.save(f"{i}.mp3", res[i].cpu(), SR) | |
| return [f"{i}.mp3" for i in res] | |
| def separate_audio_demucs(path, bs, stem, progress=gr.Progress(True)): | |
| wav, sr = torchaudio.load(path) | |
| res = Demucs(stem).to("cuda" if torch.cuda.is_available() else "cpu")(wav, sr, bs) | |
| for i in res: torchaudio.save(f"{i}.mp3", res[i].cpu(), SR) | |
| return [f"{i}.mp3" for i in res] | |
| gr.TabbedInterface([ | |
| gr.Interface(separate_audio_spleeter, [gr.Audio(type="filepath"),gr.Number(16),gr.Radio([2,4,5],label="STEM")],gr.Files()), | |
| gr.Interface(separate_audio_demucs, [gr.Audio(type="filepath"),gr.Number(16),gr.Radio([4],label="STEM")],gr.Files()) | |
| ],['spleeter','demucs']).launch() | |