import torchaudio import torch from huggingface_hub import hf_hub_download from tqdm import tqdm from safetensors.torch import load_file from torch import nn from torch.nn import functional as F from torch.utils.data import Dataset, DataLoader from torchaudio.transforms import Fade from torchaudio.models import HDemucs import gradio as gr import math # Constants WIN_LENGTH, HOP_LENGTH, SR = 4096, 1024, 44100 class Crop2d(nn.Module): def __init__(self, l, r, t, b): super().__init__() self.l, self.r, self.t, self.b = l, r, t, b def forward(self, x): return x[:, :, self.t:-self.b, self.l:-self.r] class EncoderBlock(nn.Module): def __init__(self, in_c, out_c): super().__init__() self.conv = nn.Conv2d(in_c, out_c, 5, 2) self.bn = nn.BatchNorm2d(out_c, 0.001, 0.01) self.relu = nn.LeakyReLU(0.2) self.pad = nn.ConstantPad2d((1, 2, 1, 2), 0) def forward(self, x): down = self.conv(self.pad(x)) return down, self.relu(self.bn(down)) class DecoderBlock(nn.Module): def __init__(self, in_c, out_c): super().__init__() self.tconv = nn.ConvTranspose2d(in_c, out_c, 5, 2) self.crop = Crop2d(1, 2, 1, 2) self.bn = nn.BatchNorm2d(out_c, 0.001, 0.01) self.relu = nn.ReLU() def forward(self, x): return self.bn(self.relu(self.crop(self.tconv(x)))) class UNet(nn.Module): def __init__(self, n_layers=6, in_c=2): super().__init__() down = [in_c] + [2**(i+4) for i in range(n_layers)] self.encoder_layers = nn.ModuleList([EncoderBlock(i, o) for i, o in zip(down[:-1], down[1:])]) up = [1] + [2**(i+4) for i in range(n_layers)] up.reverse() self.decoder_layers = nn.ModuleList([DecoderBlock(i if idx==0 else i*2, o) for idx, (i, o) in enumerate(zip(up[:-1], up[1:]))]) self.up_final = nn.Conv2d(1, in_c, 4, dilation=2, padding=3) self.sigmoid = nn.Sigmoid() def forward(self, x): enc_out = [] for down in self.encoder_layers: conv, x = down(x) enc_out.append(conv) for i, up in enumerate(self.decoder_layers): x = up(enc_out.pop()) if i == 0 else up(torch.cat([enc_out.pop(), x], 1)) mask = self.sigmoid(self.up_final(x)) min_f, min_t = min(mask.size(-2), x.size(-2)), min(mask.size(-1), x.size(-1)) return mask[..., :min_f, :min_t] * x[..., :min_f, :min_t] class STFTChunkDataset(Dataset): def __init__(self, wav, win): self.win, self.T, self.F = win, 512, 1024 wav = wav.unsqueeze(0) if wav.dim() == 1 else wav stft = torch.stft(wav.squeeze(), WIN_LENGTH, HOP_LENGTH, window=win, return_complex=False, pad_mode="constant")[:, :self.F] self.stft, self.L = stft, stft.size(2) mag = torch.sqrt(stft[:,:,:,0]**2 + stft[:,:,:,1]**2 + 1e-10) mag = mag.unsqueeze(-1).permute(3, 0, 1, 2) self.stft_mag = self._batch(mag) def __len__(self): return self.stft_mag.size(0) def __getitem__(self, i): return self.stft_mag[i] def _batch(self, x): new_size = math.ceil(x.size(-1) / self.T) * self.T x = F.pad(x, [0, new_size - x.size(-1)]) return torch.cat(torch.split(x, self.T, -1), 0).transpose(2, 3) def apply_mask(self, mask, mask_sum): mask = (mask**2 + 1e-10/2) / mask_sum mask = torch.cat(torch.split(mask.transpose(2, 3), 1, 0), 3).squeeze(0)[:,:,:self.L].unsqueeze(-1) stft = F.pad(self.stft * mask, (0,0,0,0,0,WIN_LENGTH//2+1-self.stft.size(1))) if self.stft.size(1) < WIN_LENGTH//2+1 else self.stft * mask return torch.istft(torch.view_as_complex(stft), WIN_LENGTH, HOP_LENGTH, WIN_LENGTH, self.win,True) def decoder(self, masks): mask_sum = sum([m**2 for m in masks.values()]) + 1e-10 return {n: self.apply_mask(m, mask_sum) for n, m in masks.items()} class Splitter(nn.Module): CFG = {2: ['2_other', '2_vocals'], 4: ['4_bass', '4_drums', '4_other', '4_vocals'], 5: ['5_piano', '5_bass', '5_drums', '5_other', '5_vocals']} def __init__(self, stem=2): super().__init__() self.win = nn.Parameter(torch.hann_window(WIN_LENGTH), requires_grad=False) self.stems = nn.ModuleDict({n: UNet() for n in self.CFG[stem]}) for n in self.stems: self.stems[n].load_state_dict(load_file(hf_hub_download("shethjenil/spleeter", f"{n}.safetensors"))) self.eval() @torch.inference_mode() def forward(self, wav, sr, bs): dev = next(self.parameters()).device wav = torchaudio.functional.resample(wav, sr, SR).to(dev) if sr != SR else wav.to(dev) ds = STFTChunkDataset(wav, self.win) masks = {n: [] for n in self.stems} for batch in DataLoader(ds, bs, shuffle=False): for n, net in self.stems.items(): masks[n].append(net(batch.to(dev))) return ds.decoder({k: torch.cat(v, 0) for k, v in masks.items()}) class DemucsChunkDataset(Dataset): def __init__(self, wav, seg=10.0, ovlp=0.1, sr=SR, n_src=4): super().__init__() self.mean, self.std = wav.mean(), wav.std() self.mix = (wav - self.mean) / self.std self.c, self.len = self.mix.shape self.ovlp_f = int(ovlp * sr) self.chunk = int(sr * seg * (1 + ovlp)) self.starts, start, idx = [], 0, 0 while start < self.len - self.ovlp_f: self.starts.append(start) start += self.chunk - self.ovlp_f if idx == 0 else self.chunk idx += 1 self.final = torch.zeros(n_src, self.c, self.len, device=wav.device) self.fade = Fade(0, self.ovlp_f, "linear").to(wav.device) def __len__(self): return len(self.starts) def __getitem__(self, i): s, e = self.starts[i], min(self.starts[i] + self.chunk, self.len) chunk = self.mix[:, s:e] return {"chunk": F.pad(chunk, (0, self.chunk - chunk.shape[-1])) if chunk.shape[-1] < self.chunk else chunk, "start": s, "idx": i} def decode_and_append(self, out, meta): for i in range(out.size(0)): s, idx, e = meta["start"][i], meta["idx"][i], min(meta["start"][i] + self.chunk, self.len) self.final[:, :, s:e] += self.fade(out[i:i+1])[0, :, :, :e-s] if idx == 0: self.fade.fade_in_len = self.ovlp_f if e >= self.len: self.fade.fade_out_len = 0 def get_output(self): return self.final * self.std + self.mean class Demucs(nn.Module): CFG = {4: ["drums", "bass", "other", "vocals"]} def __init__(self, stem=4): super().__init__() self.model = HDemucs(self.CFG[stem]) self.model.load_state_dict(torch.load(torchaudio.utils._download_asset("models/hdemucs_high_trained.pt", progress=False))) self.eval() @torch.inference_mode() def forward(self, wav, sr, bs): dev = next(self.parameters()).device wav = torchaudio.functional.resample(wav, sr, SR).to(dev) if sr != SR else wav.to(dev) ds = DemucsChunkDataset(wav) for b in tqdm(DataLoader(ds, bs), desc="Separating"): ds.decode_and_append(self.model(b["chunk"]), b) return dict(zip(self.model.sources, ds.get_output())) def separate_audio_spleeter(path, bs, stem, progress=gr.Progress(True)): wav, sr = torchaudio.load(path) res = Splitter(stem).to("cuda" if torch.cuda.is_available() else "cpu")(wav, sr, bs) for i in res: torchaudio.save(f"{i}.mp3", res[i].cpu(), SR) return [f"{i}.mp3" for i in res] def separate_audio_demucs(path, bs, stem, progress=gr.Progress(True)): wav, sr = torchaudio.load(path) res = Demucs(stem).to("cuda" if torch.cuda.is_available() else "cpu")(wav, sr, bs) for i in res: torchaudio.save(f"{i}.mp3", res[i].cpu(), SR) return [f"{i}.mp3" for i in res] gr.TabbedInterface([ gr.Interface(separate_audio_spleeter, [gr.Audio(type="filepath"),gr.Number(16),gr.Radio([2,4,5],label="STEM")],gr.Files()), gr.Interface(separate_audio_demucs, [gr.Audio(type="filepath"),gr.Number(16),gr.Radio([4],label="STEM")],gr.Files()) ],['spleeter','demucs']).launch()